pytorch 加载与分布式数据并行保存的模型时出错

iezvtpos  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(107)

加载从分布式模式下的模型保存的模型时,模型名称不同,导致此错误。我该如何解决这个问题?

File "/code/src/bert_structure_prediction/model.py", line 36, in __init__                         
    self.load_state_dict(state_dict)                                                                
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1223, in load_state
_dict                                                                                               
    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(                       
RuntimeError: Error(s) in loading state_dict for BertCoordinatePredictor:                           
        Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddin
gs.weight", ...etc.

字符串

mwkjh3gx

mwkjh3gx1#

模型名称不匹配的原因是因为DDP Package 了模型对象,导致在以分布式数据并行模式保存模型时出现不同的层名称(具体来说,层名称将在模型名称前添加module.)。要解决此问题,请使用

torch.save(model.module.state_dict(), PATH)

字符串
代替了

torch.save(model.state_dict(), PATH)


当从数据并行保存时。

41ik7eoe

41ik7eoe2#

您可以在load_state_dict()函数中将strict参数设置为False,以忽略不匹配的键:

model.load_state_dict(torch.load(path, map_location=torch.device('cpu')), strict=False)

字符串

相关问题