我正在继承tf.keras.model
。我需要覆盖compute_output_shape
,否则,我最终会得到 a NotImplementedError
from here。
class Custom(tf.keras.Model):
...
def compute_output_shape(self, input_shape):
# input_shape = (None, ...)
batch_size = ???
return (batch_size, ...)
compute_output_shape
input_shape
作为输入。然而,这并没有多大帮助,因为批量大小在 TensorFlow 的某个地方不知何故丢失了。
None
如果我尝试以与相同的方式返回一个以 开头的形状input_shape
,我会得到TypeError: 'str' object cannot be interpreted as an integer
. 只是省略批量大小也不起作用。
批量大小是可变的,所以我不能硬编码它。