测试Numpy数组是否包含给定行

k7fdbhmy  于 2023-05-29  发布在  其他
关注(0)|答案(6)|浏览(127)

有没有一种Python的有效方法来检查Numpy数组是否包含给定行的至少一个示例?所谓“高效”,我的意思是它在找到第一个匹配行时终止,而不是迭代整个数组,即使已经找到了结果。
对于Python数组,这可以用if row in array:非常干净地完成,但这并不像我期望的Numpy数组那样工作,如下所示。
Python数组:

>>> a = [[1,2],[10,20],[100,200]]
>>> [1,2] in a
True
>>> [1,20] in a
False

但是Numpy数组给予了不同的、看起来相当奇怪的结果。(ndarray__contains__方法似乎没有文档记录。

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> np.array([1,2]) in a
True
>>> np.array([1,20]) in a
True
>>> np.array([1,42]) in a
True
>>> np.array([42,1]) in a
False
w3nuxt5m

w3nuxt5m1#

可以使用.tolist()

>>> a = np.array([[1,2],[10,20],[100,200]])
>>> [1,2] in a.tolist()
True
>>> [1,20] in a.tolist()
False
>>> [1,20] in a.tolist()
False
>>> [1,42] in a.tolist()
False
>>> [42,1] in a.tolist()
False

或者使用视图:

>>> any((a[:]==[1,2]).all(1))
True
>>> any((a[:]==[1,20]).all(1))
False

或者在numpy列表上生成(可能非常慢):

any(([1,2] == x).all() for x in a)     # stops on first occurrence

或者使用numpy逻辑函数:

any(np.equal(a,[1,2]).all(1))

如果您对这些时间进行计时:

import numpy as np
import time

n=300000
a=np.arange(n*3).reshape(n,3)
b=a.tolist()

t1,t2,t3=a[n//100][0],a[n//2][0],a[-10][0]

tests=[ ('early hit',[t1, t1+1, t1+2]),
        ('middle hit',[t2,t2+1,t2+2]),
        ('late hit', [t3,t3+1,t3+2]),
        ('miss',[0,2,0])]

fmt='\t{:20}{:.5f} seconds and is {}'     

for test, tgt in tests:
    print('\n{}: {} in {:,} elements:'.format(test,tgt,n))

    name='view'
    t1=time.time()
    result=(a[...]==tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='python list'
    t1=time.time()
    result = True if tgt in b else False
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='gen over numpy'
    t1=time.time()
    result=any((tgt == x).all() for x in a)
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

    name='logic equal'
    t1=time.time()
    np.equal(a,tgt).all(1).any()
    t2=time.time()
    print(fmt.format(name,t2-t1,result))

你可以看到,无论命中或未命中,numpy例程都以相同的速度搜索数组。Python in运算符对于早期命中来说 * 可能 * 要快得多,如果你必须一路遍历数组,那么生成器就是个坏消息。
以下是300,000 x 3元素数组的结果:

early hit: [9000, 9001, 9002] in 300,000 elements:
    view                0.01002 seconds and is True
    python list         0.00305 seconds and is True
    gen over numpy      0.06470 seconds and is True
    logic equal         0.00909 seconds and is True

middle hit: [450000, 450001, 450002] in 300,000 elements:
    view                0.00915 seconds and is True
    python list         0.15458 seconds and is True
    gen over numpy      3.24386 seconds and is True
    logic equal         0.00937 seconds and is True

late hit: [899970, 899971, 899972] in 300,000 elements:
    view                0.00936 seconds and is True
    python list         0.30604 seconds and is True
    gen over numpy      6.47660 seconds and is True
    logic equal         0.00965 seconds and is True

miss: [0, 2, 0] in 300,000 elements:
    view                0.00936 seconds and is False
    python list         0.01287 seconds and is False
    gen over numpy      6.49190 seconds and is False
    logic equal         0.00965 seconds and is False

对于3,000,000 x 3阵列:

early hit: [90000, 90001, 90002] in 3,000,000 elements:
    view                0.10128 seconds and is True
    python list         0.02982 seconds and is True
    gen over numpy      0.66057 seconds and is True
    logic equal         0.09128 seconds and is True

middle hit: [4500000, 4500001, 4500002] in 3,000,000 elements:
    view                0.09331 seconds and is True
    python list         1.48180 seconds and is True
    gen over numpy      32.69874 seconds and is True
    logic equal         0.09438 seconds and is True

late hit: [8999970, 8999971, 8999972] in 3,000,000 elements:
    view                0.09868 seconds and is True
    python list         3.01236 seconds and is True
    gen over numpy      65.15087 seconds and is True
    logic equal         0.09591 seconds and is True

miss: [0, 2, 0] in 3,000,000 elements:
    view                0.09588 seconds and is False
    python list         0.12904 seconds and is False
    gen over numpy      64.46789 seconds and is False
    logic equal         0.09671 seconds and is False

这似乎表明np.equal是最快的纯numpy方式来做到这一点。

wkyowqbh

wkyowqbh2#

在写这篇文章的时候,Numpys __contains__(a == b).any(),可以说只有当b是标量时才是正确的(这有点麻烦,但我相信-只有在1.7中才能这样工作。或以后-这将是正确的通用方法(a == b).all(np.arange(a.ndim - b.ndim, a.ndim)).any(),这对ab维度的所有组合都有意义)...
编辑:只是要明确,这是 * 不 * 一定是预期的结果时,广播参与。另外,有人可能会争辩说,它应该像np.in1d那样单独处理a中的项。我不确定是否有一种明确的方法可以奏效。
现在,您希望numpy在找到第一个匹配项时停止。此AFAIK此时不存在。这很困难,因为numpy主要基于ufuncs,它对整个数组做同样的事情。Numpy确实优化了这类缩减,但实际上只有当被缩减的数组已经是布尔数组时才有效(即np.ones(10, dtype=bool).any())。
否则,它将需要一个不存在的__contains__的特殊函数。这可能看起来很奇怪,但你必须记住,numpy支持许多数据类型,并且有一个更大的机制来选择正确的数据类型并选择正确的函数来处理它。所以换句话说,ufunc机器不能做到这一点,并且由于数据类型的原因,实现__contains__或类似的特殊功能实际上并不是那么简单。
你当然可以用python来写,或者因为你可能知道你的数据类型,所以用Cython/C自己写非常简单。
也就是说,通常使用基于排序的方法来处理这些事情要好得多。这有点乏味,而且对于lexsort没有searchsorted这样的东西,但它可以工作(如果你喜欢,你也可以滥用scipy.spatial.cKDTree)。这假设您只想沿着最后一个轴进行比较:

# Unfortunatly you need to use structured arrays:
sorted = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()

# Actually at this point, you can also use np.in1d, if you already have many b
# then that is even better.

sorted.sort()

b_comp = np.ascontiguousarray(b).view(sorted.dtype)
ind = sorted.searchsorted(b_comp)

result = sorted[ind] == b_comp

这也适用于数组b,如果你保持有序数组,如果你一次对b中的单个值(行)执行,当a保持不变时(否则我将在将其视为recarray后仅为np.in1d)。* 重要提示:* 为了安全,您必须执行np.ascontiguousarray。它通常不会做任何事情,但如果它做了,这将是一个很大的潜在错误。

suzh9iv8

suzh9iv83#

我觉得

equal([1,2], a).all(axis=1)   # also,  ([1,2]==a).all(axis=1)
# array([ True, False, False], dtype=bool)

将列出匹配的行。正如Jamie指出的,要知道是否至少存在一个这样的行,使用any

equal([1,2], a).all(axis=1).any()
# True

旁白:
我怀疑in(和__contains__)和上面一样,但是使用any而不是all

gtlvzcf8

gtlvzcf84#

我将建议的解决方案与perfplot进行了比较,发现如果您在一个长的未排序列表中查找一个2元组,

np.any(np.all(a == b, axis=1))

是最快的解决方案。如果在前几行中找到匹配,则显式短路循环总是更快。

用于重现绘图的代码:

import numpy as np
import perfplot

target = [6, 23]

def setup(n):
    return np.random.randint(0, 100, (n, 2))

def any_all(data):
    return np.any(np.all(target == data, axis=1))

def tolist(data):
    return target in data.tolist()

def loop(data):
    for row in data:
        if np.all(row == target):
            return True
    return False

def searchsorted(a):
    s = np.ascontiguousarray(a).view([('', a.dtype)] * a.shape[-1]).ravel()
    s.sort()
    t = np.ascontiguousarray(target).view(s.dtype)
    ind = s.searchsorted(t)
    return (s[ind] == t)[0]

perfplot.save(
    "out02.png",
    setup=setup,
    kernels=[any_all, tolist, loop, searchsorted],
    n_range=[2 ** k for k in range(2, 20)],
    xlabel="len(array)",
)
ukdjmx9f

ukdjmx9f5#

如果你真的想在第一次出现时停止,你可以写一个循环,比如:

import numpy as np

needle = np.array([10, 20])
haystack = np.array([[1,2],[10,20],[100,200]])
found = False
for row in haystack:
    if np.all(row == needle):
        found = True
        break
print("Found: ", found)

然而,我强烈怀疑,它会比其他使用numpy例程对整个数组执行的建议慢得多。

qybjjes1

qybjjes16#

要知道一个特定的一维numpy数组(行)是否存在于二维numpy数组中,一种更简单的方法是使用以下条件。

if np.sum(np.prod(2-darray == 1-darray),axis = 1)) > 0

如果np.sum(np.prod(2-darray == 1-darray),axis = 1))大于0,则该行存在于2-D阵列中,否则不存在。

相关问题