1

所以我想在 Keras Tuner 中使用 tf.keras.callbacks.ModelCheckpoint,但是您选择保存检查点的路径的方式不允许您将其保存为具有特定名称的文件,该名称与该检查点的试验和执行,仅与一个时期相关联。

也就是说,如果我只是简单地将这个回调放在 Keras Tuner 中,在 checkpoints 保存发生的那一刻,到最后,我将不知道如何将保存的 checkpoints 与试验和试验执行相关联,只与 epoch 相关联。

4

1 回答 1

0

您可以使用tf.keras.callbacks.ModelCheckpointKeras tuner其他模型相同的方式来保存检查点。

根据该模型使用从搜索中获得的超参数训练模型后,您可以定义模型检查点并将其保存如下:

hypermodel = tuner.hypermodel.build(best_hps)

# Retrain the model
hypermodel.fit(img_train, label_train, epochs=best_epoch, validation_split=0.2)

import os
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

# Create a callback that saves the model's weights
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)
history = hypermodel.fit(img_train, label_train, epochs=5, validation_split=0.2, callbacks=[cp_callback])
os.listdir(checkpoint_dir)

# Re-evaluate the model
loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

# Loads the weights
hypermodel.load_weights(checkpoint_path)

# Re-evaluate the model
loss, acc = hypermodel.evaluate(img_test, label_test, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

有关保存和加载模型检查点的更多信息,请参阅此链接。

于 2021-12-06T16:04:19.030 回答