我想在 TensorFlow 中使用具有不同输入的相同计算图。不幸的是,我没有编写模型函数,我想尽可能避免修改它。
def model_fn(features, labels, mode):
# some code here
return loss, train_op
我正在编写一个联邦学习算法,我想在n 个不同的数据集上训练相同的模型,而不需要n 个客户端,所以我计划使用单个计算图和不同的输入。
我计划这样做:
with tf.variable_scope("client", reuse=tf.AUTO_REUSE):
for _ in range(num_of_clients):
features, labels = input_fn()
client_model = model_fn(features, labels, "train")
我曾希望 client_model 能够重用所有共享变量。
- 我可以确认它没有
tf.global_variables()
。 - 我也知道如果
model_fn()
usedtf.get_variable()
,这会起作用,但它不会。
我如何在 TensorFlow 中做到这一点而不修改model_fn
(否则我只会使用这个答案)?