检查3个numpy数组中的每一个是否都存在一个值,这些值都在x的区间内?

eimct9ow  于 2023-01-17  发布在  其他
关注(0)|答案(1)|浏览(73)

假设我有3个numpy数组,但也可以不止3个。

import numpy as np
INTERVAL = 2
array1 = np.array([1,5,10,15,20,25,30])
array2 = np.array([1,10,50,100,150,200,250,300])
array3 = np.array([3,8,12])

对于要在上述数组中匹配的给定元素集,每个元素必须在彼此的INTERVAL范围内。元素在数组中的实际索引位置在比较中并不重要。顺序不保证。它是指数组中任何位置的任何元素都在彼此的INTERVAL范围内。
将从上述3个数组返回的匹配示例:

Example#1
array1 : 1
array2 : 1
array3 : 3

Example#2
array1 : 10
array2 : 10
array3 : 8

Example#3
array1 : 10
array2 : 10
array3 : 12

加分:
如果同一元素可能有多个匹配项,则只返回总和最小的一个。例如,Example#2Example#3共享元素,但应返回Example#2Example#1,而不返回Example#3
你对我该怎么做有什么建议吗?

o8x7eapl

o8x7eapl1#

您可以使用np.meshgrid创建笛卡尔乘积(笛卡尔乘积代码为here)。然后,您可以使用np.where提取满足差异约束的组合:

product = np.array(np.meshgrid(array1, array2, array3)).T.reshape(-1,3)
result = product[np.where(product.max(axis=1) - product.min(axis=1) <= INTERVAL)]

repr(result)

这将输出:

array([[ 1,  1,  3],
       [10, 10,  8],
       [10, 10, 12]])

相关问题