在 tensorflow 2 中,不再支持获取和分配。可以按照https://stackoverflow.com/a/47081613/9949099中提供的答案在自定义 keras 回调中访问 tf 1.x 中的批处理结果 在 tf.keras 和 tf 2.0 下急切执行获取不支持,因此为 tf 1.x 提供的解决方案不起作用。有没有办法在 tf.keras 自定义回调的 on_batch_end 回调中获取 y_true 和 y_pred?
我试图修改在 tf.1 中工作的答案,如下所示
from tf.keras.callbacks import Callback
class CollectOutputAndTarget(Callback):
def __init__(self):
super(CollectOutputAndTarget, self).__init__()
self.targets = [] # collect y_true batches
self.outputs = [] # collect y_pred batches
def on_batch_end(self, batch, logs=None):
# evaluate the variables and save them into lists
# How to change the following 2 lines so that in tf.2 eager execution collect the batch results
self.targets.append(K.eval(self.model._targets[0]))
self.outputs.append(K.eval(self.model.outputs[0]))
当我运行上面的代码时,代码失败,访问 self.model._targets[0] 或 self.model.outputs[0] 中的数据显然是不可能的