我有一个torch.nn.Module的子类,它的初始化器有以下形式:(在类A中)
def __init__(self, additional_layer=False):
...
if additional_layer:
self.additional = nn.Sequential(nn.Linear(8,3)).to(self.device)
else:
self.additional = None
...
...
字符串
我使用additional_layer=True进行训练,并使用torch.save
保存模型。我保存的对象是model.state_dict()
。然后我加载模型进行推理。但随后我得到以下错误:
model.load_state_dict(best_model["my_model"])
RuntimeError: Error(s) in loading state_dict for A:
Unexpected key(s) in state_dict: "additional.0.weight"
型
是使用一个可选字段,可以是无不允许??如何正确处理这一点?[还张贴here ]
2条答案
按热度按时间hm2xizp91#
使用
torch.nn.Identity()
而不是None
。或者你可以在
forward()
函数中传递Tensor。umuewwlo2#
这不是一个特别与值为
None
相关的问题;如果你使用任何其他的nn.Module
,你会遇到同样的问题。(作为additional
属性的值)不是序列(additional
之后的0
),并且在顺序模块的第一个nn.Module
中没有名为weight
的参数(additional.0
之后的weight
)。问题是,在你的训练模式中,当你初始化你的模型时,你已经为
additional_layer
参数传递了True
,即:字符串
因此,
self.additional
被设置为nn.Module
(具体为nn.Sequential
)。因此,model
对象的state_dict
将具有self.additional
属性所引用的模块的参数。现在,当您重新初始化模型以进行 * 推断 * 时,您没有额外的层,因为您可能通过以下方式之一来初始化模型:
型
这一次,
self.additional
(即model.additional
属性)将是None
。因此,当您调用model.load_state_dict
并将之前在 train 模式下保存的状态字典传递给它时(当附加层存在时),它会给您一个例外,即additional
属性的所有键都丢失了。假设当
self.additional
是None
时,你在forward
方法中有正确的条件设置,你可以忽略异常,并通过在使用load_state_dict
时将strict
参数设置为False
来绕过加载缺少的键/参数:型