如何改变Pytorch预训练模块中的激活层?

0mkxixxg  于 2023-01-17  发布在  其他
关注(0)|答案(5)|浏览(137)

如何改变Pytorch预训练网络的激活层?下面是我的代码:

print("All modules")
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
        child=nn.SELU()
        print(child)
print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

以下是我的输出:

All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)
wvmv3b1j

wvmv3b1j1#

._modules为我解决了这个问题。

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()
dhxwm5r4

dhxwm5r42#

以下是用于替换任何图层的通用函数

def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            setattr(model, n, new)

replace_layer(model, nn.ReLU, nn.ReLU6())

我为此挣扎了几天,所以,我做了一些挖掘&写了一个kaggle notebook,解释了在pytorch中如何访问不同类型的层/模块。

ldxq2e6h

ldxq2e6h3#

默认的pytorch API对我来说很好用:

def replace_layer(module: nn.Module, old: nn.Module, new: nn.Module, full_name=""):
    for name, m in module.named_children():
        full_name = f"{full_name}.{name}"

        if isinstance(m, old):
            setattr(module, name, new)
            print(f"replaced {full_name}: {old}->{new}")
        elif len(list(m.children())) > 0:
            replace_layer(m, old, new, full_name)

model.apply(lambda m: replace_layer(m, nn.Relu, nn.Hardswish(True)))
将替换层并打印“trace”:

replaced ._model.norm_layer.0.1.2: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.0.1.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
replaced ._model.norm_layer.backbone.encoder.blocks.0.1.2.3.4.0.conv1.bn1.relu: <class 'torch.nn.modules.activation.ReLU'>->Hardswish()
t9aqgxwy

t9aqgxwy4#

我假设你使用模块接口nn.ReLU来创建激活层,而不是使用函数接口F.relu。如果是这样,setattr对我来说是有效的。

import torch
import torch.nn as nn

# This function will recursively replace all relu module to selu module. 
def replace_relu_to_selu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.SELU())
        else:
            replace_relu_to_selu(child)

########## A toy example ##########
net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True)
          )

########## Test ##########
print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# Before changing activation
# ReLU(inplace=True)
# ReLU(inplace=True)

print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# after changing activation
# SELU()
# SELU(
monwx1rj

monwx1rj5#

我将提供一个更通用的解决方案,适用于任何层(并避免其他问题,如在循环遍历字典时修改字典,或者在彼此内部存在递归的nn.modules时修改字典)。

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

关键是您需要递归地不断更改图层(主要是因为有时候你会遇到属性本身就有模块的情况)。我认为比上面更好的代码是再加一个if语句(在批处理规范之后)检测是否必须递归,如果必须递归,则递归。上面的工作是,但首先更改外层上的批处理规范(即第一个循环),然后用另一个循环确保没有遗漏其他应该被递归的对象(然后递归)。
原文:https://discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/10

相关问题