NumPy加速数组搜索和合并操作

yws3nbqq  于 12个月前  发布在  其他
关注(0)|答案(2)|浏览(83)

我有三个数组,一个arr的id,一个range数组和一个query数组。arrquery非常大。
query中的一些值存在于arr中,而另一些则不存在。类似地,query中的一些值在range范围内,而另一些则不在。
x1m10 n1x中存在的x1m9 n1x中的值的数据来自源A,x1m11 n1x中不存在的值的数据来自源B。然后我将它们合并在一起,保持与query相同的原始顺序。例如,在下面的代码中,值15的数据应该位于result数组的索引1处,因为15位于query数组的索引1处。range也是如此。
我有一个简单的代码:

arr = np.array([1,7,3,9,5,10,2,8,4,6])
range = np.array([2,6])
query = np.array([3,15,6,8,13,5,19])

contains     = []; contains_idx     = []
not_contains = []; not_contains_idx = []
in_range     = []; in_range_idx     = []
not_in_range = []; not_in_range_idx = []
contains_mask = np.full(len(query), False)
range_mask = np.full(len(query), False)

for idx,x in enumerate(query):
    if x in arr:
        contains.append(x)
        contains_idx.append(idx)
        contains_mask[idx] = True
    else:
        not_contains.append(x)
        not_contains_idx.append(idx)
    if x >= range[0] and x <= range[1]:
        in_range.append(x)
        in_range_idx.append(idx)
        range_mask[idx] = True
    else:
        not_in_range.append(x)
        not_in_range_idx.append(idx)

result = np.zeros((len(query), 5)) # 5 dimensions, n x 5 array.
# Now I get the data for the values in contains and not_contains arrays.
# For simplicity, I am generating random data here.
data_for_contains = np.random.random((len(contains), 5))
data_for_not_contains = np.random.random((len(not_contains), 5))
# Put the acquired data in the result array, keeping the same order as query.
result[contains_idx] = data_for_contains
result[not_contains_idx] = data_for_not_contains
# The same operations are carried out to generate result for the in_range and not_in_range arrays.
# Some other operations are carried out on the mask arrays, omitted here.

由于arr非常大,而query也非常大(尽管比arr小),所以这段代码很慢。有没有办法优化它的速度和性能?
另外,这可以使用GPU加速吗?

cl25kdpy

cl25kdpy1#

您可以使用Python Set加速**x in s的进程。
Python Set被实现为
hash map**,因此the average time complexity of x in s operation is O(1),这比NumPy数组的in操作(O(N))快得多。
我可以给你看一些简单的实验:

import numpy as np

arr = np.random.rand(10000)
query = np.random.rand(10000)

这是代码的简化版本:

def naive_find_if_exists(arr, query):
    indices = []
    for idx, item in enumerate(query):
        if item in arr:
            indices.append(idx)
    return indices
%timeit naive_find_if_exists(arr, query)
23.9 ms ± 153 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Python Set优化版本

arr_set = set(arr)
def better_find_if_exists(arr, query):
    indices = []
    for idx, item in enumerate(query):
        if item in arr_set:
            indices.append(idx)
    return indices
%timeit better_find_if_exists(arr, query)
691 µs ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

从23.9ms到691us,性能提升了34.5倍

44u64gxh

44u64gxh2#

不要循环。

import numpy as np

arr = np.array((1, 7, 3, 9, 5, 10, 2, 8, 4, 6))
range_lo, range_hi = 2, 6
query = np.array((3, 15, 6, 8, 13, 5, 19))

np.random.seed(0)

contains_mask = np.isin(query, arr)
n_contained = np.count_nonzero(contains_mask)
data_for_contains = np.random.random((n_contained, 5))
data_for_not_contains = np.random.random((contains_mask.size - n_contained, 5))
contains_result = np.empty((len(query), 5))
contains_result[contains_mask] = data_for_contains
contains_result[~contains_mask] = data_for_not_contains

assert np.array_equal(
    contains_mask,
    ( True, False,  True,  True, False,  True, False),
)
assert np.allclose(
    contains_result,
    np.array([
       [0.5488135 , 0.71518937, 0.60276338, 0.54488318, 0.4236548 ],
       [0.97861834, 0.79915856, 0.46147936, 0.78052918, 0.11827443],
       [0.64589411, 0.43758721, 0.891773  , 0.96366276, 0.38344152],
       [0.79172504, 0.52889492, 0.56804456, 0.92559664, 0.07103606],
       [0.63992102, 0.14335329, 0.94466892, 0.52184832, 0.41466194],
       [0.0871293 , 0.0202184 , 0.83261985, 0.77815675, 0.87001215],
       [0.26455561, 0.77423369, 0.45615033, 0.56843395, 0.0187898 ],
    ])
)

range_mask = (query >= range_lo) & (query <= range_hi)
n_in_range = np.count_nonzero(range_mask)
data_for_range = np.random.random((n_in_range, 5))
data_for_not_range = np.random.random((range_mask.size - n_in_range, 5))
range_result = np.empty((len(query), 5))
range_result[range_mask] = data_for_range
range_result[~range_mask] = data_for_not_range

assert np.array_equal(
    range_mask,
    (True, False,  True, False, False,  True, False),
)
assert np.allclose(
    range_result,
    [[0.6176355 , 0.61209572, 0.616934  , 0.94374808, 0.6818203 ],
       [0.57019677, 0.43860151, 0.98837384, 0.10204481, 0.20887676],
       [0.3595079 , 0.43703195, 0.6976312 , 0.06022547, 0.66676672],
       [0.16130952, 0.65310833, 0.2532916 , 0.46631077, 0.24442559],
       [0.15896958, 0.11037514, 0.65632959, 0.13818295, 0.19658236],
       [0.67063787, 0.21038256, 0.1289263 , 0.31542835, 0.36371077],
       [0.36872517, 0.82099323, 0.09710128, 0.83794491, 0.09609841]]
)

相关问题