Numpy数组bool操作缓慢

o4tp2gmn  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(86)

我开始分析我的应用程序,并测试了以下代码:

a = np.random.random((100000,))
        
def get_first_index(value, arr):
   firstIndex = np.argmax(arr > value)
   if firstIndex <= 0:
      raise Exception('No index found')
   return firstIndex

for i in range(0, 1000):
   get_first_index(0.5, a)

字符串
它只是返回大于给定值的元素的第一个索引。在我的机器上,数组大小为50k和1k的调用大约需要0.01s。我想知道是什么原因导致了慢下来。我的第一个怀疑是np.argmax,但我将其归结为布尔比较arr > value。它花费99%的时间创建bool比较。有没有我不知道的更快的方法?
性能分析的测试代码:

a = np.random.random((100000,))

def test_function(a, b):
    return a < b

import cProfile, pstats

profiler = cProfile.Profile()
profiler.enable()

for i in range(0, 1000):
    test_function(0.5, a)

profiler.disable()
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()

q9yhzks0

q9yhzks01#

你的方法慢的原因是arr的所有元素在所有情况下都要进行比较,即使arr的第一个元素大于value
虽然numpy没有针对这种处理进行优化的API,但可以使用Numba来代替,这可以很容易地实现如下。

import timeit

import numba
import numpy as np

np.random.seed(0)
a = np.random.random((100000,))

def get_first_index(value, arr):
    firstIndex = np.argmax(arr > value)
    if firstIndex <= 0:
        raise Exception("No index found")
    return firstIndex

@numba.njit("i8(f8,f8[:])", cache=True)
def get_first_index2(value, arr):
    for i in range(len(arr)):
        if arr[i] > value:
            return i
    raise Exception("No index found")

x = 0.9999
print(timeit.timeit(lambda: get_first_index(x, a), number=1000))
print(timeit.timeit(lambda: get_first_index2(x, a), number=1000))

个字符
如果arr是一个有序数组,那么np.searchsorted甚至更快。

相关问题