pytorch torch.nn.BCEWithLogitsLoss的正确输入维是什么?

ctzwtxfj  于 2022-11-09  发布在  Git
关注(0)|答案(1)|浏览(262)

我正在训练一个二元分类器,并使用torch.nn.BCEWithLogitsLoss作为损失函数。我对损失函数的正确输入维数感到困惑。它应该是[n,1]还是[n,2]?在[n,1]的情况下,其中n是样本数,目标的值将只是0或1,这表示样本所属的类。在[n,1]的情况下,2],目标将是torch.nn.functional.one_hot(targets, num_classes=2).float()。哪一个是正确的维数,对应的logit和/或我的网络的最后一层是什么?

6ojccjat

6ojccjat1#

该文档将告诉您需要了解的所有信息。
对于形状为(num_samples, num_classes)的预测Tensorx,目标Tensory应当具有完全相同的形状。
案例1
如果类的数量为2,并且它们是互斥的,那么(num_samples,)x应该与(num_samples,)y进行对比,(num_samples,)的值为0 | 1(但类型转换为float)。
案例2
如果类的数目是两个,但它们不是互斥的(多类分类),则xy的形状应该是(num_samples, 2),其中y仍然从0 | 1中获取浮点值。
在这两种情况下,你都需要一个最后一层,将你的网络维度Map到你的类数(无论它在你的上下文中意味着什么),所以在情况1中,类似于Linear(model_dim, 1),而在情况2中,Linear(model_dim, 2)
请记住不要将任何激活函数应用于网络输出,因为BCEWithLogitsLoss中已经包含了该函数。
如果目标标注不是离散标注,请进行相应调整(值从[0..1]开始,而不是从0开始|1)中。

相关问题