权重归一化导致PyTorch中的nan

zdwk9cvp  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(85)

我使用PyTorch 1.2.0中内置的权重归一化。当使用权重范数的层的权重变得接近于0时,权重范数运算导致NaN,其然后在整个网络中传播。为了解决这个问题,我想在PyTorch权重范数函数中的weight_v的范数上添加一个像eps = 1e-6这样的小值。
因此,我试图在本地计算机上找到该函数,并在miniconda3/envs/pytorch1_2/lib/python3.7/site-packages/torch/nn/utils/weight_norm.py(GitHub代码)中找到它,并试图修改它。
我想知道哪个函数在计算权重范数,所以我在每个函数中添加了hi,发现compute_weight函数在每次向前传递之前都会被调用。
此函数调用存储在miniconda3/envs/pytorch1_2/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py(GitHub代码)的_weight_norm函数。当我在_weight_norm函数中添加print("hi")时,它没有被打印出来。
那么,通过将eps添加到权重的范数来修改权重范数代码的正确方法是什么?也许可以用最新的PyTorch 1.9.0版本替换我本地计算机上的_weight_norm函数,但不确定如何添加eps

6ojccjat

6ojccjat1#

在https://github.com/facebookresearch/multiface/blob/main/models.py#L580找到了临时解决方案
所以不使用常规的nn.conv2d,而是使用

class Conv2dWN(nn.Conv2d):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride=1,
        padding=0,
        dilation=1,
        groups=1,
        bias=True,
    ):
        super(Conv2dWN, self).__init__(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            dilation,
            groups,
            True,
        )
        self.g = nn.Parameter(torch.ones(out_channels))

    def forward(self, x):
        wnorm = torch.sqrt(torch.sum(self.weight**2))
        return F.conv2d(
            x,
            self.weight * self.g[:, None, None, None] / wnorm,
            bias=self.bias,
            stride=self.stride,
            padding=self.padding,
            dilation=self.dilation,
            groups=self.groups,
        )

字符串

相关问题