我的理解是torch.nn.CrossEntropyLoss为批次中的每个项目创建损失值,并将其相加以给予最终损失。假设在下面的玩具例子中,我们想在求和之前将3个项目中的每一个的损失与数组multiplier
中的相应元素相乘。我如何做到这一点?
import torch
import numpy as np
multiplier = torch.from_numpy(np.array([1.0, -2.0, 5.0]))
source = torch.from_numpy(
np.array([[0.5, -0.6],
[-3.0, -2.0],
[-4.0, 2.3]]))
target = torch.from_numpy(np.array([0, 1, 0]))
loss_fn = torch.nn.CrossEntropyLoss()
loss_fn(source, target)
字符串
请注意,函数的权重参数对不同的目标索引进行了不同的加权,这不是我想要的。
1条答案
按热度按时间68de4m5k1#
您可以手动进行,不减少损失(不通过取平均值来汇总损失),而是手动进行加权求和:
字符串
参见https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html