我使用Trainer
类的fit
方法训练了SegFormer模型。
segformer_finetuner = SegformerFinetuner(
train_dataset.id2label,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=test_dataloader,
metrics_interval=10,
)
early_stop_callback = EarlyStopping(
monitor="val_loss",
min_delta=0.00,
patience=10,
verbose=False,
mode="min",
)
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val_loss")
trainer = pytorch_lightning.Trainer(
gpus=1,
callbacks=[early_stop_callback, checkpoint_callback],
max_epochs=500,
val_check_interval=len(train_dataloader),
)
trainer.fit(segformer_finetuner)
我得到了一个检查点文件epoch=151-step=304.ckpt
作为输出,但我不知道如何使用它来预测单个映像。
我是这样试的:
model = SegformerForSemanticSegmentation()
model.load_state_dict(torch.load('lightning_logs/version_33/checkpoints/epoch=151-step=304.ckpt'))
model.eval()
# Load the image
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = torchvision.datasets.ImageFolder(root='path/to/image', transform=transform)
img = img.unsqueeze(0)
# Make the prediction
with torch.no_grad():
output = model(img)
但我想我走错方向了。
1条答案
按热度按时间pw9qyyiw1#
您尝试加载的检查点文件是字典。您要查找的是加载“state_dict”中的值。这对我有效:
然后你就可以像以前那样使用这个模型了: