加载从分布式模式下的模型保存的模型时,模型名称不同,导致此错误。我该如何解决这个问题?
File "/code/src/bert_structure_prediction/model.py", line 36, in __init__
self.load_state_dict(state_dict)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1223, in load_state
_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BertCoordinatePredictor:
Missing key(s) in state_dict: "bert.embeddings.position_ids", "bert.embeddings.word_embeddin
gs.weight", ...etc.
字符串
2条答案
按热度按时间mwkjh3gx1#
模型名称不匹配的原因是因为DDP Package 了模型对象,导致在以分布式数据并行模式保存模型时出现不同的层名称(具体来说,层名称将在模型名称前添加
module.
)。要解决此问题,请使用字符串
代替了
型
当从数据并行保存时。
41ik7eoe2#
您可以在
load_state_dict()
函数中将strict
参数设置为False
,以忽略不匹配的键:字符串