我一直在尝试在 tensorflow 中使用 tf-agents 构建一个 rl 代理。我在自定义构建的环境中遇到了这个问题,但使用官方 tf colab 示例重现了它。每当我尝试使用 QRnnNetwork 作为 DqnAgent 的网络时,就会出现问题。该代理可以在常规 qnetwork 上正常工作,但在使用 qrnn 时会重新调整 policy_state_spec。我将如何解决这个问题?
这是 policy_state_spec 转换为的形状,但原始形状是 ()
ListWrapper([TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_0'), TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_1')])
q_net = q_rnn_network.QRnnNetwork(
train_env.observation_spec(),
train_env.action_spec(),
lstm_size=(16,),
)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
collect_policy = agent.collect_policy
example_environment = tf_py_environment.TFPyEnvironment(
suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
collect_policy.action(time_step)
我收到此错误:
TypeError: policy_state and policy_state_spec structures do not match:
()
vs.
ListWrapper([., .])