pytorch 正确处理具有可选成员(可以是无)的模型?

lymnna71  于 11个月前  发布在  其他
关注(0)|答案(2)|浏览(131)

我有一个torch.nn.Module的子类,它的初始化器有以下形式:(在类A中)

def __init__(self, additional_layer=False):
    ...
    if additional_layer:
        self.additional = nn.Sequential(nn.Linear(8,3)).to(self.device)
    else:
        self.additional = None
    ...
    ...

字符串
我使用additional_layer=True进行训练,并使用torch.save保存模型。我保存的对象是model.state_dict()。然后我加载模型进行推理。但随后我得到以下错误:

model.load_state_dict(best_model["my_model"])

RuntimeError: Error(s) in loading state_dict for A:
        Unexpected key(s) in state_dict: "additional.0.weight"


是使用一个可选字段,可以是无不允许??如何正确处理这一点?[还张贴here ]

hm2xizp9

hm2xizp91#

使用torch.nn.Identity()而不是None
或者你可以在forward()函数中传递Tensor。

umuewwlo

umuewwlo2#

这不是一个特别与值为None相关的问题;如果你使用任何其他的nn.Module,你会遇到同样的问题。(作为additional属性的值)不是序列(additional之后的0),并且在顺序模块的第一个nn.Module中没有名为weight的参数(additional.0之后的weight)。
问题是,在你的训练模式中,当你初始化你的模型时,你已经为additional_layer参数传递了True,即:

model = YourModelClass(additional_layer=True)

字符串
因此,self.additional被设置为nn.Module(具体为nn.Sequential)。因此,model对象的state_dict将具有self.additional属性所引用的模块的参数。
现在,当您重新初始化模型以进行 * 推断 * 时,您没有额外的层,因为您可能通过以下方式之一来初始化模型:

model = YourModelClass(additional_layer=False)
model = YourModelClass()


这一次,self.additional(即model.additional属性)将是None。因此,当您调用model.load_state_dict并将之前在 train 模式下保存的状态字典传递给它时(当附加层存在时),它会给您一个例外,即additional属性的所有键都丢失了。
假设当self.additionalNone时,你在forward方法中有正确的条件设置,你可以忽略异常,并通过在使用load_state_dict时将strict参数设置为False来绕过加载缺少的键/参数:

model.load_state_dict(best_model["my_model"], strict=False)

相关问题