Tensorflow DDPG代理,使用DDPG代理的critic损失函数中的形状不匹配,td_targets和q_valuesTensor之间不匹配

h43kikqp  于 2023-10-23  发布在  其他
关注(0)|答案(1)|浏览(119)

我在使用DDPG代理计算critic损失函数中的均方误差时遇到了一个问题。我收到的错误消息表明,在td_targets和q_valuesTensor之间,DDPG代理的critic损失函数中的预期Tensor形状和实际Tensor形状之间存在形状不匹配。
下面是相关的代码片段:

# Create the agent 
self.ddpg_agent = DdpgAgent(
            time_step_spec=self.tf_env.time_step_spec(),
            action_spec=self.tf_env.action_spec(),
            actor_network=actor_network,
            critic_network=critic_network,
            actor_optimizer=Adam(learning_rate=self.learning_rate),
            critic_optimizer=Adam(learning_rate=self.learning_rate),
            gamma=self.discount_factor,  
            target_update_tau=0.01, 
            ou_stddev=0.3,
            ou_damping=0.3,
            td_errors_loss_fn=common.element_wise_squared_loss,
        )
# Initialize replay buffer
replay_buffer = replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer(
            data_spec=self.ddpg_agent.collect_data_spec,
            batch_size=1,  
            max_length=5000)
#Add experiences to the replay buffer
experience = trajectory.from_transition(time_step, action_step, next_time_step)
            replay_buffer.add_batch(experience)
# Create the dataset
dataset = replay_buffer.as_dataset(
            sample_batch_size=self.batch_size, # self.batch_size = 32
            num_steps=2,  
            num_parallel_calls=3,  
            single_deterministic_pass=False
).prefetch(3)
#Train the agent 
iterator = iter(dataset)                
experience_set, _ = next(iterator)
loss = self.ddpg_agent.train(experience_set)

如果我运行代码,它会在损失计算过程中被错误中断:

File "main.py", line 138, in <module>
    main()
  File "main.py", line 109, in main
    a2c.train_agent()
  File "a2c.py", line 41, in train_agent
    self.agent.train_agent()
  File "agent.py", line 161, in train_agent
    loss = self.ddpg_agent.train(experience_set)
  File "tf_agents\agents\tf_agent.py", line 330, in train
    loss_info = self._train_fn(
  File "tf_agents\utils\common.py", line 188, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "tf_agents\agents\ddpg\ddpg_agent.py", line 247, in _train
    critic_loss = self.critic_loss(time_steps, actions, next_time_steps,
  File "tf_agents\agents\ddpg\ddpg_agent.py", line 343, in critic_loss
    critic_loss = self._td_errors_loss_fn(td_targets, q_values)
  File "tf_agents\utils\common.py", line 1139, in element_wise_squared_loss
    return tf.compat.v1.losses.mean_squared_error(
  File "tensorflow\python\util\traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "tensorflow\python\framework\tensor_shape.py", line 1361, in assert_is_compatible_with
    raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (32, 1) and (32, 32) are incompatible

我检查了所有的spec_shapes、experience shapes以及我的actor和critic网络的输出形状。它们看起来都是正确的,actor和critic输出层产生的预期形状为(32,1),其中批量大小为32。不匹配的是tf_agents\agents\ddpg\ddpg_agent.py中损失函数中的td_targets和q_values之间的不匹配,其中:TD Targets形状:(32,32)Q值形状:(32,1)
有人能告诉我我错过了什么吗?

ldxq2e6h

ldxq2e6h1#

我通过在DDPG初始化时选择另一个损失函数来解决这个问题:

# Create the agent 
self.ddpg_agent = DdpgAgent(
            time_step_spec=self.tf_env.time_step_spec(),
            action_spec=self.tf_env.action_spec(),
            actor_network=actor_network,
            critic_network=critic_network,
            actor_optimizer=Adam(learning_rate=self.learning_rate),
            critic_optimizer=Adam(learning_rate=self.learning_rate),
            gamma=self.discount_factor,  
            target_update_tau=0.01, 
            ou_stddev=0.3,
            ou_damping=0.3,
            td_errors_loss_fn=tf.keras.losses.MeanSquaredError(),
        )

相关问题