pytorch 微调冻结权重nnUNet

smtd7mpg  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(209)

早上好,我已经按照github杂志上的说明做了:
https://github.com/MIC-DKFZ/nnUNet/issues/1108
我想在预先训练的模型上微调nnUNet模型(pyTorch),但这种方法会重新训练所有权重,我想冻结所有权重,只重新训练最后一层的权重,将分割类的数量从3改为1。您知道如何做到这一点吗?提前感谢

xu3bshqb

xu3bshqb1#

要冻结权重,您需要设置parameter.requires_grad = False
示例:

from nnunet.network_architecture.generic_UNet import Generic_UNet

model = Generic_UNet(input_channels=3, base_num_features=64, num_classes=4, num_pool=3)

for name, parameter in model.named_parameters():
    if 'seg_outputs' in name:
        print(f"parameter '{name}' will not be freezed")
        parameter.requires_grad = True
    else:
        parameter.requires_grad = False

要检查参数名称,可以使用print

print(model)

其产生:

Generic_UNet(
  (conv_blocks_localization): ModuleList(
    (0): Sequential(
      (0): StackedConvLayers(
        (blocks): Sequential(
          (0): ConvDropoutNormNonlin(
            (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
      (1): StackedConvLayers(
        (blocks): Sequential(
          (0): ConvDropoutNormNonlin(
            (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
    )
  )
  (conv_blocks_context): ModuleList(
    (0): StackedConvLayers(
      (blocks): Sequential(
        (0): ConvDropoutNormNonlin(
          (conv): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (dropout): Dropout2d(p=0.5, inplace=True)
          (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
        )
        (1): ConvDropoutNormNonlin(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (dropout): Dropout2d(p=0.5, inplace=True)
          (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
        )
      )
    )
    (1): Sequential(
      (0): StackedConvLayers(
        (blocks): Sequential(
          (0): ConvDropoutNormNonlin(
            (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (dropout): Dropout2d(p=0.5, inplace=True)
            (instnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
      (1): StackedConvLayers(
        (blocks): Sequential(
          (0): ConvDropoutNormNonlin(
            (conv): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (dropout): Dropout2d(p=0.5, inplace=True)
            (instnorm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (lrelu): LeakyReLU(negative_slope=0.01, inplace=True)
          )
        )
      )
    )
  )
  (td): ModuleList(
    (0): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
  )
  (tu): ModuleList(
    (0): Upsample()
  )
  (seg_outputs): ModuleList(
    (0): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1), bias=False)
  )
)

或者,您可以使用netron来可视化您的网络:
https://github.com/lutzroeder/netron

相关问题