我有两个相同长度的一维NumPy数组A和B。我想找到两个数组的交集,也就是说我想找到A中所有的元素,这些元素也存在于B中。
当数组A中索引处的元素也是数组B的成员时,结果应该是一个布尔数组True
,保留顺序,以便我可以使用结果索引另一个数组。
如果没有布尔掩码约束,我会将两个数组都转换为集合,并使用集合交集运算符(&
)。然而,我尝试过使用np.isin
和np.in1d
,发现使用普通的Python列表理解要快得多。
给定设置:
import numba
import numpy as np
primes = np.array([
2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41,
43, 47, 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101,
103, 107, 109, 113, 127, 131, 137, 139, 149, 151, 157, 163, 167,
173, 179, 181, 191, 193, 197, 199, 211, 223, 227, 229, 233, 239,
241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, 313,
317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397,
401, 409, 419, 421, 431, 433, 439, 443, 449, 457, 461, 463, 467,
479, 487, 491, 499, 503, 509, 521, 523, 541, 547, 557, 563, 569,
571, 577, 587, 593, 599, 601, 607, 613, 617, 619, 631, 641, 643,
647, 653, 659, 661, 673, 677, 683, 691, 701, 709, 719, 727, 733,
739, 743, 751, 757, 761, 769, 773, 787, 797, 809, 811, 821, 823,
827, 829, 839, 853, 857, 859, 863, 877, 881, 883, 887, 907, 911,
919, 929, 937, 941, 947, 953, 967, 971, 977, 983, 991, 997],
dtype=np.int64)
@numba.vectorize(nopython=True, cache=True, fastmath=True, forceobj=False)
def reverse_digits(n, base):
out = 0
while n:
n, rem = divmod(n, base)
out = out * base + rem
return out
flipped = reverse_digits(primes, 10)
def set_isin(a, b):
return a in b
vec_isin = np.vectorize(set_isin)
primes
包含1000以下的所有素数,总数为168。我选择它是因为它是体面的大小和预定的。我做了各种测试:
In [2]: %timeit np.isin(flipped, primes)
51.3 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [3]: %timeit np.in1d(flipped, primes)
46.2 µs ± 386 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [4]: %timeit setp = set(primes)
12.9 µs ± 133 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [5]: %timeit setp = set(primes.tolist())
6.84 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [6]: %timeit setp = set(primes.flat)
11.5 µs ± 54.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [7]: setp = set(primes.tolist())
In [8]: %timeit [x in setp for x in flipped]
23.3 µs ± 739 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [9]: %timeit [x in setp for x in flipped.tolist()]
12.1 µs ± 76.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [10]: %timeit [x in setp for x in flipped.flat]
19.7 µs ± 249 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [11]: %timeit vec_isin(flipped, setp)
40 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [12]: %timeit np.frompyfunc(lambda x: x in setp, 1, 1)(flipped)
25.7 µs ± 418 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
In [13]: %timeit setf = set(flipped.tolist())
6.51 µs ± 44 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [14]: setf = set(flipped.tolist())
In [15]: %timeit np.array(sorted(setf & setp))
9.42 µs ± 78.9 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
setp = set(primes.tolist()); [x in setp for x in flipped.tolist()]
大约需要19微秒,比NumPy方法快。我想知道为什么会这样,如果有一种方法可以使它更快。
(我写了所有的代码,我使用AI建议编辑功能来编辑问题)
1条答案
按热度按时间ao218c7q1#
为什么提供的解决方案效率不高
np.isin
有两个实现。第一种方法是对两个数组进行排序(使用合并排序),然后合并它们。该解决方案在O(n log n + m log m + n+m)
中运行,即O(n log n + m log m)
。另一种实现方式基于查找表。第二个实现基于第二个数组创建一个布尔值数组,然后检查是否为第一个数组的每个项设置了lookupTable[item]
。对于包含小整数的数组,第二种实现可能会更快(这有点复杂,但explained in the documentation)。第二个解决方案运行在O(n + m + max(arr2))
(甚至理论上O(n + m)
在一些平台上有一个大的隐藏常数)。但是,它可以使用 * 更多的内存 *。Numpy尝试默认选择最好的一个。在你的例子中,两个数组都很小,里面的整数也相对较小,所以两个解决方案相对较快。对于包含小整数的较大数组,查找表应该更快。问题是Numpy在这里效率不高,因为与实际计算相比,调用这样的Numpy函数的开销相对较大。此外,第二个数组已经排序,所以再次排序效率不高。
实现更快
例如,可以使用二进制搜索在第二个数组中找到第一个数组的值,而无需分配任何额外的临时数组。你可以使用Numba来减少在小数组上调用Numpy几个函数的开销,甚至使用jitted循环更快地填充结果。下面是最终的实现:
在我的机器上,这个解决方案比
np.isin(flipped, primes)
快15倍,比所有其他替代方案都快(快得多)。在提供的输入上仅需约2 µs。它的规模也相对较好。大阵列最快解决方案
对于大型数组,使用查找表应该更快,因为上面的解决方案在
O(n log m)
时间内运行,而查找表实现可以在线性时间内运行。也就是说,查找表也会使用更多的内存。最好的方法是使用Bloom filter来使查找表更加紧凑(多亏了哈希)。然而,该解决方案实现起来明显更复杂。这里有一个setdif1d
的例子。最快的解决方案往往是以更复杂的代码为代价的(没有免费的午餐)。