numpy 为什么np.dot比np.sum快这么多?

vq8itlhq  于 2023-03-02  发布在  其他
关注(0)|答案(2)|浏览(295)

为什么www.example.com比np.sum快这么多?根据这个答案,我们知道np.sum很慢,但有更快的替代方案。np.dot so much faster than np.sum? Following this answer we know that np.sum is slow and has faster alternatives.
例如:

In [20]: A = np.random.rand(1000)

In [21]: B = np.random.rand(1000)

In [22]: %timeit np.sum(A)
3.21 µs ± 270 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [23]: %timeit A.sum()
1.7 µs ± 11.5 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [24]: %timeit np.add.reduce(A)
1.61 µs ± 19.6 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

但它们都慢于:

In [25]: %timeit np.dot(A,B)
1.18 µs ± 43.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

如果www.example.com将两个数组元素相乘,然后求和,这怎么会比求和一个数组快呢?如果B被设置为全1数组,那么np.dot将简单地求和A。np.dot is both multiplying two arrays elementwise and then summing them, how can this be faster than just summing one array? If B were set to the all ones array then np.dot would simply be summing A.
因此,对A求和最快的方法是:

In [26]: O = np.ones(1000)
In [27]: %timeit np.dot(A,O)
1.16 µs ± 6.37 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

这不可能是对的,是吗?
这是在Ubuntu上使用numpy 1.24.2,在Python 3.10.6上使用openblas64。
此NumPy安装中支持的SIMD扩展:

baseline = SSE,SSE2,SSE3
found = SSSE3,SSE41,POPCNT,SSE42,AVX,F16C,FMA3,AVX2
    • 更新**

如果数组更长,计时顺序将颠倒,即:

In [28]: A = np.random.rand(1000000)
In [29]: O = np.ones(1000000)
In [30]: %timeit np.dot(A,O)
545 µs ± 8.87 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [31]: %timeit np.sum(A)
429 µs ± 11 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)    
In [32]: %timeit A.sum()
404 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [33]: %timeit np.add.reduce(A)
401 µs ± 4.21 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

对我来说,这意味着在调用np.sum(A)、A.sum()、np.add.reduce(A)时存在一些固定大小的开销,而在调用www.example.com()时不存在这些开销,但执行求和的代码部分实际上更快。np.dot() but the part of the code that does the summation is in fact faster.
-—————————-
任何使用cython,numba,python等的加速都是很好的。

kq4fsx7k

kq4fsx7k1#

numpy.dot在这里代表BLAS矢量-矢量乘法,而numpy.sum使用成对求和例程,切换到块大小为128个元素的8x展开求和循环。
我不知道NumPy使用的是什么BLAS库,但据我所知,一个好的BLAS通常会利用SIMD操作,而numpy.sum不会这样做。numpy.sum代码中的任何SIMD使用都必须是编译器自动向量化,这可能比BLAS效率低。
当您将数组大小增加到100万个元素时,此时可能会达到缓存阈值。dot代码使用大约16 MB的数据,sum代码使用大约8 MB的数据。dot代码可能会将数据转移到较慢的缓存级别或RAM。或者dotsum都使用较慢的缓存级别,而dot的性能较差,因为它需要读取更多的数据。与具有更高的每元件性能的X1 M10 N1 X相比,定时与某种阈值效应更一致。

jrcvhitl

jrcvhitl2#

此答案提供了更多详细信息,从而完善了@user2357112的正确答案。两个函数都经过优化。也就是说,成对求和通常速度稍慢,但通常提供更准确的结果。它也是次优的,但相对较好。Windows上默认使用的OpenBLAS不执行成对求和。
下面是Numpy代码的汇编代码:

下面是OpenBLAS代码的汇编代码:

Numpy代码的主要问题是它不使用AVX(256位SIMD指令集),而是使用SSE(128位SIMD指令集),这与OpenBLAS相反,至少在1.22.4(我使用的版本)及之前的版本中是这样。在Numpy代码中,指令是标量1!我们最近对此进行了研究,Numpy的最新版本现在应该使用AVX。也就是说,由于成对求和(特别是对于大数组),它可能仍然不如OpenBLAS快。

注意,两个函数都花费了不可忽略的时间在开销上,因为数组太小了。这样的开销可以使用Numba中的手写实现来消除。
如果阵列更长,则计时顺序相反。
这是意料之中的。实际上,当函数该高速缓存中操作时,它们是相当受计算限制的,但是当数组很大**并且适合L3缓存甚至RAM时,它们变成受内存限制的。因此,np.dot统计对于更大的数组会更慢,因为它需要从内存读取两倍大的数据。更具体地说,它需要从内存中读取8*1000000*2/1024**2 ~= 15.3 MiB,所以你可能需要从RAM中读取数据,而RAM的吞吐量非常有限。事实上,像我这样的好的双通道3200 MHz DDR4 RAM可以达到接近40 GiB和15.3/(40*1024) ~= 374 µs的实际吞吐量。也就是说,顺序代码很难完全饱和此吞吐量,因此顺序代码达到30 GiB/s已经很不错了,更不用说许多主流PC RAM在较低频率下运行。30 GHz/s吞吐量导致约500 µs,这与您的时序非常接近。同时,np.sumnp.add.reduce由于其低效率实现而更受计算限制,但是要读取的数据量小了两倍,并且实际上可以更好地适合具有显著更大吞吐量的L3高速缓存。
要证明此效果,您可以简单地尝试运行:

# L3 cache of 9 MiB

# 2 x 22.9 = 45.8 MiB
a = np.ones(3_000_000)
b = np.ones(3_000_000)
%timeit -n 100 np.dot(a, a)   #  494 µs => read from RAM
%timeit -n 100 np.dot(a, b)   # 1007 µs => read from RAM

# 2 x 7.6 = 15.2 MiB
a = np.ones(1_000_000)
b = np.ones(1_000_000)
%timeit -n 100 np.dot(a, a)   #  90 µs => read from the L3 cache
%timeit -n 100 np.dot(a, b)   # 283 µs => read from RAM

# 2 x 1.9 = 3.8 MiB
a = np.ones(250_000)
b = np.ones(250_000)
%timeit -n 100 np.dot(a, a)   # 40 µs => read from the L3 cache (quite compute-bound)
%timeit -n 100 np.dot(a, b)   # 46 µs => read from the L3 cache too (quite memory-bound)

在我的机器上,L3的大小只有9 MiB,因此第二次调用不仅需要读取两倍多的数据,而且从较慢的RAM读取的数据也比从L3缓存读取的数据多。

对于小型阵列,L1缓存速度很快,阅读数据应该不会成为瓶颈。在我的i5- 9600 KF机器上,L1缓存的吞吐量非常大:~268 GiB/s。这意味着读取两个大小为1000的数组的最佳时间是8*1000*2/(268*1024**3) ~= 0.056 µs。实际上,调用Numpy函数的开销要比这个大得多。

快速实施

下面是一个快速的Numba实现

import numba as nb

# Function eagerly compiled only for 64-bit contiguous arrays
@nb.njit('float64(float64[::1],)', fastmath=True)
def fast_sum(arr):
    s = 0.0
    for i in range(arr.size):
        s += arr[i]
    return s

以下是性能结果:

array items |    time    |  speedup (dot/numba_seq)
--------------------------|------------------------
 3_000_000   |   870 µs   |   x0.57
 1_000_000   |   183 µs   |   x0.49
   250_000   |    29 µs   |   x1.38

如果你使用parallel=Truenb.prange而不是range,Numba将使用多线程,这对于大型数组很好,但对于某些机器上的小型数组可能不太好(由于创建线程和共享工作的开销):

array items |    time    |  speedup (dot/numba_par)
--------------------------|--------------------------
 3_000_000   |   465 µs   |   x1.06
 1_000_000   |    66 µs   |   x1.36
   250_000   |    10 µs   |   x4.00

正如预期的那样,Numba对于小数组来说速度更快(因为Numpy调用开销基本上被消除了),并且在大数组方面可以与OpenBLAS竞争。

.LBB0_7:
        vaddpd  (%r9,%rdx,8), %ymm0, %ymm0
        vaddpd  32(%r9,%rdx,8), %ymm1, %ymm1
        vaddpd  64(%r9,%rdx,8), %ymm2, %ymm2
        vaddpd  96(%r9,%rdx,8), %ymm3, %ymm3
        vaddpd  128(%r9,%rdx,8), %ymm0, %ymm0
        vaddpd  160(%r9,%rdx,8), %ymm1, %ymm1
        vaddpd  192(%r9,%rdx,8), %ymm2, %ymm2
        vaddpd  224(%r9,%rdx,8), %ymm3, %ymm3
        vaddpd  256(%r9,%rdx,8), %ymm0, %ymm0
        vaddpd  288(%r9,%rdx,8), %ymm1, %ymm1
        vaddpd  320(%r9,%rdx,8), %ymm2, %ymm2
        vaddpd  352(%r9,%rdx,8), %ymm3, %ymm3
        vaddpd  384(%r9,%rdx,8), %ymm0, %ymm0
        vaddpd  416(%r9,%rdx,8), %ymm1, %ymm1
        vaddpd  448(%r9,%rdx,8), %ymm2, %ymm2
        vaddpd  480(%r9,%rdx,8), %ymm3, %ymm3
        addq    $64, %rdx
        addq    $-4, %r11
        jne     .LBB0_7

话虽如此,它并不是最佳的:LLVM-Lite JIT编译器使用4倍展开,而在我的Intel CoffeeLake处理器上,8倍展开应该是最佳的。实际上,vaddpd指令的延迟是4个周期,而每个周期可以执行2条指令,因此需要8个寄存器来避免延迟,并且生成的代码受到延迟限制。此外,该汇编代码在Intel Alderlake和Sapphire急流处理器上是最佳的,因为它们具有两倍低的vaddpd延迟。2饱和FMA SIMD处理单元远非易事。我认为编写更快函数的唯一方法是使用SIMD内部函数编写(C/C++)本机代码,尽管它的可移植性较差。
请注意,由于fastmath,Numba代码不支持NaN或Inf值等特殊数字(AFAIK OpenBLAS可以)。实际上,它应该仍然可以在x86-64机器上工作,但这并不能保证。此外,Numba代码对于非常大的数组来说在数值上是不稳定的。2 Numpy代码应该是三种变体中数值上最稳定的(然后是OpenBLAS代码)。你可以按块计算总和来提高数值稳定性,尽管这会使代码更加复杂。天下没有免费的午餐。

相关问题