我使用ResNet50从图像中提取特征。我如何删除网络的分类头,并且还因为我需要拆分网络以具有中间特征,所以我如何将网络转换为可以访问中间网络架构的分层形式,如下所示:
import torch
from torchvision.models import resnet50, ResNet50_Weights
import torch.nn as nn
class model(nn.Module):
def __init__(self, pretrained=True):
super(model, self).__init__()
self.featureExtractor = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
def forward(self, x):
x1= self.featureExtractor_1(x) #(number of feature maps:256)
x2= self.featureExtractor_2(x1) # (number of feature maps:512)
x3= self.featureExtractor_3(x2) # (number of feature maps:1024)
x4= self.featureExtractor_4(x3) # (number of feature maps:2048)
return x1, x2, x3, x4
虽然我知道如何使用hook
方法提取网络的中间特征,但我不知道如何将网络划分为这样的层次结构?
你知道吗?
1条答案
按热度按时间smtd7mpg1#
您可以执行以下操作:
其给出: