如何在pytorch CrossEntropyLoss中对每个批次进行不同的称重

kb5ga3dv  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(98)

我的理解是torch.nn.CrossEntropyLoss为批次中的每个项目创建损失值,并将其相加以给予最终损失。假设在下面的玩具例子中,我们想在求和之前将3个项目中的每一个的损失与数组multiplier中的相应元素相乘。我如何做到这一点?

import torch
import numpy as np
multiplier = torch.from_numpy(np.array([1.0, -2.0, 5.0]))
source = torch.from_numpy(
    np.array([[0.5, -0.6],
              [-3.0, -2.0],
              [-4.0, 2.3]]))
target = torch.from_numpy(np.array([0, 1, 0]))
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn(source, target)

字符串
请注意,函数的权重参数对不同的目标索引进行了不同的加权,这不是我想要的。

68de4m5k

68de4m5k1#

您可以手动进行,不减少损失(不通过取平均值来汇总损失),而是手动进行加权求和:

loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
loss = torch.sum( batch_weights*loss_fn(source, target)) / torch.sum(batch_weights)

字符串
参见https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html

相关问题