-1

是否有一种简单的本地方法可以在 DqnAgent 上实现带有QNetwork的 tfa.optimizers.CyclicalLearningRate

尽量避免编写我自己的 DqnAgent。

我想更好的问题可能是,在 DqnAgent 上实现回调的正确方法是什么?

4

1 回答 1

1

从您链接的教程中,他们设置优化器的部分是

optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)

agent.initialize()

因此,您可以用您更愿意使用的任何优化器替换优化器。根据文档类似

optimizer = tf.keras.optimizers.Adam(learning_rate=tfa.optimizers.CyclicalLearningRate)

应该可以工作,除非他们在教程中使用 tf 1.0 adam 导致任何潜在的兼容性问题。

于 2020-09-13T01:15:03.770 回答