pytorch torch.einsum如何执行这个4DTensor乘法?

hgc7kmma  于 2023-01-09  发布在  其他
关注(0)|答案(1)|浏览(165)

我遇到过一个代码,它使用torch.einsum来计算Tensor乘法I am able to understand the workings for lower order tensors,但不用于4DTensor,如下所示:

import torch

a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))

c = torch.einsum('nxhd,nyhd->nhxy', [a,b])

print(c.size())

# output: torch.Size([3, 2, 5, 4])

我需要以下方面的帮助:
1.这里执行了什么操作(解释矩阵如何相乘/转置等)?

  1. torch.einsum在这个场景中实际上是有益的吗?
irlmq6kh

irlmq6kh1#

  • (跳至tl; dr部分(如果您只想了解einsum中涉及的步骤)*

在这个例子中,我将尝试一步一步地解释einsum是如何工作的,但我将使用numpy.einsumdocumentation),而不是使用torch.einsum,它的工作原理完全相同,但我只是总体上更适应它。
让我们用NumPy重写上面的代码-

import numpy as np

a = np.random.random((3, 5, 2, 10))
b = np.random.random((3, 4, 2, 10))
c = np.einsum('nxhd,nyhd->nhxy', a,b)
c.shape

#(3, 2, 5, 4)

一步一个脚印

Einsum由3个步骤组成:multiplysumtranspose
让我们看看维度,我们有一个(3, 5, 2, 10)和一个(3, 4, 2, 10),我们需要将它们基于'nxhd,nyhd->nhxy',变成(3, 2, 5, 4)

1.相乘

我们先不考虑n,x,y,h,d轴的顺序,只考虑你是想保留它们还是删除(减少)它们,把它们写下来,然后看看我们如何安排维度--

## Multiply ##
       n   x   y   h   d
      --------------------
a  ->  3   5       2   10
b  ->  3       4   2   10
c1 ->  3   5   4   2   10

要获得xy轴之间的广播乘法结果(x, y),我们必须在正确的位置添加一个新轴,然后相乘。

a1 = a[:,:,None,:,:] #(3, 5, 1, 2, 10)
b1 = b[:,None,:,:,:] #(3, 1, 4, 2, 10)

c1 = a1*b1
c1.shape

#(3, 5, 4, 2, 10)  #<-- (n, x, y, h, d)

2.加/减

接下来,我们要减少最后一个轴10。这将得到尺寸(n,x,y,h)

## Reduce ##
        n   x   y   h   d
       --------------------
c1  ->  3   5   4   2   10
c2  ->  3   5   4   2

这很简单,我们只在axis=-1上执行np.sum

c2 = np.sum(c1, axis=-1)
c2.shape

#(3,5,4,2)  #<-- (n, x, y, h)

3.转置

最后一步是使用转置来重新排列轴,我们可以使用np.transpose来实现这个目的,np.transpose(0,3,1,2)基本上将第3个轴带到第0个轴之后,并推动第1个和第2个轴,因此,(n,x,y,h)变为(n,h,x,y)

c3 = c2.transpose(0,3,1,2)
c3.shape

#(3,2,5,4)  #<-- (n, h, x, y)

4.最终检查

让我们做最后一个检查,看看c3是否与从np.einsum生成的c相同-

np.allclose(c,c3)

#True

TL; DR。

因此,我们将'nxhd , nyhd -> nhxy'实现为-

input     -> nxhd, nyhd
multiply  -> nxyhd      #broadcasting
sum       -> nxyh       #reduce
transpose -> nhxy

优势

np.einsum相对于多个步骤的优势在于,您可以选择计算所采用的"路径",并使用同一个函数执行多个操作。这可以通过optimize参数来实现,该参数将优化一个等式表达式的压缩顺序。
这些运算的非详尽列表(可通过einsum计算)与示例一起显示如下:

  • 数组numpy.trace的跟踪。
  • 返回对角线numpy.diag
  • 阵列轴总和,numpy.sum
  • 换位和排列,numpy.transpose
  • 矩阵乘法和点积,numpy.matmulnumpy.dot
  • 向量的内积和外积,numpy.innernumpy.outer
  • 广播,元素乘法和标量乘法,numpy.multiply
  • Tensor收缩,numpy.tensordot
  • 链式数组运算,计算顺序低效,numpy.einsum_path

基准

%%timeit
np.einsum('nxhd,nyhd->nhxy', a,b)
#8.03 µs ± 495 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%%timeit
np.sum(a[:,:,None,:,:]*b[:,None,:,:,:], axis=-1).transpose(0,3,1,2)
#13.7 µs ± 1.42 µs per loop (mean ± std. dev. of 7 runs, 100000 loops each)

它表明np.einsum执行操作的速度比单个步骤快。

相关问题