pytorch 自定义函数为1维Tensor错误提供过多索引

p3rjfoxz  于 2023-01-30  发布在  其他
关注(0)|答案(1)|浏览(428)

我正在计算损失的多类(7)分类程序使用pytorch。

class AFL(nn.Module):
   
    def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
        super(AFL, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        #y_pred=y_pred.size()[1]
        print(y_pred.shape) #[32,7]
        print(y_true.shape) #[32]
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = np.empty(y_pred.shape)
        for i in range(len(y_pred)):
            for j in range(len(y_pred[i])):
                cross_entropy[i][j] = -y_true * torch.log(y_pred[i][j])
        #cross_entropy = -y_true * torch.log(y_pred[0][0]) #here i want to calculate cross_entropy for for each class
        
    # Calculate losses separately for each class, only suppressing background class
        back_ce = torch.pow(1 - y_pred[:,0], self.gamma) * cross_entropy[:,0]
        back_ce =  (1 - self.delta) * back_ce

        fore_ce = cross_entropy[:,1,:,:]
        fore_ce = self.delta * fore_ce

        loss = torch.mean(torch.sum(torch.stack([back_ce, fore_ce], axis=-1), axis=-1))
        return loss

我想分别计算每个类的back_ce,但得到的误差为:

back_ce = torch.pow(1 - y_pred[:,0], self.gamma) * cross_entropy[:,0]
IndexError: too many indices for tensor of dimension 1

有人能告诉我哪里做错了吗?提到了y_pred和y_true的大小。

ycl3bljg

ycl3bljg1#

以下是具有多个常见类和罕见类的多类的AFL。

class AsymmetricFocalLoss(nn.Module):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.25
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    epsilon : float, optional
        clip values to prevent division by zero error
    common : list, required
        a list of common class indices
    rare : list, required
        a list of rare class indices
    """
    def __init__(self, common, rare, delta=0.7, gamma=2., epsilon=1e-07):
        super(AsymmetricFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon
        self.common = common
        self.rare = rare

    def forward(self, y_pred, y_labels):
        # assume y_pred contain probabilities (batch_size_ n_class)
        # y_labels contain integer class lables (batch_size, )

        # convert one-hot
        y_true = torch.zeros_like(y_pred)
        for i,j in enumerate(y_labels):y_true[i, j]=1

        # clamp
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = -y_true * torch.log(y_pred)
        #print(f'{cross_entropy.shape=}\n{cross_entropy=}')
    
        # Calculate losses separately for each class,
        all_ce=[]

        for c in self.common:
            back_ce = (1 - self.delta) * (torch.pow(1 - y_pred[:,c], self.gamma) * cross_entropy[:,c])
            all_ce.append(back_ce)

        for r in self.rare:
            fore_ce=self.delta * cross_entropy[:,r]
            all_ce.append(fore_ce)

        loss_stack = torch.stack(all_ce, axis=-1)
        #print(f'{loss_stack.shape=}\n{loss_stack=}')

        loss_sum=torch.sum(loss_stack, axis=-1)
        #print(f'{loss_sum.shape=}\n{loss_sum=}')

        loss = torch.mean(loss_sum)

        return loss

为了使用这个,

batch_size = 5
n_class = 7

y_pred = torch.softmax( torch.rand((batch_size, n_class)), dim=-1)
y_labels = torch.randint(0, n_class, size=(batch_size,))
print(f'{y_pred=}\n{y_labels=}')

lossF = AsymmetricFocalLoss(common = [0,2,4,6], rare = [1,3,5])
loss = lossF(y_pred, y_labels)

print(f'{loss=}')

输出:

"""
y_pred=tensor([[0.1955, 0.1455, 0.0976, 0.1869, 0.1043, 0.1173, 0.1529],
        [0.1613, 0.1635, 0.1121, 0.1290, 0.1571, 0.0993, 0.1777],
        [0.0978, 0.1340, 0.1025, 0.1993, 0.2197, 0.1041, 0.1425],
        [0.1371, 0.1113, 0.1771, 0.1560, 0.0897, 0.1554, 0.1734],
        [0.1960, 0.1890, 0.1403, 0.1076, 0.1714, 0.1079, 0.0878]])
y_labels=tensor([0, 3, 2, 5, 3])
loss=tensor(1.0328)
"""

相关问题