在两个不同大小的numpy数组之间选择最接近的元素

hgncfbus  于 2023-10-19  发布在  其他
关注(0)|答案(3)|浏览(85)

我有两个不同大小的numpy数组。其中一个“a”包含int值,而另一个(较大的)np数组“B”包含浮点值,其中每个元素/值有3-4个值。

a = np.random.randint(low = 1, high = 100, size = (7))

a
# array([35, 11, 48, 20, 13, 31, 49])

b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])

a.shape, b.shape
# ((7,), (19,))

这个想法是找到“B”中的值与“a”中的每个唯一值在最近距离方面匹配,我使用abs值计算。要使用单个元素'a'或使用第一个元素'a'来执行此操作,请执行以下操作:

# Compare first element of a with all elements of b-
np.abs(a[0] - b).argsort()
'''
array([ 3,  2,  1,  0, 14, 13, 15, 16, 17, 18,  9,  8,  7, 12, 11, 10,  4,
        6,  5])
'''

# b[3]
# 34.99

# np.abs(a[0] - b).argsort()[0]
# 3

b[np.abs(a[0] - b).argsort()[0]]
# 34.99

因此,“B”(B[3])中的第4个元素是与a[0]最接近的匹配。
为了计算'a'中的所有值,我使用一个循环:

for e in a:
    idx = np.abs(e - b).argsort()
    print(f"{e} has nearest match = {b[idx[0]]:.4f}")
'''
35 has nearest match = 34.9900
11 has nearest match = 11.2890
48 has nearest match = 48.1000
20 has nearest match = 20.0500
13 has nearest match = 12.8700
31 has nearest match = 31.0300
49 has nearest match = 49.2000
'''

如果没有slow for循环,我怎么能做到这一点呢?

注:a.shape = 1400和B.shape = 150万(近似值)

yzxexxkh

yzxexxkh1#

如果你有很多值需要检查,你也可以尝试使用kdTree。对于很少的值,构建树的开销不会使其值得,但是对于大n,特别是在多维空间中搜索最近邻,它比计算所有对的距离要快得多:

import numpy as np
from scipy import spatial

a = np.random.randint(low = 1, high = 100, size = (7))

b = np.array([34.78, 34.8, 35.1, 34.99, 11.3, 10.7, 11.289, 18.78, 19.1, 20.05, 12.32, 12.87, 13.5, 31.03, 31.15, 29.87, 48.1, 48.5, 49.2])

kd_tree = spatial.KDTree(np.expand_dims(b, 1))
# This returns distance to the nearest neighbor d
# and position of the nearest neighbor i
d,i = kd_tree.query(np.expand_dims(a, 1), k=[1])
ekqde3dh

ekqde3dh2#

b[np.argmin(abs(a[:, None] - b[None, :]), axis=1)]

让我们来分解一下:

  • a[:, None] - b[None, :]在两个不同的方向上扩展ab。这使得a的所有元素与b的所有元素不同,并将它们存储到大小为len(a) * len(b)的2D数组中。这是替换for循环的部分
  • np.argmin(..., axis=1)为每行选择最小值。每个值都以range(0, len(b))为单位
  • b[...]然后选择值
rur96b6h

rur96b6h3#

你可以使用广播来计算所有对的绝对差,numpy.argmin来获得每行最小值的索引(具有O(n)复杂度,而排序是O(n*logn)):

out = b[np.argmin(abs(a[:,None]-b), axis=1)]

输出:array([34.99 , 11.289, 48.1 , 20.05 , 12.87 , 31.03 , 49.2 ])
要知道,两个数组的所有组合都将被一次性计算,中间数组的大小为len(a) * len(b),请确保您有足够的内存,因为ab都很大。
中间值abs(a[:,None]-b)(为了更好的显示,四舍五入到小数点后1位)。行对应于a的元素,列对应于b的元素:

array([[ 0.2,  0.2,  0.1,  0. , 23.7, 24.3, 23.7, 16.2, 15.9, 15. , 22.7,
        22.1, 21.5,  4. ,  3.9,  5.1, 13.1, 13.5, 14.2],
       [23.8, 23.8, 24.1, 24. ,  0.3,  0.3,  0.3,  7.8,  8.1,  9. ,  1.3,
         1.9,  2.5, 20. , 20.2, 18.9, 37.1, 37.5, 38.2],
       [13.2, 13.2, 12.9, 13. , 36.7, 37.3, 36.7, 29.2, 28.9, 28. , 35.7,
        35.1, 34.5, 17. , 16.8, 18.1,  0.1,  0.5,  1.2],
       [14.8, 14.8, 15.1, 15. ,  8.7,  9.3,  8.7,  1.2,  0.9,  0.1,  7.7,
         7.1,  6.5, 11. , 11.1,  9.9, 28.1, 28.5, 29.2],
       [21.8, 21.8, 22.1, 22. ,  1.7,  2.3,  1.7,  5.8,  6.1,  7. ,  0.7,
         0.1,  0.5, 18. , 18.2, 16.9, 35.1, 35.5, 36.2],
       [ 3.8,  3.8,  4.1,  4. , 19.7, 20.3, 19.7, 12.2, 11.9, 11. , 18.7,
        18.1, 17.5,  0. ,  0.1,  1.1, 17.1, 17.5, 18.2],
       [14.2, 14.2, 13.9, 14. , 37.7, 38.3, 37.7, 30.2, 29.9, 29. , 36.7,
        36.1, 35.5, 18. , 17.8, 19.1,  0.9,  0.5,  0.2]])

用于索引b的中间体np.argmin(abs(a[:,None]-b), axis=1)

array([ 3,  6, 16,  9, 11, 13, 18])

相关问题