编辑:
为了澄清为什么这个问题与建议的重复项不同,这个 SO 问题跟进了那些建议的重复项,Keras 究竟在用这些 SO 问题中描述的技术做了什么。建议的重复项使用数据集 API 指定make_one_shot_iterator()
,model.fit
我的后续行动是make_one_shot_iterator()
只能通过数据集一次,但是在给出的解决方案中,指定了几个时期。
这是对这些 SO 问题的跟进
如何正确结合 TensorFlow 的 Dataset API 和 Keras?
使用 tf.data.Dataset 作为 Keras 模型的训练输入不起作用
其中“从 Tensorflow 1.9 开始,可以将 tf.data.Dataset 对象直接传递给 keras.Model.fit(),它的行为类似于 fit_generator”。每个示例都有一个 TF 数据集 one shot iterator 输入到 Kera 的 model.fit 中。
下面给出一个例子
# Load mnist training data
(x_train, y_train), _ = tf.keras.datasets.mnist.load_data()
training_set = tfdata_generator(x_train, y_train,is_training=True)
model = # your keras model here
model.fit(
training_set.make_one_shot_iterator(),
steps_per_epoch=len(x_train) // 128,
epochs=5,
verbose = 1)
但是,根据 Tensorflow 数据集 API 指南(此处为https://www.tensorflow.org/guide/datasets):
one-shot 迭代器是最简单的迭代器形式,仅支持对数据集进行一次迭代
所以它只适用于 1 个 epoch。但是,SO 问题中的代码指定了几个 epoch,上面的代码示例指定了 5 个 epoch。
这个矛盾有什么解释吗?Keras 是否知道当单次迭代器遍历数据集时,它可以重新初始化和打乱数据?