pytorch 如何在单幅图像上使用经过训练的SegFormer模型?

dwbf0jvd  于 2023-02-08  发布在  其他
关注(0)|答案(1)|浏览(221)

我使用Trainer类的fit方法训练了SegFormer模型。

segformer_finetuner = SegformerFinetuner(
    train_dataset.id2label,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    metrics_interval=10,
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode="min",
)

checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")

trainer = pytorch_lightning.Trainer(
    gpus=1,
    callbacks=[early_stop_callback, checkpoint_callback],
    max_epochs=500,
    val_check_interval=len(train_dataloader),
)
trainer.fit(segformer_finetuner)

我得到了一个检查点文件epoch=151-step=304.ckpt作为输出,但我不知道如何使用它来预测单个映像。
我是这样试的:

model = SegformerForSemanticSegmentation()
model.load_state_dict(torch.load('lightning_logs/version_33/checkpoints/epoch=151-step=304.ckpt'))
model.eval()

# Load the image
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = torchvision.datasets.ImageFolder(root='path/to/image', transform=transform)
img = img.unsqueeze(0)

# Make the prediction
with torch.no_grad():
    output = model(img)

但我想我走错方向了。

pw9qyyiw

pw9qyyiw1#

您尝试加载的检查点文件是字典。您要查找的是加载“state_dict”中的值。这对我有效:

checkpoint = torch.load(<your checkpoint file path here>)
state_dict = checkpoint["state_dict"]
model = SegformerForSemanticSegmentation() 
model.load_state_dict(state_dict)
model.eval()

然后你就可以像以前那样使用这个模型了:

with torch.no_grad():
    output = model(img)

相关问题