pytorch 焊炬交叉熵损失计算2D输入和3D输入之间的差异

cngwdvgl  于 2023-02-19  发布在  其他
关注(0)|答案(1)|浏览(219)

我正在torch.nn.CrossEntropyLoss上运行一个测试。我正在使用官方页面上显示的示例。

loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=False)
target = torch.randn(3, 5).softmax(dim=1)
output = loss(input, target)

输出是2.05。在这个例子中,输入和目标都是2DTensor。因为在大多数NLP的情况下,输入应该是3DTensor,相应的输出也应该是3DTensor。因此,我写了几行测试代码,发现了一个奇怪的问题。

input = torch.stack([input])
target = torch.stack([target])
output = loss(ins, ts)

输出是0.9492这个结果真的让我很困惑,除了维数,Tensor里面的数字完全一样。2有人知道为什么会有这样的差别吗?
我测试该方法的原因是我正在使用Transformers.BartForConditionalGeneration处理项目。损失结果在输出中给出,输出始终在(1,)shape。输出令人困惑。如果我的批量大小大于1,我应该得到批量大小的损失,而不是只有一个。我看了看代码,它只是简单地使用nn.CrossEntropyLoss(),所以我考虑问题可能出在nn.CrossEntropyLoss()方法上,但是卡在方法上了。

2vuwiymt

2vuwiymt1#

在第二种情况下,您添加了一个额外的维度,这意味着最终,logitsTensor(input)上的softmax不会应用于其他维度。
这里我们分别计算这两个量:

>>> loss = nn.CrossEntropyLoss()
>>> input = torch.randn(3, 5, requires_grad=False)
>>> target = torch.randn(3, 5).softmax(dim=1)

首先,您有loss(input, target),它等同于:

>>> o = -target*F.log_softmax(input, 1)
>>> o.sum(1).mean()

第二个场景loss(input[None], target[None])与以下内容相同:

>>> o = -target[None]*F.log_softmax(input[None], 1)
>>> o.sum(1).mean()

相关问题