python 在numba中递增numpy数组

mutmk8jj  于 2023-06-28  发布在  Python
关注(0)|答案(2)|浏览(180)

我有一个问题,本质上相当于计数,我需要并行化我的实现。我很懒,想继续用Python编码,所以我希望使用numba。
串行实现是这样的:

import numpy as np

def count_stuff(stuff):
    count=np.zeros(len(stuff))
    for thing in stuff:
        if some_condition(thing):
            i=some_integer_function(thing)
            count[i]+=1
    return count

但是,我预计以下天真的numba实现肯定会由于竞争条件而失败:

import numpy as np
from numba import njit,prange

@njit(parallel=True)
def count_stuff_parallel(stuff):
    count=np.zeros(len(stuff))
    for i in prange(len(stuff)):
        thing=stuff[i]
        if some_condition(thing):
            j=some_integer_function(thing)
            count[j]+=1
    return count

问题是不同的线程试图同时访问数组count。一个典型的解决方案是给予每个线程一个私有的count,然后在最后把它们加起来。但是,我不知道如何在numba中做到这一点。请指示。

suzh9iv8

suzh9iv81#

通常有几种方法可以做到这一点。两种常见的方法是使用原子访问或基于数组的减少。AFAIK,前者在Numba中还不可能,因为CPU还不支持原子操作。第二个解决方案可以在Numba中完成,但这并不是那么简单,因为Numba没有提供任何简单的方法来完成它(例如,与C/C++/Fortran中的OpenMP不同)。您需要手动完成这项工作。还有一个坏消息:在Numba中没有标准的方法来创建线程私有变量了。话虽如此,你可以把问题分成固定数量的块,并行计算它们,然后按顺序进行约简。这是相当麻烦,但它可能值得。下面是一个例子:

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def count_stuff_parallel(stuff):
    nChunks = 8  # To tune
    count = np.zeros((nChunks, len(stuff)))
    for i1 in prange(nChunks):
        start = len(stuff) * i1 // nChunks
        stop = len(stuff) * (i1 + 1) // nChunks
        for i2 in range(start, stop):
            thing = stuff[i2]
            if some_condition(thing):
                j = some_integer_function(thing)
                count[i1, j] += 1
    return count.sum(axis=0)            # If not supported by Numba: write a loop instead

理想情况下,nChunks应该很大,或者是线程数的倍数。然而,len(stuff) / nChunks也需要相对较小,以使最终约简快速。请注意,只有在len(stuff)很大或者some_integer_function函数很慢的情况下,才值得使用多线程。

deyfvvtc

deyfvvtc2#

这真的取决于我认为的细节。如果你有多余的内存,一个更纯粹的Numpy方法,你一次计算所有的东西,可能会允许更好的并行化(以内存为代价)。

crit = some_condition(stuff)
idx = some_integer_function(crit)
count = np.bincount(idx)

如果没有具体的例子,就不可能说什么是最佳的。

相关问题