numpy 如何用Numba在数组上并行化for循环

oogrdqng  于 2023-05-17  发布在  其他
关注(0)|答案(1)|浏览(189)

我正在尝试编写一个Numba函数,该函数使用@numba.njit(parallel=True)循环遍历列表(而不是范围)。比如说

import numpy as np
import numba

arr = np.ones(10)
idx = np.array([4, 2, 5])

@numba.njit(parallel=True)
def foo(arr, idx):
    for i in idx:
        arr[i] = 0
    return arr

foo(arr, idx)

我得到的警告是NumbaPerformanceWarning: The keyword argument 'parallel=True' was specified but no transformation for parallel execution was possible.我知道可以使用numba.prange()显式并行化类似的循环,但我需要循环遍历索引数组。这是可能的,如何做到的?

pbgvytdp

pbgvytdp1#

for i in idx: ...在语义上等同于:

for j in range(len(idx)):
    i = idx[j]
    [...]

第二个版本的性能可能略有不同。在第二个版本的基础上,您可以将range替换为prange并获得并行版本。
但是,请注意,如果两个idx值相同,则会导致竞争条件。竞争条件导致未定义的行为(基本上,任何事情都可能发生)。

相关问题