python 如果numpy数组中的任何项已经出现在前一个数组中,则删除其中的子数组

9q78igpj  于 2023-04-19  发布在  Python
关注(0)|答案(3)|浏览(108)

我有一个二维的numpy数组。我需要过滤掉重复的内容--如果一行中的任何一项在前一行中,那么它就被认为是重复的。

#i.e.:
arr =
array([[4580, 4581, 4657, 4658],
       [4580, 4581, 4657, 4659], #-> duplicate because of 4580
       [4650, 4652, 4654, 4655],
       [4651, 4655, 4652, 4656]]) #-> duplicate because of 4652

#Output should be: 
array([[4580, 4581, 4657, 4658],
       [4650, 4652, 4654, 4655]])

下面的脚本给出了我对小输入的预期输出。然而,它在大数组上会阻塞。我相信有一种更简单,更有效的方法来做到这一点,但我似乎找不到它。

check = np.array([not(np.in1d(a, np.unique(arr[:i])).any()) for i,a in enumerate(arr)])
arr[check]
qvtsj1bj

qvtsj1bj1#

可以使用np.unique查找数组的唯一元素。传递参数return_index=True返回唯一元素第一次出现的索引。请注意,由于unique隐式展平数组,因此这些值是 flattened 数组中的索引。

unique_elems, unique_indices = np.unique(arr, return_index=True)
# array([4580, 4581, 4650, 4651, 4652, 4654, 4655, 4656, 4657, 4658, 4659]),
# array([ 0,  1,  8, 12,  9, 10, 11, 15,  2,  3,  7], dtype=int64)

现在,我们要选择任何行,其中 * 所有 * 它的元素都在unique_indices数组中。首先,让我们创建一个数组,将扁平数组中的元素的索引Map到它在arr中的位置:

mapping = np.arange(arr.size).reshape(arr.shape)

现在,让我们看看unique_indices中有哪些索引:

select_elem = np.isin(mapping, unique_indices)
# array([[ True,  True,  True,  True],
#        [False, False, False,  True],
#        [ True,  True,  True,  True],
#        [ True, False, False,  True]])

最后,只选择select_elem的行,它们都是True

select_rows = select_elem.all(axis=1)
# array([ True, False,  True, False])

使用它来索引数组,我们得到了想要的结果:

result = arr[select_rows]
# array([[4580, 4581, 4657, 4658],
#        [4650, 4652, 4654, 4655]])

以下是性能随输入大小的变化:

Timeless's方法与您的方法(毫不奇怪)几乎相同,因为它具有相同的瓶颈,即在python中对数组进行迭代。我上面展示的pure-numpy方法运行速度明显更快。我没有对Yossi's方法计时,因为它给出了错误的结果。

xienkqul

xienkqul2#

以下是一个与您的方法一致的方法:

out = np.concatenate([arr[i:i+1] for i, a in enumerate(arr)
                      if not np.isin(a, np.unique(arr[:i])).any()], axis=0)

输出:

print(out)

[[4580 4581 4657 4658]
 [4650 4652 4654 4655]]
hyrbngr7

hyrbngr73#

import numpy as np
arr = np.array([[4580, 4581, 4657, 4658],
       [4580, 4581, 4657, 4659], #-> duplicate because of 4580
       [4650, 4652, 4654, 4655],
       [4651, 4655, 4652, 4656]]) #-> duplicate because of 4652

tot_mask = np.zeros_like(arr)
tmp_mask = np.zeros_like(arr)
arr_shifted = arr
for ii, cur_col in enumerate(range(arr.shape[0]-2)):
    arr_shifted = np.roll(arr_shifted,-1, axis=0)[:-1,:]
    for cur_col in range(arr.shape[1]):
        mask = arr[:-(ii+1),:] == arr_shifted
        tmp_mask[ii+1:,:] = mask
        tot_mask = np.logical_or(tot_mask, tmp_mask)
        tmp_mask = np.zeros_like(arr)
        arr_shifted = np.roll(arr_shifted,-1, axis=1)
any_eq = np.logical_not(np.any(tot_mask, 1))

return arr[any_eq,:]

输出:

[[4580 4581 4657 4658]
 [4650 4652 4654 4655]]

EDITED:由于pranav注解。其想法是将所有行向上滚动一行,并在列维中迭代滚动。如果任何数字与上排的数字之一相同,则忽略该行。
修复了检查所有先前行的实现。

相关问题