有没有更快的方法来匹配numpy数组中的公共元素?

vxbzzdmp  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(86)

给定一个长度为n的二维字符串数组,我需要返回一个n x n的数组,它在自身上进行简单的匹配操作。匹配操作是检查子数组是否完全相同(返回2),共享一个公共元素(返回1),或者两者都不相同(返回0)。这被实现为:

def match(a, b):
    a_set = set(a)
    b_set = set(b)
    if a_set == b_set:
        return 2
    elif a_set & b_set:
        return 1
    else:
        return 0

字符串
示例:

arr = np.array([['a', 'b'], ['a', 'b'], ['a', 'c'], ['d', 'e']])
array([['a', 'b'],
   ['a', 'b'],
   ['a', 'c'],
   ['d', 'e']], dtype='<U1')


应该返回:

[[2, 2, 1, 0]
[2, 2, 1, 0]
[1, 1, 2, 0]
[0, 0, 0, 2]]


我目前的解决方案工作正常,但扩展性不是很好。寻找可以在几秒钟内在50k元素上快速运行的东西。

np.reshape(np.array([match(i, j) for i, j in it.product(arr, repeat=2)]),(len(arr),len(arr)))

2mbi3lxu

2mbi3lxu1#

可能的解决方案:

(arr == arr[:, None]).sum(axis=2)

字符串
输出量:

array([[2, 2, 1, 0],
       [2, 2, 1, 0],
       [1, 1, 2, 0],
       [0, 0, 0, 2]])

相关问题