我有一个问题,本质上相当于计数,我需要并行化我的实现。我很懒,想继续用Python编码,所以我希望使用numba。
串行实现是这样的:
import numpy as np
def count_stuff(stuff):
count=np.zeros(len(stuff))
for thing in stuff:
if some_condition(thing):
i=some_integer_function(thing)
count[i]+=1
return count
但是,我预计以下天真的numba实现肯定会由于竞争条件而失败:
import numpy as np
from numba import njit,prange
@njit(parallel=True)
def count_stuff_parallel(stuff):
count=np.zeros(len(stuff))
for i in prange(len(stuff)):
thing=stuff[i]
if some_condition(thing):
j=some_integer_function(thing)
count[j]+=1
return count
问题是不同的线程试图同时访问数组count
。一个典型的解决方案是给予每个线程一个私有的count
,然后在最后把它们加起来。但是,我不知道如何在numba中做到这一点。请指示。
2条答案
按热度按时间suzh9iv81#
通常有几种方法可以做到这一点。两种常见的方法是使用原子访问或基于数组的减少。AFAIK,前者在Numba中还不可能,因为CPU还不支持原子操作。第二个解决方案可以在Numba中完成,但这并不是那么简单,因为Numba没有提供任何简单的方法来完成它(例如,与C/C++/Fortran中的OpenMP不同)。您需要手动完成这项工作。还有一个坏消息:在Numba中没有标准的方法来创建线程私有变量了。话虽如此,你可以把问题分成固定数量的块,并行计算它们,然后按顺序进行约简。这是相当麻烦,但它可能值得。下面是一个例子:
理想情况下,
nChunks
应该很大,或者是线程数的倍数。然而,len(stuff) / nChunks
也需要相对较小,以使最终约简快速。请注意,只有在len(stuff)
很大或者some_integer_function
函数很慢的情况下,才值得使用多线程。deyfvvtc2#
这真的取决于我认为的细节。如果你有多余的内存,一个更纯粹的Numpy方法,你一次计算所有的东西,可能会允许更好的并行化(以内存为代价)。
如果没有具体的例子,就不可能说什么是最佳的。