如何在pytorch中停止图层更新?

wgmfuz8q  于 2022-11-23  发布在  其他
关注(0)|答案(2)|浏览(140)
class DR(nn.Module):
    def __init__(self, orginal, latent_dims):
        super(DR, self).__init__()
        self.latent_dims=latent_dims
        self.linear1 = nn.Linear(orginal, 1000)
        self.linear2 = nn.Linear(1000, 2000)
        self.linear3 = nn.Linear(2000, latent_dims)
        

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        return x

在这个DR类中,我不想在训练时更新linear1和linear2。这意味着,层应该与它初始化时相同,我只想在训练时更新linear3。我该怎么做呢?
我希望这个问题能得到解决。谢谢

jfgube3f

jfgube3f1#

一种解决方案是将这些层中所有参数的requires_grad属性设置为False。

model = DR(64, 100)
for layer in [model.linear1, model.linear2]:
    for p in layer.parameters():
        p.requires_grad_(False)
print([(n, p.requires_grad) for n,p in model.named_parameters()])
k97glaaz

k97glaaz2#

PyTorch的一个未被充分利用的特性是requires_grad,它用于禁用梯度计算的推理缓存,可以在torch.Tensor上调用(作为一个就地和不就地操作),也可以在nn.Module上调用(就地)。
在您的情况下,可以简单地写两行:

model.linear1.requires_grad_(False)
model.linear2.requires_grad_(False)

相关问题