3

我一直在尝试在 tensorflow 中使用 tf-agents 构建一个 rl 代理。我在自定义构建的环境中遇到了这个问题,但使用官方 tf colab 示例重现了它。每当我尝试使用 QRnnNetwork 作为 DqnAgent 的网络时,就会出现问题。该代理可以在常规 qnetwork 上正常工作,但在使用 qrnn 时会重新调整 policy_state_spec。我将如何解决这个问题?

这是 policy_state_spec 转换为的形状,但原始形状是 ()

ListWrapper([TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_0'), TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_1')])

q_net = q_rnn_network.QRnnNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    lstm_size=(16,),
    )
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)
agent.initialize()

collect_policy = agent.collect_policy
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()


collect_policy.action(time_step)

我收到此错误:

TypeError: policy_state and policy_state_spec structures do not match:
  ()
vs.
  ListWrapper([., .])
4

1 回答 1

0

我进入了代码,似乎对于 RNN,在action(time_step, policy_state, seed)方法中您需要在上一步中提供策略的状态,如文档所述:

policy_state:一个张量,或一个嵌套的字典、张量列表或元组,表示先前的策略状态。 https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/GreedyPolicy#action

你的错误:

TypeError: policy_state and policy_state_spec structures do not match:
  ()
 vs.
  ListWrapper([., .])

想说的是你应该向action方法提供RNN的内部状态。我在文档中找到了一个示例:

https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/TFPolicy#example_usage

它显示的代码(截至 2021 年 8 月 8 日)如下:

env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy

policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()

while not time_step.is_last():
  policy_step = policy.action(time_step, policy_state)
  time_step = env.step(policy_step.action)

  policy_state = policy_step.state
  # policy_step.info may contain side info for logging, such as action log
  # probabilities.

如果您以这种方式实现代码,它可能会起作用!

于 2021-08-08T14:33:49.027 回答