pytorch Torch softmax with value threshold

3hvapo4f  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(123)

我尝试用一个阈值来应用softmax。就像,给定一个阈值,输出的总和仍然是1,但是每个值都低于阈值。
这是我当前的代码-我构建了剪辑模块来处理Softmax输出。然而,我收到下面的错误。我该怎么做才能让它工作?

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [200, 31]], which is output 0 of SoftmaxBackward0, is at version 3; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck

个字符

8cdiaqws

8cdiaqws1#

您在PyTorch中遇到的RuntimeError是由于Clip类中的一个inplace操作,特别是x[i] = new_x_i。该操作修改了原始Tensorx,PyTorch要求通过其计算图进行梯度计算时保持不变。
要解决这个问题,请避免对x进行inplace修改;相反,使用一个新的Tensor并返回它。下面修改后的Clip类避免了inplace操作,从而防止了错误:

class Clip(nn.Module):
    def __init__(self, threshold):
        super().__init__()
        self.threshold = threshold

    def clip(self, x):
        clipped_x = torch.zeros_like(x)
        for i in range(len(x)):
            if torch.all(x[i] < self.threshold):
                clipped_x[i] = x[i]
                continue
            ind = torch.where(x[i] >= self.threshold)[0]
            new_x_i = torch.zeros_like(x[i])
            if len(ind) > 0:
                new_x_i[ind] = x[i][ind] / torch.sum(x[i][ind])
            new_x_i[~ind] = self.threshold
            clipped_x[i] = new_x_i
        return clipped_x

    def forward(self, x):
        return self.clip(x)

字符串

ax6ht2ek

ax6ht2ek2#

首先,检查你自己的代码。当你的代码发现一个值超过阈值的行时,它会替换阈值的值,但也会将所有其他值清零,我认为这不是你的意图。
你所描述的任务实际上很难有效地完成。也许如果你能提供更多关于为什么你需要这个的信息,我可以想出一个更好的解决方案。
为了说明为什么这是困难的,采取下面的代码。
给定一个输入x和一个threshold,我们:
1.计算原始softmax值
1.为threshold下的值创建threshold_mask
1.计算under_threshold_sumsover_threshold_sumsunder_threshold_sums是我们要放大的低于阈值的当前总和。over_threshold_sums将是在我们将它们裁剪到阈值之后的超过阈值的总和值
1.将rescale_values计算为(1-over_threshold_sums)/(under_threshold_sums)。假设我们有两个值超过0.3threshold。这意味着这两个值将被裁剪为0.3。裁剪后的两个超过阈值的值将0.6贡献到总数中。这意味着我们希望缩放后的低于阈值的值之和为0.4。因此形式为(1-over_threshold_sums)/(under_threshold_sums)
1.计算rescale_mask。这个掩码在所有位置都有我们的rescale_values,除了超过阈值的位置
1.计算scaled_softmax_vals,方法是将阈值裁剪到threshold,并按比例放大低于阈值的值,以便这些值的总和仍然为1。

import torch
import torch.nn.functional as F

x = torch.randn(8, 32)/0.1 # divide by 0.1 to push more values over the thresold
threshold = 0.3

# 1
softmax_vals = F.softmax(x, -1)

# 2
threshold_mask = softmax_vals < threshold

# 3
under_threshold_sums = (softmax_vals * threshold_mask).sum(-1)
over_threshold_sums = threshold * (~threshold_mask).sum(-1)

# 4
rescale_values = (1-over_threshold_sums)/(under_threshold_sums)

# 5
rescale_mask = rescale_values.unsqueeze(1).expand(-1, softmax_vals.shape[-1])
rescale_mask = rescale_mask.masked_fill(~threshold_mask, 1)

# 6
scaled_softmax_vals = softmax_vals.clip(max=threshold) * rescale_mask

字符串
这一方法实现了:
1.将阈值限幅到阈值
1.缩放低于阈值的值,使得总softmax值的总和仍然为1

然而有一个问题。放大一个低于阈值的值可能会导致创建一个新的超过阈值的值。这就是为什么你的问题特别具有挑战性。你必须迭代地应用上面的代码,直到所有值都低于阈值,这可能没有收敛保证。

还有另一个问题,在softmax输出中有很多非常小的值(即0.99999...对于一个值,其他值都很小)将导致这种重新缩放产生nan。
我建议询问 * 为什么 * 你需要这种特定类型的softmax重新缩放,以及为什么更传统的方法,如温度缩放或标签平滑不适合你的需要。

相关问题