pytorch 3DTensor的平均交叉熵

wsxa1bj1  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(178)

我有一个输出Tensor(目标值和预测值)(32 x 8 x 5000)。这里,批大小为32,类的数量为5000,每个批的点数为8。我想以这样的方式计算CELoss:损失是为每一个点计算的(跨越5000个类),然后在8个点上平均。2我怎么做呢?
为了清楚起见,一个批中有32个 * 批点 *(bs=32)。每个批点有8个 * 向量点 *,每个向量点有5000个类。对于给定的批,我希望计算所有(8)个 * 向量点 * 的CELoss,计算它们的平均值,并对所有 * 批点 *(32)进行计算。
如果我的问题不清楚或不明确,请告诉我。
例如:

op = torch.rand((4,3,5))

gt = torch.tensor([
    [[0,1,1,0,0],[0,0,1,0,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[0,0,0,1,0],[0,0,1,0,0]],
    [[0,0,1,0,0],[1,1,1,1,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[1,1,0,0,1],[1,0,0,0,0]]
])
vngu2lb8

vngu2lb81#

数据库

op = torch.rand((4,3,5))
gt = torch.tensor([
    [[0,1,1,0,0],[0,0,1,0,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[0,0,0,1,0],[0,0,1,0,0]],
    [[0,0,1,0,0],[1,1,1,1,0],[1,1,0,0,1]],
    [[1,1,0,0,1],[1,1,0,0,1],[1,0,0,0,0]]
], dtype=torch.float)

现在,如果你的输出是[0,1](如果不是,请在模型末尾提供S形激活),你可以用下面的方法计算二元交叉熵损失(每个元素的每个点的N_class值):

torch.nn.BCELoss(reduction="none")(op, gt)

最后,您可以计算批中每个元素的平均损耗,如下所示:

torch.nn.BCELoss(reduction="none")(op, gt).mean(dim=[-1,-2])

如果这不是你正在寻找的解决方案或它是不清楚让我知道

相关问题