需求描述 Feature Description
我现在正在想将一个基于torch搭建深度学习模型和paddle.parld的强化学习模型联合起来使用,目前出现了以下问题:
Exception has occurred: ValueError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
(InvalidArgument) unsqueeze2(): argument 'X' (position 0) must be Tensor, but got Tensor (at /paddle/paddle/fluid/pybind/op_function_common.cc:737)
File "/root/miniconda3/envs/pylight/lib/python3.8/site-packages/paddle/fluid/layers/nn.py", line 6572, in unsqueeze
out, _ = _C_ops.unsqueeze2(input, 'axes', axes)
File "/root/miniconda3/envs/pylight/lib/python3.8/site-packages/paddle/tensor/manipulation.py", line 1339, in unsqueeze
return layers.unsqueeze(x, axis, name)
File "/root/miniconda3/envs/pylight/lib/python3.8/site-packages/paddle/nn/functional/conv.py", line 379, in conv1d
x = unsqueeze(x, axis=[squeeze_aixs])
File "/root/miniconda3/envs/pylight/lib/python3.8/site-packages/paddle/nn/layer/conv.py", line 339, in forward
out = F.conv1d(
File "/root/miniconda3/envs/pylight/lib/python3.8/site-packages/paddle/fluid/dygraph/layers.py", line 915, in _dygraph_call_func
outputs = self.forward(*inputs, **kwargs)
File "/root/miniconda3/envs/pylight/lib/python3.8/site-packages/paddle/fluid/dygraph/layers.py", line 930, in __call__
return self._dygraph_call_func(*inputs, **kwargs)
File "/data4/jack/HiVT_ours/models/ppo/atari_model.py", line 52, in value
out = F.relu(self.conv1(obs))
File "/data4/jack/HiVT_ours/models/ppo/ppo.py", line 170, in sample
value = self.model.value(obs)
File "/data4/jack/HiVT_ours/models/ppo/agent.py", line 63, in sample
value, action, action_log_probs, action_entropy = self.alg.sample(obs_tensor)
File "/data4/jack/HiVT_ours/models/drl_decoder.py", line 170, in run_episodes
value, action, logprob, _ = self.super_agent.sample(obs)
File "/data4/jack/HiVT_ours/models/drl_decoder.py", line 72, in forward
game_scores, rolling_scores, time_taken = self.run_episodes(agents_embed, self.config)
File "/data4/jack/HiVT_ours/models/drl_decoder.py", line 424, in <module>
result = ppo_decoder.forward(agents_embed, agents_position, lanelets_embed, gt_target_loaction, info)
File "/root/miniconda3/envs/pylight/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/root/miniconda3/envs/pylight/lib/python3.8/runpy.py", line 194, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
我猜测是由于前端深度学习模型输出的torch.tensor和paddl.tensor类型不一致引起的问题,所有请问能否做这样的一个类型转换。
替代实现 Alternatives
目前想到的解决方案是,先将torch,tensor转成numpy,然后用paddle.to_tensor处理
1条答案
按热度按时间oyjwcjzk1#
@Jacky-gsq 两个框架的 tensor 底层数据结构不一样的,所以推荐先转成numpy,然后用paddle.to_tensor处理。