pytorch Torch 加载后,性能急剧下降,评估后,性能上升

d8tt03nd  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(146)

我有个奇怪的虫子。这是我的模型。

class MyModel(nn.Module):
    def __init__(self, feat_dim, num_classes):
        super(MyModel, self).__init__()
        self.model_resnet = models.resnet50(pretrained=False)
        num_ftrs = self.model_resnet.fc.in_features
        self.model_resnet.fc = nn.Identity()

        self.head1 = nn.Sequential(
                nn.Linear(num_ftrs, num_ftrs),
                nn.ReLU(inplace=True),
                nn.Linear(num_ftrs, feat_dim)
            )

        self.head2 = nn.Linear(num_ftrs, num_classes)

    def forward(self, x):
        self.eps=self.eps+1
        x = self.model_resnet(x)
        feat = F.normalize(self.head1(x), dim=1)
        classes = self.head2(x)

        return feat,classes

下面的代码用于保存和加载

torch.save(model.state_dict(),"./test.pth")
model.load_state_dict(torch.load("test.pth"))

然后我训练了它,保存了重量,测试准确度0. 95。下次我加载它,测试一些东西。就像随机猜测,准确度接近0。
经过对整个测试集的评估,测试精度恢复到0.8,但仍有性能损失。
我检查了model.state_dict(),在评估整个测试集之前和之后,权重是相同的。
有人有什么想法吗?

bqf10yzr

bqf10yzr1#

经过一整天的调试,我终于找到了原因,重新加载模型后,我没有设置model.eval(),做了一些计算。

相关问题