python 手动计算pytorch中的交叉熵损失

c0vxltue  于 2023-01-19  发布在  Python
关注(0)|答案(2)|浏览(182)

我尝试在Pytorch中手动计算编码器-解码器模型的cross_entropy loss。
我使用这里发布的代码来计算它:Cross Entropy in PyTorch
我更新了代码以丢弃填充的令牌(-100)。

class compute_crossentropyloss_manual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:  # <------ I added this if-statement
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss

为了验证它是否工作正常,我在一个文本生成任务中测试了它,并使用pytorch.nn实现和下面的代码计算了损失。
损失值不相同:
使用nn.CrossEntropyLoss

使用上面链接中的代码:

我错过什么了吗?
我试着得到nn.CrossEntropyLoss的源代码,但是我不能,在nn/www.example.com的链接第2955行,你会看到函数指向另一个cross_entropy损失,叫做torch._C._nn.cross_entropy_loss;functional.py at line 2955, you will see that the function points to another cross_entropy loss called torch._C._nn.cross_entropy_loss ; I can't find this function in the repo.

    • 编辑:**

我注意到,只有当我在黄金中有-100令牌时,才会出现差异。
演示示例:

y = torch.randint(1, 50, (100, 50), dtype=float)
x = torch.randint(1, 50, (100,))

x[40:] = -100
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 25.55788695847976
> 10.223154783391905

如果没有-100

x[40:] = 30 # any positive number
print(criterion(y, x).item())
print(criterion2(y, x).item())
> 24.684453267596453
> 24.684453267596453
hmmo2u0o

hmmo2u0o1#

我通过更新代码解决了这个问题。我在-100标记之前丢弃了(上面的if语句),但是我忘记了减小hidden_state的大小(在上面的代码中称为n_batch)。在这样做之后,丢失数与nn.CrossEntropyLoss的值相同。最后的代码:

class CrossEntropyLossManual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                n_batch -= 1
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
        loss = - loss/n_batch
        return loss
fdx2calv

fdx2calv2#

我也需要这个--谢谢你的手动交叉熵损失代码。它与pytorch的结果完美匹配(与我的数据匹配)。我对你上面的修复做了一个小小的修复。最后你需要除以未忽略行的最终计数(那些没有标签-100的行)。所以你需要一个计数器:

class compute_crossentropyloss_manual:
    """
    y0 is the vector with shape (batch_size,C)
    x shape is the same (batch_size), whose entries are integers from 0 to C-1
    """
    def __init__(self, ignore_index=-100) -> None:
        self.ignore_index=ignore_index
    
    def __call__(self, y0, x):
        loss = 0.
        n_batch, n_class = y0.shape
        # print(n_class)
        cnt = 0             # <----- I added this
        for y1, x1 in zip(y0, x):
            class_index = int(x1.item())
            if class_index == self.ignore_index:
                continue
            loss = loss + torch.log(torch.exp(y1[class_index])/(torch.exp(y1).sum()))
            cnt += 1        # <----- I added this
        loss = - loss/cnt   # <---- I changed this from nbatch to 'cnt'
        return loss

相关问题