pytorch 沿着轴对Tensor进行 Torch 求和

wbgh16ku  于 2022-11-09  发布在  其他
关注(0)|答案(5)|浏览(436)

如何对Tensor的列求和?

torch.Size([10, 100])    --->    torch.Size([10])
gstyhher

gstyhher1#

最简单和最好的解决方案是使用torch.sum()
要对Tensor的所有元素求和,请执行以下操作:

torch.sum(x) # gives back a scalar

要对所有行求和(即对每列求和):

torch.sum(x, dim=0) # size = [ncol]

要对所有列求和(即对每行求和):

torch.sum(x, dim=1) # size = [nrow]

应该注意的是,求和的维度从结果Tensor中消除。

6qqygrtg

6qqygrtg2#

或者,您可以使用tensor.sum(axis),其中axis表示01,分别对2DTensor的行和列求和。

In [210]: X
Out[210]: 
tensor([[  1,  -3,   0,  10],
        [  9,   3,   2,  10],
        [  0,   3, -12,  32]])

In [211]: X.sum(1)
Out[211]: tensor([ 8, 24, 23])

In [212]: X.sum(0)
Out[212]: tensor([ 10,   3, -10,  52])

从上面的输出中我们可以看到,在两种情况下,输出都是一个一维Tensor。另一方面,如果你想在输出中保留原始Tensor的维数,那么你已经将布尔kwarg keepdim设置为True,如下所示:

In [217]: X.sum(0, keepdim=True)
Out[217]: tensor([[ 10,   3, -10,  52]])

In [218]: X.sum(1, keepdim=True)
Out[218]: 
tensor([[ 8],
        [24],
        [23]])
nwlqm0z1

nwlqm0z13#

如果你有Tensormy_tensor,并且你想在第二维数组上求和(也就是索引为1的那个,如果Tensor是二维的,它是列维,就像你的Tensor是二维的),使用torch.sum(my_tensor,1)或等价的my_tensor.sum(1),请参阅此处的文档。
文档中没有明确提到的一件事是:您可以使用-1last 数组维度求和(或使用-2对倒数第二个维度求和,等等)。
因此,在您的示例中,您可以用途:outputs.sum(1)torch.sum(outputs,1),或者等价地,outputs.sum(-1)torch.sum(outputs,-1)。所有这些将给予相同的结果,即大小为torch.Size([10])的输出Tensor,其中每个条目是Tensoroutputs的给定列中的所有行的和。
用三维Tensor来说明:

In [1]: my_tensor = torch.arange(24).view(2, 3, 4) 
Out[1]: 
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [2]: my_tensor.sum(2)
Out[2]:
tensor([[ 6, 22, 38],
        [54, 70, 86]])

In [3]: my_tensor.sum(-1)
Out[3]:
tensor([[ 6, 22, 38],
        [54, 70, 86]])
jobtbby3

jobtbby34#

基于文档https://pytorch.org/docs/stable/generated/torch.sum.html
它应该
dim(int或python:int的元组)-要减少的一个或多个维度。
dim=0表示减少行维度:压缩所有行=按列求和
dim=1表示减少列维数:压缩列=按行求和

djmepvbi

djmepvbi5#

沿着多个轴或尺寸的 Torch 总和

只是为了完整起见(我不能很容易地找到它),我包括如何沿着多个维度与torch.sum求和,这是大量用于计算机视觉任务中,你必须沿着HW维度减少。
如果你有一个x形状为C x H x W的图像,并且想计算每个通道的平均像素强度值,你可以:

avg = torch.sum(x, dim=(1,2)) / (H*W)     # Sum along (H,W) and norm

相关问题