PyTorch:从另一个模型加载权重而不保存

w80xi6nr  于 12个月前  发布在  其他
关注(0)|答案(3)|浏览(139)

假设我在PyTorch中有两个模型,我如何在不保存权重的情况下通过模型2的权重加载模型1的权重?
就像这样:

model1.weights = model2.weights

字符串
在TensorFlow中,我可以这样做:

variables1 = model1.trainable_variables
variables2 = model2.trainable_variables
for v1, v2 in zip(variables1, variables2):
    v1.assign(v2.numpy())

iibxawm4

iibxawm41#

假设你有两个相同模型的示例(必须子类化nn.Module),那么你可以使用nn.Module.state_dict()nn.Module.load_state_dict()。你可以找到状态字典here的简要介绍。

model1.load_state_dict(model2.state_dict())

字符串

anauzrmj

anauzrmj2#

这里有两种方法可以做到这一点。

# Use load state dict
model_source = Model()
model_dest = Model()
model_dest.load_state_dict(model_source.state_dict())

# Use deep copy
model_source = Model()
model_dest = copy.deepcopy(model_source )

字符串

vsnjm48y

vsnjm48y3#

在解决方案中添加另一种方法,尽管它与load_state_dict()相同,但当load_state_dict()因任何原因抛出错误时可能会很有用:

with torch.no_grad():
    for source_param, target_param in zip(model_to_copy_from.parameters(), model.parameters()):
        target_param.data.copy_(source_param.data)

字符串

相关问题