我开始分析我的应用程序,并测试了以下代码:
a = np.random.random((100000,))
def get_first_index(value, arr):
firstIndex = np.argmax(arr > value)
if firstIndex <= 0:
raise Exception('No index found')
return firstIndex
for i in range(0, 1000):
get_first_index(0.5, a)
字符串
它只是返回大于给定值的元素的第一个索引。在我的机器上,数组大小为50k和1k的调用大约需要0.01s。我想知道是什么原因导致了慢下来。我的第一个怀疑是np.argmax
,但我将其归结为布尔比较arr > value
。它花费99%的时间创建bool比较。有没有我不知道的更快的方法?
性能分析的测试代码:
a = np.random.random((100000,))
def test_function(a, b):
return a < b
import cProfile, pstats
profiler = cProfile.Profile()
profiler.enable()
for i in range(0, 1000):
test_function(0.5, a)
profiler.disable()
stats = pstats.Stats(profiler).sort_stats('tottime')
stats.print_stats()
型
1条答案
按热度按时间q9yhzks01#
你的方法慢的原因是
arr
的所有元素在所有情况下都要进行比较,即使arr
的第一个元素大于value
。虽然numpy没有针对这种处理进行优化的API,但可以使用Numba来代替,这可以很容易地实现如下。
个字符
如果
arr
是一个有序数组,那么np.searchsorted甚至更快。