我遇到过一个代码,它使用torch.einsum
来计算Tensor乘法I am able to understand the workings for lower order tensors,但不用于4DTensor,如下所示:
import torch
a = torch.rand((3, 5, 2, 10))
b = torch.rand((3, 4, 2, 10))
c = torch.einsum('nxhd,nyhd->nhxy', [a,b])
print(c.size())
# output: torch.Size([3, 2, 5, 4])
我需要以下方面的帮助:
1.这里执行了什么操作(解释矩阵如何相乘/转置等)?
torch.einsum
在这个场景中实际上是有益的吗?
1条答案
按热度按时间irlmq6kh1#
在这个例子中,我将尝试一步一步地解释
einsum
是如何工作的,但我将使用numpy.einsum
(documentation),而不是使用torch.einsum
,它的工作原理完全相同,但我只是总体上更适应它。让我们用NumPy重写上面的代码-
一步一个脚印
Einsum由3个步骤组成:
multiply
、sum
和transpose
让我们看看维度,我们有一个
(3, 5, 2, 10)
和一个(3, 4, 2, 10)
,我们需要将它们基于'nxhd,nyhd->nhxy'
,变成(3, 2, 5, 4)
1.相乘
我们先不考虑
n,x,y,h,d
轴的顺序,只考虑你是想保留它们还是删除(减少)它们,把它们写下来,然后看看我们如何安排维度--要获得
x
和y
轴之间的广播乘法结果(x, y)
,我们必须在正确的位置添加一个新轴,然后相乘。2.加/减
接下来,我们要减少最后一个轴10。这将得到尺寸
(n,x,y,h)
。这很简单,我们只在
axis=-1
上执行np.sum
3.转置
最后一步是使用转置来重新排列轴,我们可以使用
np.transpose
来实现这个目的,np.transpose(0,3,1,2)
基本上将第3个轴带到第0个轴之后,并推动第1个和第2个轴,因此,(n,x,y,h)
变为(n,h,x,y)
4.最终检查
让我们做最后一个检查,看看c3是否与从
np.einsum
生成的c相同-TL; DR。
因此,我们将
'nxhd , nyhd -> nhxy'
实现为-优势
np.einsum
相对于多个步骤的优势在于,您可以选择计算所采用的"路径",并使用同一个函数执行多个操作。这可以通过optimize
参数来实现,该参数将优化一个等式表达式的压缩顺序。这些运算的非详尽列表(可通过
einsum
计算)与示例一起显示如下:numpy.trace
的跟踪。numpy.diag
。numpy.sum
。numpy.transpose
。numpy.matmul
numpy.dot
。numpy.inner
numpy.outer
。numpy.multiply
。numpy.tensordot
。numpy.einsum_path
。基准
它表明
np.einsum
执行操作的速度比单个步骤快。