找到两个NumPy数组的交集同时保持顺序的最快方法是什么?

w6mmgewl  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(88)

我有两个相同长度的一维NumPy数组A和B。我想找到两个数组的交集,也就是说我想找到A中所有的元素,这些元素也存在于B中。
当数组A中索引处的元素也是数组B的成员时,结果应该是一个布尔数组True,保留顺序,以便我可以使用结果索引另一个数组。
如果没有布尔掩码约束,我会将两个数组都转换为集合,并使用集合交集运算符(&)。然而,我尝试过使用np.isinnp.in1d,发现使用普通的Python列表理解要快得多。
给定设置:

import numba
import numpy as np

primes = np.array([
    2,   3,   5,   7,  11,  13,  17,  19,  23,  29,  31,  37,  41,
    43,  47,  53,  59,  61,  67,  71,  73,  79,  83,  89,  97, 101,
    103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167,
    173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239,
    241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313,
    317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397,
    401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467,
    479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569,
    571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643,
    647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
    739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823,
    827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
    919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997],
    dtype=np.int64)

@numba.vectorize(nopython=True, cache=True, fastmath=True, forceobj=False)
def reverse_digits(n, base):
    out = 0
    while n:
        n, rem = divmod(n, base)
        out = out * base + rem
    return out

flipped = reverse_digits(primes, 10)

def set_isin(a, b):
    return a in b

vec_isin = np.vectorize(set_isin)

primes包含1000以下的所有素数,总数为168。我选择它是因为它是体面的大小和预定的。我做了各种测试:

In [2]: %timeit np.isin(flipped, primes)
51.3 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [3]: %timeit np.in1d(flipped, primes)
46.2 µs ± 386 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [4]: %timeit setp = set(primes)
12.9 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [5]: %timeit setp = set(primes.tolist())
6.84 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [6]: %timeit setp = set(primes.flat)
11.5 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [7]: setp = set(primes.tolist())

In [8]: %timeit [x in setp for x in flipped]
23.3 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [9]: %timeit [x in setp for x in flipped.tolist()]
12.1 µs ± 76.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [10]: %timeit [x in setp for x in flipped.flat]
19.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [11]: %timeit vec_isin(flipped, setp)
40 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [12]: %timeit np.frompyfunc(lambda x: x in setp, 1, 1)(flipped)
25.7 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [13]: %timeit setf = set(flipped.tolist())
6.51 µs ± 44 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [14]: setf = set(flipped.tolist())

In [15]: %timeit np.array(sorted(setf & setp))
9.42 µs ± 78.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

setp = set(primes.tolist()); [x in setp for x in flipped.tolist()]大约需要19微秒,比NumPy方法快。我想知道为什么会这样,如果有一种方法可以使它更快。
(我写了所有的代码,我使用AI建议编辑功能来编辑问题)

ao218c7q

ao218c7q1#

为什么提供的解决方案效率不高

np.isin有两个实现。第一种方法是对两个数组进行排序(使用合并排序),然后合并它们。该解决方案在O(n log n + m log m + n+m)中运行,即O(n log n + m log m)。另一种实现方式基于查找表。第二个实现基于第二个数组创建一个布尔值数组,然后检查是否为第一个数组的每个项设置了lookupTable[item]。对于包含小整数的数组,第二种实现可能会更快(这有点复杂,但explained in the documentation)。第二个解决方案运行在O(n + m + max(arr2))(甚至理论上O(n + m)在一些平台上有一个大的隐藏常数)。但是,它可以使用 * 更多的内存 *。Numpy尝试默认选择最好的一个。在你的例子中,两个数组都很小,里面的整数也相对较小,所以两个解决方案相对较快。对于包含小整数的较大数组,查找表应该更快。
问题是Numpy在这里效率不高,因为与实际计算相比,调用这样的Numpy函数的开销相对较大。此外,第二个数组已经排序,所以再次排序效率不高。

实现更快

例如,可以使用二进制搜索在第二个数组中找到第一个数组的值,而无需分配任何额外的临时数组。你可以使用Numba来减少在小数组上调用Numpy几个函数的开销,甚至使用jitted循环更快地填充结果。下面是最终的实现:

# Assume primes is sorted
@numba.njit('bool_[:](int64[:],int64[:])')
def compute(flipped, primes):
    assert primes.size > 0 and primes.size == flipped.size
    res = np.empty(flipped.size, dtype=np.bool_)
    idx = np.searchsorted(primes, flipped)
    for i in range(res.size):
        if idx[i] < len(primes) and primes[idx[i]] == flipped[i]:
            res[i] = True
        else:
            res[i] = False
    return res

在我的机器上,这个解决方案比np.isin(flipped, primes)15倍,比所有其他替代方案都快(快得多)。在提供的输入上仅需约2 µs。它的规模也相对较好。

大阵列最快解决方案

对于大型数组,使用查找表应该更快,因为上面的解决方案在O(n log m)时间内运行,而查找表实现可以在线性时间内运行。也就是说,查找表也会使用更多的内存。最好的方法是使用Bloom filter来使查找表更加紧凑(多亏了哈希)。然而,该解决方案实现起来明显更复杂。这里有一个setdif1d的例子。最快的解决方案往往是以更复杂的代码为代价的(没有免费的午餐)。

相关问题