神经网络冻结层,固定参数

x33g5p2x  于2022-04-25 转载在 其他  
字(0.7k)|赞(0)|评价(0)|浏览(255)

bn层,卷积层测试:

import torch
from torch import nn

def init_weights(m):
    if type(m) == torch.nn.Linear :
        m.weight.data=torch.ones_like(m.weight)
        m.bias.data = torch.ones_like(m.bias)
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(6)
    def forward(self,x):
        a1=self.conv1(x)
        a2=self.bn1(a1)
        return a2

if __name__ == '__main__':

    net=Net()
    net.apply(init_weights)    #为了固定住网络的初始参数
    print(net.conv1.weight.grad,net.bn1.weight.grad,sep="\n")
    print("-------------")
    net.conv1.requires_grad_(False)
    net=net.train()
    x=torch.rand([2,3,8,8],dtype=torch.float32)
    y=net(x)
    y.sum().backward()
    print(net

相关文章