python Pytorch多类分类2DTensor

o8x7eapl  于 2023-10-14  发布在  Python
关注(0)|答案(1)|浏览(142)

我使用torch.nn.CrossEntropyLoss(),它只排除1DTensor。我正在执行二进制分类,我知道这是可以做到的,但在阅读所有关于同一主题的帖子后,我似乎无法让它工作。
我的数据形状是:

out.shape = torch.size([1,4])
target.shape = torch.size([1,2])

在跑步之后

criterion = nn.CrossEntropyLoss()
for graphs in training_loader:
    out = model(graphs.x, graphs.edge_index, graphs.batch)
    loss = criterion(out, graphs.y)

我得到

Traceback (most recent call last):
  File "/home/polar/venv/pygeo/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3526, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-19-29fe12a317a0>", line 1, in <module>
    loss = criterion(out, graphs.y.squeeze_(dim=1))
  File "/home/polar/venv/pygeo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/polar/venv/pygeo/lib/python3.10/site-packages/torch/nn/modules/loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/home/polar/venv/pygeo/lib/python3.10/site-packages/torch/nn/functional.py", line 3029, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
RuntimeError: 0D or 1D target tensor expected, multi-target not supported

我试过graphs.y.squeezetorch.max(graphs.y, 1)[0]都没有用。我需要的预测输出4 logits的概率的4类。

x6yk4ghg

x6yk4ghg1#

torch.nn.CrossEntropyLoss期望1DTensor是不正确的。API文档中明确说明了其他情况。
也就是说,预测Tensor和目标Tensor需要具有相同的形状,以使交叉熵损失有意义。如果预测Tensor的形状为(1, 4),则意味着批量大小为1,预测(logits)为4个类。然而,您的目标Tensor具有(1, 2)的形状,因此批量大小为1,但标签仅为2个类。当然,如果你在预测和标签之间缺少两个类,那么无论你如何重塑Tensor,都无法计算损失。

相关问题