pytorch 为什么nn.参数没有保存在模型state_dict中?

flvlnr44  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(273)

我尝试在Pytorch(版本1.9.1)中训练一个自定义模型,并在每个训练步骤中保存模型权重。为此,我使用torch.save(model.state_dict(), 'filename.pt')。训练后,当我试图加载保存的权重时,我得到一个错误,说模型的state_dict中缺少一个键,这表明模型的某些部分没有正确保存。
我找到了罪魁祸首,它是一个n.Parameter对象初始化如下:

self.class_token = nn.Parameter(torch.rand(1, self.hidden_d)).to(self.device)

我不知道为什么它没有拯救。类似问题的报告似乎暗示它可能是the conversion between devices,但网络中所有其他层的转换方式与此参数相同。此外,模型代码是由我的一个同事编写的,保存/加载对他来说没有任何问题。可能是版本问题,他运行的是最新版本的Pytorch,但我运行的系统很难升级到Pytorch 2.0,如果不是不可能的话。
是否有其他方法可以在设备之间转换,而不会将其从模型中分离?否则,我是否必须手动添加此参数,以便将其保存在state_dict中?

ogq8wdun

ogq8wdun1#

.to(self.device)操作移动到Parameter调用中:

self.class_token = nn.Parameter(torch.rand(1, self.hidden_d).to(self.device))

您可能会遇到这个问题,因为您正在尝试为GPU加载Tensor,但实际上,Tensor正在为CPU加载。
检查类似的问题here

相关问题