我使用torch.nn.CrossEntropyLoss()
,它只排除1DTensor。我正在执行二进制分类,我知道这是可以做到的,但在阅读所有关于同一主题的帖子后,我似乎无法让它工作。
我的数据形状是:
out.shape = torch.size([1,4])
target.shape = torch.size([1,2])
在跑步之后
criterion = nn.CrossEntropyLoss()
for graphs in training_loader:
out = model(graphs.x, graphs.edge_index, graphs.batch)
loss = criterion(out, graphs.y)
我得到
Traceback (most recent call last):
File "/home/polar/venv/pygeo/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-19-29fe12a317a0>", line 1, in <module>
loss = criterion(out, graphs.y.squeeze_(dim=1))
File "/home/polar/venv/pygeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/polar/venv/pygeo/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1174, in forward
return F.cross_entropy(input, target, weight=self.weight,
File "/home/polar/venv/pygeo/lib/python3.10/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: 0D or 1D target tensor expected, multi-target not supported
我试过graphs.y.squeeze
和torch.max(graphs.y, 1)[0]
都没有用。我需要的预测输出4 logits的概率的4类。
1条答案
按热度按时间x6yk4ghg1#
torch.nn.CrossEntropyLoss
期望1DTensor是不正确的。API文档中明确说明了其他情况。也就是说,预测Tensor和目标Tensor需要具有相同的形状,以使交叉熵损失有意义。如果预测Tensor的形状为
(1, 4)
,则意味着批量大小为1,预测(logits)为4个类。然而,您的目标Tensor具有(1, 2)
的形状,因此批量大小为1,但标签仅为2个类。当然,如果你在预测和标签之间缺少两个类,那么无论你如何重塑Tensor,都无法计算损失。