python 将图像分类模型转换为分层模型

zqdjd7g9  于 2023-06-28  发布在  Python
关注(0)|答案(1)|浏览(145)

我使用ResNet50从图像中提取特征。我如何删除网络的分类头,并且还因为我需要拆分网络以具有中间特征,所以我如何将网络转换为可以访问中间网络架构的分层形式,如下所示:

import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

class model(nn.Module):
    def __init__(self, pretrained=True):
        super(model, self).__init__()

        self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

def forward(self, x):   
        
      x1= self.featureExtractor_1(x)     #(number of feature maps:256)
      x2= self.featureExtractor_2(x1)    # (number of feature maps:512)
      x3= self.featureExtractor_3(x2)    # (number of feature maps:1024)
      x4= self.featureExtractor_4(x3)    # (number of feature maps:2048)

 return x1, x2, x3, x4

虽然我知道如何使用hook方法提取网络的中间特征,但我不知道如何将网络划分为这样的层次结构?
你知道吗?

smtd7mpg

smtd7mpg1#

您可以执行以下操作:

import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn

class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()

        self.model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.block1 = nn.Sequential(*list(self.model.children())[:5])
        self.block2 = nn.Sequential(*list(self.model.children())[5:6])
        self.block3 = nn.Sequential(*list(self.model.children())[6:7])
        self.block4 = nn.Sequential(*list(self.model.children())[7:8])

    def forward(self, x):   
        
        x1 = self.block1(x)    
        x2 = self.block2(x1)
        x3 = self.block3(x2)
        x4 = self.block4(x3)

        return(x1, x2, x3, x4)

x = torch.randn(1, 3, 256, 256)
model = model()
x1, x2, x3, x4 = model(x)

print(x1.shape, x2.shape, x3.shape, x4.shape)

其给出:

torch.Size([1, 256, 64, 64]) torch.Size([1, 512, 32, 32]) torch.Size([1, 1024, 16, 16]) torch.Size([1, 2048, 8, 8])

相关问题