pytorch中预训练resnet50的特征提取

huwehgph  于 2023-03-02  发布在  其他
关注(0)|答案(2)|浏览(408)

嗨,伙计们,我想提取我的预训练resnet50的全连接层的特征。
我在前面创建了一个方法来给予我特征向量:

def get_vector(image):

#layer = model._modules.get('fc')

layer = model.fc
my_embedding = torch.zeros(2048) #2048 is the in_features of FC , output of avgpool

def copy_data(m, i, o):

    my_embedding.copy_(o.data)

h = layer.register_forward_hook(copy_data)
tmp = model(image)

h.remove()

# return the vector
return my_embedding

在我调用这个方法之后

column = ["FlickrID", "Features"]

path = "./train_dataset/train_imgs/"

pathCSV = "./train_dataset/features/img_info_TRAIN.csv"


f_id=[]
features_extr=[]

df = pd.DataFrame(columns=column)

tr=transforms.Compose([transforms.Resize(256),
                       transforms.CenterCrop(224),
                       transforms.ToTensor(),
                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])


test = Dataset(path, pathCSV, transform=tr)

test_loader = DataLoader(test, batch_size=1, num_workers=2, shuffle = False)


#Leggiamo le immagini
for batch in test_loader:
    nome = batch['FlickrID']
    f_id.append(nome)
    image = batch['image']


    #print(image)
    with torch.no_grad():
        pred = get_vector(image)

    features_extr.append(pred)

df["FlickrID"] = f_id
df["Features"] = features_extr  

df.to_hdf("Places.h5", key='df', mode='w')

我有这样一个错误:输出形状[2048]与广播形状[1,2048,1,2048]不匹配
如何获取这个resnet50的完全连接的in_feature?Dataset是一个自定义的Dataset类。
抱歉我英语不好

zxlwwiss

zxlwwiss1#

该模型采用批处理输入,这意味着全连接层的输入大小为 [batch_size,2048]。由于您使用的批处理大小为1,因此变为 [1,2048]。因此,这不适合Tensortorch.zeros(2048),因此应该改为torch.zeros(1, 2048)
您还尝试使用图层model.fc的输出(o),而不是输入(i)。
除此之外,使用钩子过于复杂,获取特征的一种更简单的方法是通过将model.fc替换为nn.Identity来修改模型,nn.Identity只返回输入作为输出,由于特征是其输入,因此整个模型的输出将是特征。

model.fc = nn.Identity()

features = model(image)
4sup72z8

4sup72z82#

这对我很有效,和Michael的回答一样有效,实际上两者是一样的

class EmptyModule(nn.Module) :
    def __init__(self, *args) :
        super(EmptyModule, self).__init__()
    def forward(self,x):
        return x

model.fc = EmptyModule()

编辑:NVM,nn.Identity()层具有相同的代码

相关问题