简短的回答是:做你通常会做的事情,Tensorflow 会处理其余的事情。
答案隐藏在方法的文档字符串中(强调添加):save_weights
tf.keras.Model
以 TensorFlow 格式保存时,网络引用的所有对象都以与 相同的格式保存tf.train.Checkpoint
,包括任何Layer
实例或Optimizer
分配给对象属性的实例。对于使用 由输入和输出构建的网络,网络使用tf.keras.Model(inputs,
outputs)
的Layer
实例会被自动跟踪/保存。对于继承自 的用户定义类tf.keras.Model
,
Layer
必须将实例分配给对象属性,通常在构造函数中。
实现目标的最简单方法是将图层分配给 Python 对象。在以下示例中,我使用字典来保留原始名称。
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.my_weight_dict = {}
self.my_weight_dict["dense1"] = tf.keras.layers.Dense(6, activation=tf.nn.relu)
self.my_weight_dict["dense2"] = tf.keras.layers.Dense(3, activation=tf.nn.softmax) # changed to fit the dataset
def call(self,inputs):
x = self.my_weight_dict["dense1"](inputs)
return self.my_weight_dict["dense2"](x)
这允许您以编程方式指定将更改模型属性的属性 - 例如,对于自动超参数调整很有用。
这是一个使用上面定义的类的完全可重现的示例:
import tensorflow as tf
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize
# load the data and split it into train and test
iris_dataset = load_iris()
X = iris_dataset.data
y = iris_dataset.target
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.3,stratify=y)
# normalize the features
X_train = normalize(X_train, axis=0,norm='max')
X_test = normalize(X_test, axis=0,norm='max')
# create, compile, and fit the model
model = MyModel()
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05, momentum=0.9),
loss="sparse_categorical_crossentropy", #tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(X_train, y_train, epochs=50, verbose = 2, batch_size=128,
validation_data = (X_test, y_test))
# just call the save_weights
model.save_weights(filepath="path/to/your/weights/file")
# create a new model with the same structure
model_2 = MyModel()
model_2.load_weights("path/to/your/weights/file")
model_2.compile(optimizer=tf.keras.optimizers.SGD(lr=0.05, momentum=0.9),
loss="sparse_categorical_crossentropy", #tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model_2.evaluate(X_test,y_test)