考虑以下两种对2d numpy数组中的所有值求和的方法。
import numpy as np
from numba import njit
a = np.random.rand(2, 5000)
@njit(fastmath=True, cache=True)
def sum_array_slow(arr):
s = 0
for i in range(arr.shape[0]):
for j in range(arr.shape[1]):
s += arr[i, j]
return s
@njit(fastmath=True, cache=True)
def sum_array_fast(arr):
s = 0
for i in range(arr.shape[1]):
s += arr[0, i]
for i in range(arr.shape[1]):
s += arr[1, i]
return s
看看sum_array_slow中的嵌套循环,它似乎应该以与sum_array_fast相同的顺序执行完全相同的操作。然而:
In [46]: %timeit sum_array_slow(a)
7.7 µs ± 374 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [47]: %timeit sum_array_fast(a)
951 ns ± 2.63 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
为什么sum_array_fast函数的速度比sum_array_slow函数快8倍,而它似乎是以相同的顺序执行相同的计算?
1条答案
按热度按时间cnjp1d6j1#
这是因为慢版本是不自动向量化(即编译器无法生成快速SIMD代码),而快版本是。这肯定是因为Numba在第一次循环中没有优化索引 Package ,所以这是Numba的一个错过的优化。
这可以通过分析汇编代码看出。下面是慢速版本的热循环:
我们可以看到,Numba产生了许多无用的索引检查,这使得循环效率非常低。我不知道有任何干净的方法来解决这个问题。这是可悲的,因为这样的问题在实践中远非罕见。使用像C和C++这样的本地语言可以解决这个问题(因为数组中没有索引 Package )。一种不安全/丑陋的方法是在Numba中使用指针,但是提取Numpy数据指针并将其交给Numba似乎相当痛苦(如果可能的话)。
这是最快的一个:
在这种情况下,循环得到了很好的优化。事实上,它对于大型数组几乎是最优的。对于小型数组,就像你的例子一样,它在像我这样的处理器上不是最优的。事实上,AFAIK,展开的指令没有使用足够的寄存器以隐藏FMA单元的等待时间(这是因为LLVM在内部生成了一个次优代码)。可能需要一个较低级别的本机代码来解决这个问题(至少,在Numba中没有简单的方法来解决这个问题)。
更新
感谢@max9111提供的this link,可以通过使用无符号整数来优化缓慢的代码。这个技巧大大提高了执行时间。下面是修改后的代码:
以下是英特尔至强W-2255处理器的性能:
将
opt=0
替换为opt=2
的解决方案(再次感谢@max911)在我的机器上没有给出很好的结果:更不用说编译的时间也稍微大一些。
可以实现更快的实现,以便更好地隐藏FMA指令的延迟:
这个需要1.08 µs,更好。
生成的Numba代码仍然有两个限制因素:
请注意,可以使用Numba函数的方法
inspect_asm
提取汇编代码。