1

我正在继承tf.keras.model。我需要覆盖compute_output_shape,否则,我最终会得到 a NotImplementedErrorfrom here

class Custom(tf.keras.Model):
    ...

    def compute_output_shape(self, input_shape):
        # input_shape = (None, ...)
        batch_size = ???
        return (batch_size, ...)

compute_output_shapeinput_shape作为输入。然而,这并没有多大帮助,因为批量大小在 TensorFlow 的某个地方不知何故丢失了。

None如果我尝试以与相同的方式返回一个以 开头的形状input_shape,我会得到TypeError: 'str' object cannot be interpreted as an integer. 只是省略批量大小也不起作用。

批量大小是可变的,所以我不能硬编码它。

4

0 回答 0