numpy 具有“rows”和索引的ismember的Python版本

vshtjzan  于 2023-01-17  发布在  Python
关注(0)|答案(4)|浏览(178)

类似的问题也有人问过,但没有一个答案完全符合我的需要--有些答案允许多维搜索(在matlab中又名“rows”选项)但不返回索引。有些返回索引但不允许行。我的数组非常大(1 M x 2),我已经成功地做了一个循环,但显然这是非常缓慢的。在matlab中,内置的ismember函数花费大约10秒。
这是我正在寻找的:

a=np.array([[4, 6],[2, 6],[5, 2]])

b=np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])

实现这个技巧的matlab函数是:

[~,index] = ismember(a,b,'rows')

其中

index = [6, 3, 9]
2sbarzqh

2sbarzqh1#

import numpy as np

def asvoid(arr):
    """
    View the array as dtype np.void (bytes)
    This views the last axis of ND-arrays as bytes so you can perform comparisons on
    the entire row.
    http://stackoverflow.com/a/16840350/190597 (Jaime, 2013-05)
    Warning: When using asvoid for comparison, note that float zeros may compare UNEQUALLY
    >>> asvoid([-0.]) == asvoid([0.])
    array([False], dtype=bool)
    """
    arr = np.ascontiguousarray(arr)
    return arr.view(np.dtype((np.void, arr.dtype.itemsize * arr.shape[-1])))

def in1d_index(a, b):
    voida, voidb = map(asvoid, (a, b))
    return np.where(np.in1d(voidb, voida))[0]    

a = np.array([[4, 6],[2, 6],[5, 2]])
b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])

print(in1d_index(a, b))

印刷品

[2 5 8]

这相当于Matlab的[3,6,9],因为Python使用基于0的索引。
一些警告:
1.索引以升序返回,它们不对应于a项在b中的位置。

  1. asvoid将适用于整数数据类型,但如果在浮点数据类型上使用asvoid,则要小心,因为asvoid([-0.]) == asvoid([0.])返回array([False])
  2. asvoid在连续数组上工作得最好。如果数组不连续,数据将被复制到连续数组中,这将降低性能。
    尽管有这些警告,但出于速度的考虑,还是可以选择使用in1d_index
def ismember_rows(a, b):
    # http://stackoverflow.com/a/22705773/190597 (ashg)
    return np.nonzero(np.all(b == a[:,np.newaxis], axis=2))[1]

In [41]: a2 = np.tile(a,(2000,1))
In [42]: b2 = np.tile(b,(2000,1))

In [46]: %timeit in1d_index(a2, b2)
100 loops, best of 3: 8.49 ms per loop

In [47]: %timeit ismember_rows(a2, b2)
1 loops, best of 3: 5.55 s per loop

因此in1d_index快了大约650倍(对于长度在千位以下的数组),但再次注意,这种比较并不完全是对等的,因为in1d_index按升序返回索引,而ismember_rowsa的行在b中显示的顺序返回索引。

nsc4cvqm

nsc4cvqm2#

import numpy as np 
def ismember_rows(a, b):
    '''Equivalent of 'ismember' from Matlab
    a.shape = (nRows_a, nCol)
    b.shape = (nRows_b, nCol)
    return the idx where b[idx] == a
    '''
    return np.nonzero(np.all(b == a[:,np.newaxis], axis=2))[1]

a = np.array([[4, 6],[2, 6],[5, 2]])
b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
idx = ismember_rows(a, b)
print idx
print np.all(b[idx] == a)

列印

array([5, 2, 8])
True

我用广播

  • --------------------------[更新]----------------
def ismember(a, b):
    return np.flatnonzero(np.in1d(b[:,0], a[:,0]) & np.in1d(b[:,1], a[:,1]))

a = np.array([[4, 6],[2, 6],[5, 2]])
b = np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
a2 = np.tile(a,(2000,1))
b2 = np.tile(b,(2000,1))

%timeit timeit in1d_index(a2, b2)
# 100 loops, best of 3: 8.74 ms per loop
%timeit ismember(a2, b2)
# 100 loops, best of 3: 8.5 ms per loop

np.all(in1d_index(a2, b2) == ismember(a2, b2))
# True

正如联合国大学所述,指数是按升序返回的

gcxthw6b

gcxthw6b3#

该函数首先将多列元素转换为单列数组,然后可以使用numpy.in1d来查找所需的答案,请尝试以下代码:

import numpy as np

def ismemberRow(A,B):
    '''
    This function is find which rows found in A can be also found in B,
    The function first turns multiple columns of elements into a single column array, then numpy.in1d can be used

    Input: m x n numpy array (A), and p x q array (B)
    Output unique numpy array with length m, storing either True or False, True for rows can be found in both A and B
    '''

    sa = np.chararray((A.shape[0],1))
    sa[:] = '-'
    sb = np.chararray((B.shape[0],1))
    sb[:] = '-'

    ba = (A).astype(np.str)
    sa2 = np.expand_dims(ba[:,0],axis=1) + sa + np.expand_dims(ba[:,1],axis=1)
    na = A.shape[1] - 2    

    for i in range(0,na):
         sa2 = sa2 + sa + np.expand_dims(ba[:,i+2],axis=1)

    bb = (B).astype(np.str)
    sb2 = np.expand_dims(bb[:,0],axis=1) + sb + np.expand_dims(bb[:,1],axis=1)
    nb = B.shape[1] - 2    

    for i in range(0,nb):
         sb2 = sb2 + sb + np.expand_dims(bb[:,i+2],axis=1)

    return np.in1d(sa2,sb2)

A = np.array([[1, 3, 4],[2, 4, 3],[7, 4, 3],[1, 1, 1],[1, 3, 4],[5, 3, 4],[1, 1, 1],[2, 4, 3]])

B = np.array([[1, 3, 4],[1, 1, 1]])

d = ismemberRow(A,B)

print A[np.where(d)[0],:]

#results:
#[[1 3 4]
# [1 1 1]
# [1 3 4]
# [1 1 1]]
ars1skjm

ars1skjm4#

下面是一个基于libigl's igl::ismember_rows的函数,它非常接近Matlab的ismember(A,B,'rows')的行为:

def ismember_rows(A,B, return_index=False):
    """
    Return whether each row in A occurs as a row in B
    
    Parameters
    ----------
    A : #A by dim array
    B : #B by dim array
    return_index : {True,False}, optional.
    
    Returns
    -------
    IA : #A 1D array, IA[i] == True if and only if
        there exists j = LOCB[i] such that B[j,:] == A[i,:]
    LOCB : #A 1D array of indices. LOCB[j] == -1 if IA[i] == False, 
        only returned if return_index=True
    """
    IA = np.full(A.shape[0],False)
    LOCB = np.full(A.shape[0],-1)
    if len(A) == 0: return (IA,LOCB) if return_index else IA
    if len(B) == 0: return (IA,LOCB) if return_index else IA
    # Get rid of any duplicates
    uA,uIuA = np.unique(A, axis=0, return_inverse=True)
    uB,uIB = np.unique(B, axis=0, return_index=True)
    # Sort both
    sIA = np.lexsort(uA.T[::-1])
    sA = uA[sIA,:]
    sIB = np.lexsort(uB.T[::-1])
    sB = uB[sIB,:]
    #
    uF = np.full(sA.shape[0],False)
    uLOCB = np.full(sA.shape[0],-1)
    def row_greater_than(a,b):
        for c in range(sA.shape[1]):
            if(sA[a,c] > sB[b,c]): return True
            if(sA[a,c] < sB[b,c]): return False
        return False
    # loop over sA
    bi = 0
    past = False
    for a in range(A.shape[0]):
        assert(bi < sB.shape[0])
        while not past and row_greater_than(a,bi):
            bi+=1
            past = bi>=sB.shape[0]
        if not past and np.all(sA[a,:]==sB[bi,:]):
            uF[sIA[a]] = True
            uLOCB[sIA[a]] = uIB[sIB[bi]]
    for a in range(A.shape[0]):
        IA[a] = uF[uIuA[a]]
        LOCB[a] = uLOCB[uIuA[a]]
    return (IA,LOCB) if return_index else IA

例如,

a=np.array([[4, 6],[6,6],[2, 6],[5, 2]])
b=np.array([[1, 7],[1, 8],[2, 6],[2, 1],[2, 4],[4, 6],[4, 7],[5, 9],[5, 2],[5, 1]])
(flag,index) = ismember_rows(a,b,return_index=True)

生产

>>> flag
array([ True, False,  True,  True])
>>> index
array([ 5, -1,  2,  8])

相关问题