numpy dot()和Python 3.5+矩阵乘法之间的区别@

wbrvyc0a  于 2023-03-30  发布在  Python
关注(0)|答案(6)|浏览(129)

我最近转到Python 3.5,注意到新的矩阵乘法运算符(@)有时与numpy dot运算符的行为不同。例如,对于3d数组:

import numpy as np

a = np.random.rand(8,13,13)
b = np.random.rand(8,13,13)
c = a @ b  # Python 3.5+
d = np.dot(a, b)

@运算符返回一个形状数组:

c.shape
(8, 13, 13)

np.dot()函数返回:

d.shape
(8, 13, 8, 13)

我如何用numpy dot重现相同的结果?是否有其他显著差异?

zzwlnbp8

zzwlnbp81#

@运算符调用数组的__matmul__方法,而不是dot。该方法也作为函数np.matmul出现在API中。

>>> a = np.random.rand(8,13,13)
>>> b = np.random.rand(8,13,13)
>>> np.matmul(a, b).shape
(8, 13, 13)

来自文档:
matmuldot有两个重要的区别。

  • 不允许与标量相乘。
  • 矩阵堆栈一起广播,就好像矩阵是元素一样。

最后一点清楚地表明,dotmatmul方法在传递3D(或更高维)数组时的行为不同。引用文档中的一些内容:
对于matmul
如果任一参数是N-D,N〉2,则将其视为驻留在最后两个索引中的矩阵的堆栈并相应地广播。
对于np.dot
对于2-D数组,它等价于矩阵乘法,对于1-D数组,它等价于向量的内积(没有复共轭)。* 对于N维,它是a的最后一个轴和b的倒数第二个轴的和积 *

llew8vvj

llew8vvj2#

仅供参考,@和它的numpy等价物dotmatmul都同样快(用perfplot创建的图,我的一个项目)。

用于重现绘图的代码:

import perfplot
import numpy

def setup(n):
    A = numpy.random.rand(n, n)
    x = numpy.random.rand(n)
    return A, x

def at(A, x):
    return A @ x

def numpy_dot(A, x):
    return numpy.dot(A, x)

def numpy_matmul(A, x):
    return numpy.matmul(A, x)

perfplot.show(
    setup=setup,
    kernels=[at, numpy_dot, numpy_matmul],
    n_range=[2 ** k for k in range(15)],
)
7tofc5zh

7tofc5zh3#

@ajcr的回答解释了dotmatmul(由@符号调用)的不同之处。通过查看一个简单的示例,可以清楚地看到两者在操作“矩阵堆栈”或Tensor时的不同行为。
为了阐明差异,取一个4x 4数组,并返回dot乘积和matmul乘积以及3x 4x 2“矩阵堆栈”或Tensor。

import numpy as np
fourbyfour = np.array([
                       [1,2,3,4],
                       [3,2,1,4],
                       [5,4,6,7],
                       [11,12,13,14]
                      ])

threebyfourbytwo = np.array([
                             [[2,3],[11,9],[32,21],[28,17]],
                             [[2,3],[1,9],[3,21],[28,7]],
                             [[2,3],[1,9],[3,21],[28,7]],
                            ])

print('4x4*3x4x2 dot:\n {}\n'.format(np.dot(fourbyfour,threebyfourbytwo)))
print('4x4*3x4x2 matmul:\n {}\n'.format(np.matmul(fourbyfour,threebyfourbytwo)))

每个操作的乘积显示在下面。请注意点积是怎样的,
a的最后一个轴与b的倒数第二个轴的和积
以及如何通过一起广播矩阵来形成矩阵乘积。

4x4*3x4x2 dot:
 [[[232 152]
  [125 112]
  [125 112]]

 [[172 116]
  [123  76]
  [123  76]]

 [[442 296]
  [228 226]
  [228 226]]

 [[962 652]
  [465 512]
  [465 512]]]

4x4*3x4x2 matmul:
 [[[232 152]
  [172 116]
  [442 296]
  [962 652]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]

 [[125 112]
  [123  76]
  [228 226]
  [465 512]]]
rqdpfwrv

rqdpfwrv4#

在数学上,我认为numpy中的更有意义

dot(a,B)_{i,j,k,a,b,c} =

因为当a和B是向量时它给出点积,或者当a和b是矩阵时它给出矩阵乘法
至于numpy中的matmul操作,它由dotresult的部分组成,可以定义为

matmul(a,B)_{i,j,k,c} =

所以,你可以看到**matmul(a,B)**返回一个小形状的数组,它的内存消耗更小,在应用中更有意义。特别是,结合broadcasting,你可以得到
(a,B)_{i,j,k,l} =

例如。
从上面的两个定义可以看出使用这两个操作的要求,假设a.shape=(s1,s2,s3,s4)b.shape=(t1,t2,t3,t4)

  • 要使用dot(a,B),需要

1.t3=s4;

  • 使用matmul(a,B)**需要
  1. t3=s4**
    1.t2=s2,或t2和s2之一为1
    1.t1=s1,或t1和s1之一为1
    使用下面的代码来说服自己。
import numpy as np
for it in range(10000):
    a = np.random.rand(5,6,2,4)
    b = np.random.rand(6,4,3)
    c = np.matmul(a,b)
    d = np.dot(a,b)
    #print ('c shape: ', c.shape,'d shape:', d.shape)
    
    for i in range(5):
        for j in range(6):
            for k in range(2):
                for l in range(3):
                    if c[i,j,k,l] != d[i,j,k,j,l]:
                        print (it,i,j,k,l,c[i,j,k,l]==d[i,j,k,j,l])  # you will not see them
j2qf4p5b

j2qf4p5b5#

下面是与np.einsum的比较,以显示索引是如何投影的

np.allclose(np.einsum('ijk,ijk->ijk', a,b), a*b)        # True 
np.allclose(np.einsum('ijk,ikl->ijl', a,b), a@b)        # True
np.allclose(np.einsum('ijk,lkm->ijlm',a,b), a.dot(b))   # True
i2byvkas

i2byvkas6#

我的经验与MATMUL和DOT
我经常收到“ValueError:当尝试使用MATMUL时,传递值的形状是(200,1),索引意味着(200,3)”。我想要一个快速的解决方案,并发现DOT提供了相同的功能。我使用DOT没有得到任何错误。我得到了正确的答案
关于MATMUL

X.shape
>>>(200, 3)

type(X)

>>>pandas.core.frame.DataFrame

w

>>>array([0.37454012, 0.95071431, 0.73199394])

YY = np.matmul(X,w)

>>>  ValueError: Shape of passed values is (200, 1), indices imply (200, 3)"

关于DOT

YY = np.dot(X,w)
# no error message
YY
>>>array([ 2.59206877,  1.06842193,  2.18533396,  2.11366346,  0.28505879, …

YY.shape

>>> (200, )

相关问题