pytorch 公制收集 Torch 中输入的类别数和形状之间的问题

nx7onnlm  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(210)

我有一个问题,因为我想计算一些指标在torchmetrics.但有一个问题:

ValueError: The implied number of classes (from shape of inputs) does not match num_classes.

输出来自UNet,损失函数为BCEWithLogitsLoss(二进制分段)
通道= 1,因为灰度img
输入形状:(批次大小,通道,h,w)torch.float32
标签形状:(批次大小,通道,h,w)torch.float32用于BCE
输出图形:(批次大小、通道、h、w):torch.float32

inputs, labels = batch
outputs = model(input)
loss = self.loss_function(outputs, labels)
prec = torchmetrics.Precision(num_classes=1)(outputs, labels.type(torch.int32)
ego6inou

ego6inou1#

torchmetrics似乎需要不同的形状。请尝试将输出和标签都平面化:

prec = torchmetrics.Precision(num_classes=1)(outputs.view(-1), labels.type(torch.int32).view(-1))
0vvn1miw

0vvn1miw2#

我使用Torchmetrics库来计算分割任务的F1得分、精确度和召回率;当我遇到上述错误时,我试图获得我的两个单独类的F1分数,这个解决方案有效,但首先我必须将“multi_class=True”设置为“num_classes=2

torchmetrics_f1_none = torchmetrics.classification.F1Score(average=None, num_classes=2, multiclass=True) 

f1_0, f1_1 = torchmetrics_f1_none(thres_out.view(-1), masks.int().view(-1)) 

print("F1 Score for Background - {}, F1 Score for Foreground - {} \n".format(f1_0, f1_1))

相关问题