我使用PyTorch 1.2.0中内置的权重归一化。当使用权重范数的层的权重变得接近于0时,权重范数运算导致NaN
,其然后在整个网络中传播。为了解决这个问题,我想在PyTorch权重范数函数中的weight_v
的范数上添加一个像eps = 1e-6
这样的小值。
因此,我试图在本地计算机上找到该函数,并在miniconda3/envs/pytorch1_2/lib/python3.7/site-packages/torch/nn/utils/weight_norm.py
(GitHub代码)中找到它,并试图修改它。
我想知道哪个函数在计算权重范数,所以我在每个函数中添加了hi
,发现compute_weight
函数在每次向前传递之前都会被调用。
此函数调用存储在miniconda3/envs/pytorch1_2/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py
(GitHub代码)的_weight_norm
函数。当我在_weight_norm
函数中添加print("hi")
时,它没有被打印出来。
那么,通过将eps
添加到权重的范数来修改权重范数代码的正确方法是什么?也许可以用最新的PyTorch 1.9.0版本替换我本地计算机上的_weight_norm
函数,但不确定如何添加eps
。
1条答案
按热度按时间6ojccjat1#
在https://github.com/facebookresearch/multiface/blob/main/models.py#L580找到了临时解决方案
所以不使用常规的
nn.conv2d
,而是使用字符串