1

我想在 TensorFlow 中使用具有不同输入的相同计算图。不幸的是,我没有编写模型函数,我想尽可能避免修改它。

def model_fn(features, labels, mode):
    # some code here
    return loss, train_op

我正在编写一个联邦学习算法,我想在n 个不同的数据集上训练相同的模型,而不需要n 个客户端,所以我计划使用单个计算图和不同的输入。

我计划这样做:

with tf.variable_scope("client", reuse=tf.AUTO_REUSE):
    for _ in range(num_of_clients):
        features, labels = input_fn()
        client_model = model_fn(features, labels, "train")

我曾希望 client_model 能够重用所有共享变量。

  • 我可以确认它没有tf.global_variables()
  • 我也知道如果model_fn()used tf.get_variable(),这会起作用,但它不会。

我如何在 TensorFlow 中做到这一点而不修改model_fn(否则我只会使用这个答案)?

4

0 回答 0