numpy 检查一个数组中的哪些元素在另一个数组中

sc4hvdpw  于 2023-10-19  发布在  其他
关注(0)|答案(4)|浏览(163)

我有两个大数据数组[100000 x3],分别称为A和B
我想知道一个数组中的任何元素是否存在于另一个数组中。
如果两个数组中的一个条目和另一个条目中的值之间的差值小于1 e-9,那么它们被认为是相同的。
实现这一目标的最有效方法是什么?
我有一个工作示例:

A = np.random.rand(100000, 3)
B = np.random.rand(100000, 3)
B[10] = A[10] + 1e-11
a = []
for entry in A:
    if np.min(np.max(np.abs(entry - B), axis=1)) < 1e-9:
        a.append(entry)

但是对于这个例子的大小,这将需要我很长时间来检查。

ffvjumwh

ffvjumwh1#

使用math.isclose()来测试数字是否在彼此的公差范围内。
使用any()函数对另一个数组的每个元素进行测试。

flat_A = A.flatten()
flat_B = B.flatten()
result = [a for a in flat_A if any(math.isclose(a, b, abs_tol=1e-9) for b in flat_B)]

如果对B进行排序,可以改进算法。然后你可以使用二进制搜索来找到接近的元素,而不是线性搜索。

bjg7j2ky

bjg7j2ky2#

使用KDTree % s

from scipy.spatial import KDTree

A_Tree = KDTree(A)
B_Tree = KDTree(B)
a = np.where([len(i) >0 for i in B_Tree.query_ball_tree(A_Tree, 1e-9, p=100)])

a
Out:        v
(array([   10,  3884,  9511, 10878, 23977, 40623, 45204, 49036, 51976,
        54020, 56403, 60273, 64507, 77374, 82695, 91806, 93235, 97881],
       dtype=int64),)

高p范数近似于你正在寻找的切比雪夫距离度量,KDTree在两棵树之间进行二叉树搜索,这比顺序搜索方法快得多,比矢量化方法节省内存得多。

yhived7q

yhived7q3#

钝器

我认为我们可以检查舍入值的交集来回答数组是否有公共项的问题:

accuracy = 9
intersection = np.isin(A.round(accuracy), B.round(accuracy))
common_items = A[intersection]
has_commons = intersection.any()

我同意Barmar的评论,这不是一个准确的工具。我们可以忽略一些小案子。

更精准的工具

这将需要更多的时间,当使用numba时,我似乎可以接受,除非有其他要求(在i3-2100 3GHz上,需要4分钟才能得到答案):

import numpy as np
from numba import njit

@njit
def intersect(a, b):
    '''return an array c of a.shape with c[i]=1 if a[i] is close to any item in b'''
    a_shape = a.shape
    a = a.ravel()
    c = np.zeros_like(a)
    for i in range(c.size):
        c[i] = np.isclose(b, a[i]).any()
    return c.reshape(a_shape)
  • p.s.问题仍然是开放的:比较什么-矩阵单元格还是它们的行?*
lndjwyie

lndjwyie4#

您可以对数组进行排序以降低复杂性。

示例

import numpy as np
import numba as nb

@nb.njit() #With jit approx. 30ms, without approx. 500ms
def get_simmilar(A,B):
    """
    Returns a set with the indices to A_sorted
    """
    idx_A=np.argsort(A[:,0])
    A_sorted = A[idx_A,:] 

    idx_B=np.argsort(B[:,0])
    B_sorted = B[idx_B,:]

    a=set()
    ii=0
    for i in range(A_sorted.shape[0]):
        while A_sorted[i,0] > B_sorted[ii,0] and not ii > B_sorted.shape[0]-1:
            ii+=1
        if np.max(np.abs(A_sorted[i,:] - B_sorted[ii-1,:])) < 1e-9:
            a.add(i)
        if np.max(np.abs(A_sorted[i,:] - B_sorted[ii,:])) < 1e-9:
             a.add(i)
    return a,A_sorted

a,A_sorted =  get_simmilar(A, B)
A_sorted[list(a),:]

相关问题