我尝试在Pytorch中手动计算编码器-解码器模型的cross_entropy loss。
我使用这里发布的代码来计算它:Cross Entropy in PyTorch
我更新了代码以丢弃填充的令牌(-100)。
class compute_crossentropyloss_manual:
"""
y0 is the vector with shape (batch_size,C)
x shape is the same (batch_size), whose entries are integers from 0 to C-1
"""
def __init__(self, ignore_index=-100) -> None:
self.ignore_index=ignore_index
def __call__(self, y0, x):
loss = 0.
n_batch, n_class = y0.shape
# print(n_class)
for y1, x1 in zip(y0, x):
class_index = int(x1.item())
if class_index == self.ignore_index: # <------ I added this if-statement
continue
loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
loss = - loss/n_batch
return loss
为了验证它是否工作正常,我在一个文本生成任务中测试了它,并使用pytorch.nn实现和下面的代码计算了损失。
损失值不相同:
使用nn.CrossEntropyLoss
:
使用上面链接中的代码:
我错过什么了吗?
我试着得到nn.CrossEntropyLoss
的源代码,但是我不能,在nn/www.example.com的链接第2955行,你会看到函数指向另一个cross_entropy损失,叫做torch._C._nn.cross_entropy_loss
;functional.py at line 2955, you will see that the function points to another cross_entropy loss called torch._C._nn.cross_entropy_loss
; I can't find this function in the repo.
- 编辑:**
我注意到,只有当我在黄金中有-100
令牌时,才会出现差异。
演示示例:
y = torch.randint(1, 50, (100, 50), dtype=float)
x = torch.randint(1, 50, (100,))
x[40:] = -100
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 25.55788695847976
> 10.223154783391905
如果没有-100
x[40:] = 30 # any positive number
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 24.684453267596453
> 24.684453267596453
2条答案
按热度按时间hmmo2u0o1#
我通过更新代码解决了这个问题。我在
-100
标记之前丢弃了(上面的if语句),但是我忘记了减小hidden_state的大小(在上面的代码中称为n_batch
)。在这样做之后,丢失数与nn.CrossEntropyLoss
的值相同。最后的代码:fdx2calv2#
我也需要这个--谢谢你的手动交叉熵损失代码。它与pytorch的结果完美匹配(与我的数据匹配)。我对你上面的修复做了一个小小的修复。最后你需要除以未忽略行的最终计数(那些没有标签-100的行)。所以你需要一个计数器: