我正在训练一个二元分类器,并使用torch.nn.BCEWithLogitsLoss作为损失函数。我对损失函数的正确输入维数感到困惑。它应该是[n,1]还是[n,2]?在[n,1]的情况下,其中n是样本数,目标的值将只是0或1,这表示样本所属的类。在[n,1]的情况下,2],目标将是torch.nn.functional.one_hot(targets, num_classes=2).float()。哪一个是正确的维数,对应的logit和/或我的网络的最后一层是什么?
torch.nn.BCEWithLogitsLoss
torch.nn.functional.one_hot(targets, num_classes=2).float()
6ojccjat1#
该文档将告诉您需要了解的所有信息。对于形状为(num_samples, num_classes)的预测Tensorx,目标Tensory应当具有完全相同的形状。案例1如果类的数量为2,并且它们是互斥的,那么(num_samples,)的x应该与(num_samples,)的y进行对比,(num_samples,)的值为0 | 1(但类型转换为float)。案例2如果类的数目是两个,但它们不是互斥的(多类分类),则x和y的形状应该是(num_samples, 2),其中y仍然从0 | 1中获取浮点值。在这两种情况下,你都需要一个最后一层,将你的网络维度Map到你的类数(无论它在你的上下文中意味着什么),所以在情况1中,类似于Linear(model_dim, 1),而在情况2中,Linear(model_dim, 2)。请记住不要将任何激活函数应用于网络输出,因为BCEWithLogitsLoss中已经包含了该函数。如果目标标注不是离散标注,请进行相应调整(值从[0..1]开始,而不是从0开始|1)中。
(num_samples, num_classes)
x
y
(num_samples,)
0 | 1
(num_samples, 2)
Linear(model_dim, 1)
Linear(model_dim, 2)
BCEWithLogitsLoss
1条答案
按热度按时间6ojccjat1#
该文档将告诉您需要了解的所有信息。
对于形状为
(num_samples, num_classes)
的预测Tensorx
,目标Tensory
应当具有完全相同的形状。案例1
如果类的数量为2,并且它们是互斥的,那么
(num_samples,)
的x
应该与(num_samples,)
的y
进行对比,(num_samples,)
的值为0 | 1
(但类型转换为float)。案例2
如果类的数目是两个,但它们不是互斥的(多类分类),则
x
和y
的形状应该是(num_samples, 2)
,其中y
仍然从0 | 1
中获取浮点值。在这两种情况下,你都需要一个最后一层,将你的网络维度Map到你的类数(无论它在你的上下文中意味着什么),所以在情况1中,类似于
Linear(model_dim, 1)
,而在情况2中,Linear(model_dim, 2)
。请记住不要将任何激活函数应用于网络输出,因为
BCEWithLogitsLoss
中已经包含了该函数。如果目标标注不是离散标注,请进行相应调整(值从[0..1]开始,而不是从0开始|1)中。