pytorch 数据发生微小变化时,损失为NaN

j5fpnvbx  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(106)

在我的PyTorch项目中,我稍微编辑了一下数据,现在我得到了loss.item()的nan。使用旧数据,它仍然在计算一个很大的损失。
新数据

旧数据

谁能告诉我为什么会这样?

xfb7svmp

xfb7svmp1#

可能导致此问题的几个因素:
1.学习速率太大。请尝试设定较小的学习速率,看看这样是否可以解决问题。
1.您需要将输入到网络中的数据归一化。您可以尝试X = X - X.mean(axis=0)/X.std(axis=0),或者因为数据看起来像一个numpy数组,所以选择在转换为Tensor之前使用scikit-learn对其进行预处理。例如:最小最大缩放器1
1.尝试添加一个batchnorm层(例如:nn.BatchNorm1d)连接到您的网络,以进一步稳定层到层的输出。
1.检查数据中的离群值(如果有非常大的值或nan值),并将其过滤掉。
1.您可以通过设置torch.autograd.detect_anomaly(True)进行调试

相关问题