python 如何从密集网中解冻层?(PyTorch)

zysjyyx4  于 2022-12-10  发布在  Python
关注(0)|答案(1)|浏览(166)

我想对来自DenseNet-161的整个块执行微调。目前,我知道我可以使用以下命令冻结除分类器之外的所有层:

model = models.densenet161(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
num_ftrs = model.classifier.in_features
    
model.classifier = torch.nn.Linear(num_ftrs,2)

然而,我想解冻最后几层/块的密集网微调。什么是最好的最优雅的方式来实现这一点?

b1zrtrql

b1zrtrql1#

首先,您还可以通过将分类器的参数requires_grad设置为True来解冻分类器。

for param in model.classifier.parameters():
    param.requires_grad = True

这样就可以保留该层的原始参数,而不是在创建新的nn.Linear时获得的新的随机初始化。
这也适用于DenseNet的任何其他子模块。您可以通过打印模块来查看还有哪些其他模块。要解冻最后一个块和最后一个BatchNorm,您可以执行以下操作

# this is a torch.nn.Sequential containing the 
# "denseblock4" and "norm5" submodules
submodules = model.features[-2:]  
for param in submodules.parameters():
    param.requires_grad = True

如果你想把参数重置为一个新的随机初始化,你可以在每个参数上使用torch.nn.init中的一些初始化器。

相关问题