我想使用自定义要素提取器来计算FID
根据https://lightning.ai/docs/torchmetrics/stable/image/frechet_inception_distance.html,我可以使用nn.Module
作为feature
下面的代码有什么问题?
import torch
_ = torch.manual_seed(123)
from torchmetrics.image.fid import FrechetInceptionDistance
from torchvision.models import inception_v3
net = inception_v3()
checkpoint = torch.load('checkpoint.pt')
net.load_state_dict(checkpoint['state_dict'])
net.eval()
fid = FrechetInceptionDistance(feature=net)
# generate two slightly overlapping image intensity distributions
imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8)
imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8)
fid.update(imgs_dist1, real=True)
fid.update(imgs_dist2, real=False)
result = fid.compute()
print(result)
个字符
1条答案
按热度按时间nbysray51#
问题是你将输入转换为
dtype=torch.uint8
。模型需要一个浮点Tensor。