我在使用DDPG代理计算critic损失函数中的均方误差时遇到了一个问题。我收到的错误消息表明,在td_targets和q_valuesTensor之间,DDPG代理的critic损失函数中的预期Tensor形状和实际Tensor形状之间存在形状不匹配。
下面是相关的代码片段:
# Create the agent
self.ddpg_agent = DdpgAgent(
time_step_spec=self.tf_env.time_step_spec(),
action_spec=self.tf_env.action_spec(),
actor_network=actor_network,
critic_network=critic_network,
actor_optimizer=Adam(learning_rate=self.learning_rate),
critic_optimizer=Adam(learning_rate=self.learning_rate),
gamma=self.discount_factor,
target_update_tau=0.01,
ou_stddev=0.3,
ou_damping=0.3,
td_errors_loss_fn=common.element_wise_squared_loss,
)
# Initialize replay buffer
replay_buffer = replay_buffers.tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=self.ddpg_agent.collect_data_spec,
batch_size=1,
max_length=5000)
#Add experiences to the replay buffer
experience = trajectory.from_transition(time_step, action_step, next_time_step)
replay_buffer.add_batch(experience)
# Create the dataset
dataset = replay_buffer.as_dataset(
sample_batch_size=self.batch_size, # self.batch_size = 32
num_steps=2,
num_parallel_calls=3,
single_deterministic_pass=False
).prefetch(3)
#Train the agent
iterator = iter(dataset)
experience_set, _ = next(iterator)
loss = self.ddpg_agent.train(experience_set)
如果我运行代码,它会在损失计算过程中被错误中断:
File "main.py", line 138, in <module>
main()
File "main.py", line 109, in main
a2c.train_agent()
File "a2c.py", line 41, in train_agent
self.agent.train_agent()
File "agent.py", line 161, in train_agent
loss = self.ddpg_agent.train(experience_set)
File "tf_agents\agents\tf_agent.py", line 330, in train
loss_info = self._train_fn(
File "tf_agents\utils\common.py", line 188, in with_check_resource_vars
return fn(*fn_args, **fn_kwargs)
File "tf_agents\agents\ddpg\ddpg_agent.py", line 247, in _train
critic_loss = self.critic_loss(time_steps, actions, next_time_steps,
File "tf_agents\agents\ddpg\ddpg_agent.py", line 343, in critic_loss
critic_loss = self._td_errors_loss_fn(td_targets, q_values)
File "tf_agents\utils\common.py", line 1139, in element_wise_squared_loss
return tf.compat.v1.losses.mean_squared_error(
File "tensorflow\python\util\traceback_utils.py", line 153, in error_handler
raise e.with_traceback(filtered_tb) from None
File "tensorflow\python\framework\tensor_shape.py", line 1361, in assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (32, 1) and (32, 32) are incompatible
我检查了所有的spec_shapes、experience shapes以及我的actor和critic网络的输出形状。它们看起来都是正确的,actor和critic输出层产生的预期形状为(32,1),其中批量大小为32。不匹配的是tf_agents\agents\ddpg\ddpg_agent.py中损失函数中的td_targets和q_values之间的不匹配,其中:TD Targets形状:(32,32)Q值形状:(32,1)
有人能告诉我我错过了什么吗?
1条答案
按热度按时间ldxq2e6h1#
我通过在DDPG初始化时选择另一个损失函数来解决这个问题: