python Numpy:在稀疏数组中查找序列,忽略NaN

lh80um4z  于 2023-08-02  发布在  Python
关注(0)|答案(1)|浏览(100)

我有一个大的1D数组,主要包含NaN和一些整数。我试图提取数组包含特定序列的开始和结束索引,忽略中间的NaN。
举例来说:

sequence = np.array([1, 2, 3])
sparse_array1 = np.array([np.nan, 2, 3, np.nan, 1, 2, np.nan, 3, 2, np.nan, np.nan, 3, 1, np.nan, np.nan, np.nan, 2, np.nan, 3, 1, 1, np.nan])
sparse_array2 = np.array([3, 2, 1, 2, 1])

print(find_sequence_indices(sparse_array1, sequence))
# prints [(4, 7), (12, 18)]

print(find_sequence_indices(sparse_array2, sequence))
# prints []

字符串
我想不出一个不涉及3个嵌套循环的方法。我的数组比示例大得多,我无法承受立方运行时间。

ffx8fchx

ffx8fchx1#

使用sliding_window_view、索引和numpy.where

from numpy.lib.stride_tricks import sliding_window_view as swv

def find_sequence_indices(sparse_array, sequence):
    N = len(sequence)
    m = np.isnan(sparse_array)

    # get only indices of non-NA
    idx = np.where(~m)[0]
    # find position of sequence
    idx2 = (swv(sparse_array[idx], N) == sequence).all(axis=1)
    # get start/end
    return list(zip(idx[:1-N][idx2], idx[N-1:][idx2]))

find_sequence_indices(sparse_array1, sequence)
# [(4, 7), (12, 18)]

find_sequence_indices(sparse_array2, sequence)
# []

字符串
中间体(对于sparse_array1):

# idx
array([ 1,  2,  4,  5,  7,  8, 11, 12, 16, 18, 19, 20])

# sparse_array1[idx]
array([2., 3., 1., 2., 3., 2., 3., 1., 2., 3., 1., 1.])

# swv(sparse_array1[idx], N)
array([[2., 3., 1.],
       [3., 1., 2.],
       [1., 2., 3.],  # match, index 4
       [2., 3., 2.],
       [3., 2., 3.],
       [2., 3., 1.],
       [3., 1., 2.],
       [1., 2., 3.],  # match, index 12
       [2., 3., 1.],
       [3., 1., 1.]])

# (swv(sparse_array1[idx], N) == sequence).all(axis=1)
array([False, False,  True, False, False, False, False,  True, False,
       False])

相关问题