Numpy `matmul`在数组视图上的性能比`dot`差约100倍

omhiaaxx  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(194)

我注意到numpy中的matmul函数在乘以数组视图时的性能明显比dot函数差。在这种情况下,我的数组视图是复杂数组的真实的部分。下面是一些代码,重现了这个问题:

import numpy as np
from timeit import timeit
N = 1300
xx = np.random.randn(N, N) + 1j
yy = np.random.randn(N, N) + 1J

x = np.real(xx)
y = np.real(yy)
assert np.shares_memory(x, xx)
assert np.shares_memory(y, yy)

dot = timeit('np.dot(x,y)', number = 10, globals = globals())
matmul = timeit('np.matmul(x,y)', number = 10, globals = globals())

print('time for np.matmul: ', matmul)
print('time for np.dot: ', dot)

在我的机器上,输出如下:

time for np.matmul:  23.023062199994456
time for np.dot:  0.2706864000065252

这显然与共享内存有关,因为用np.real(xx).copy()替换np.real(xx)可以消除性能差异。
trolling numpy docs并不是特别有用,因为列出的差异并没有讨论处理内存视图时的实现细节。

8qgya5xd

8qgya5xd1#

这些时序表示dot正在执行copyreal的操作:

In [22]: timeit np.dot(xx.real,xx.real)
232 ms ± 3.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [23]: timeit np.dot(xx.real.copy(),xx.real.copy())
232 ms ± 4.18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

将其应用于matmul产生几乎相同的时间:

In [24]: timeit np.matmul(xx.real.copy(),xx.real.copy())
231 ms ± 3.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

同样,matmulreal采用了一些缓慢的路线。当使用int阵列时,matmul/dot的性能都较差,尽管没有matmul real的速度慢。matmul/dot也可以处理object数据类型,但速度更慢。
因此,作为python级别的用户,我们看不到(也没有文档记录)有很多隐藏的东西。

编辑

我很想把标题改为关注复数实数,但决定检查另一个view浮点数组的一个片段

In [42]: y=xx.real.copy()[::2,::2];y.shape,y.dtype
Out[42]: ((650, 650), dtype('float64'))

In [43]: timeit np.dot(y,y)
36.4 ms ± 63.4 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
In [44]: timeit np.dot(y.copy(),y.copy())
35.6 ms ± 191 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

同样很明显,dot正在使用视图的copiesmatmul没有:

In [45]: timeit np.matmul(y,y)
1.89 s ± 3.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

但对于副本,时间与点相同:

In [46]: timeit np.matmul(y.copy(),y.copy())
35.3 ms ± 102 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

因此,我的猜测是,如果dot不能将数组直接发送到BLAS例程,它通常会生成copy

编辑

虽然dotmatmul处理2d数组的方式相似,但它们处理3+d数组的方式却截然不同。实际上,添加@的主要原因是为矩阵乘法提供一个方便的“批处理”概念。
坚持使用大型复杂数组,让我们将一个大3倍:

In [49]: yy=np.array([xx,xx,xx]);yy.shape
Out[49]: (3, 1300, 1300)

In [50]: timeit np.dot(xx,xx)
794 ms ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)    
In [51]: timeit np.dot(xx,yy)       # (yy,xx) same timings
55.5 s ± 151 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [52]: timeit np.matmul(xx,yy)    # (yy,yy) same
2.58 s ± 362 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

matmul刚刚将时间增加了3;我可以探索更多的东西,但不是在分钟范围内的时间。

相关问题