numpy 如何获取至少两个连续值都大于某个阈值的指数?

i86rm4rw  于 2022-11-10  发布在  其他
关注(0)|答案(8)|浏览(93)

例如,让我们考虑下面的NumPy数组:

[1, 5, 0, 5, 4, 6, 1, -1, 5, 10]

另外,让我们假设阈值等于3。这就是说,我们正在寻找至少有两个连续的值都在阈值以上的序列。
输出将是这些值的
索引**,在我们的示例中为:

[[3, 4, 5], [8, 9]]

如果输出数组是扁平化的,那也可以!

[3, 4, 5, 8, 9]

输出说明

在我们的初始数组中,我们可以看到,对于index = 1,我们有一个值5,它大于阈值,但不是每个值都大于阈值的序列(至少包含两个值)的一部分。这就是为什么这个指数不会出现在我们的产出中。
另一方面,对于索引[3, 4, 5],我们有一系列(至少两个)相邻值[5, 4, 6],其中每个值都高于阈值,这就是它们的索引包含在最终输出中的原因!

我的代码到目前为止

我对这个问题的看法是这样的:

(arr > 3).nonzero()

上面的命令收集阈值以上的所有项目的索引。然而,我不能确定它们是否是连续的。我曾想过对上述代码片段的结果尝试diff,然后可能会定位其中的一个(也就是说,索引一个接一个)。这将给我们带来:

np.diff((arr > 3).nonzero())

但我还是会错过这里的一些东西。

w6mmgewl

w6mmgewl1#

如果将一个布尔数组与一个大小为win_size([1] * win_size)的充满1的窗口进行卷积,则将获得一个值为win_size的数组,其中win_size项的条件成立:

import numpy as np

def groups(arr, *, threshold, win_size, merge_contiguous=False, flat=False):
    conv = np.convolve((arr >= threshold).astype(int), [1] * win_size, mode="valid")
    indexes_start = np.where(conv == win_size)[0]
    indexes = [np.arange(index, index + win_size) for index in indexes_start]

    if flat or merge_contiguous:
        indexes = np.unique(indexes)
        if merge_contiguous:
            indexes = np.split(indexes, np.where(np.diff(indexes) != 1)[0] + 1)
    return indexes

arr = np.array([1, 5, 0, 5, 4, 6, 1, -1, 5, 10])
threshold = 3
win_size = 2

print(groups(arr, threshold=threshold, win_size=win_size))
print(groups(arr, threshold=threshold, win_size=win_size, merge_contiguous=True))
print(groups(arr, threshold=threshold, win_size=win_size, flat=True))
[array([3, 4]), array([4, 5]), array([8, 9])]
[array([3, 4, 5]), array([8, 9])]
[3 4 5 8 9]
gcmastyq

gcmastyq2#

您可以使用简单的NumPy操作来做您想做的事情

import numpy as np

arr = np.array([1, 5, 0, 5, 4, 6, 1, -1, 5, 10])

arr_padded = np.concatenate(([0], arr, [0]))
a = np.where(arr_padded > 3, 1, 0)

da = np.diff(a)

idx_start = (da == 1).nonzero()[0]
idx_stop = (da == -1).nonzero()[0]

valid = (idx_stop - idx_start >= 2).nonzero()[0]

result = [list(range(idx_start[i], idx_stop[i])) for i in valid]
print(result)

说明

数组a是原始数组的填充二进制版本,其中1的原始元素大于3。da包含1,其中“岛”以a开始,-1,“岛”以a结束。由于填充,保证在da中有相同数量的1-1。提取它们的指数,我们就可以计算出这些岛屿的长度。有效的索引对是那些各自“岛”的长度大于等于2的索引对。然后,只需生成有效“岛”的索引边界之间的所有数字即可。

67up9zun

67up9zun3#

我遵循你最初的想法。你就快做完了。
我使用另一个diff2来选取序列中第一个值的索引。有关详细信息,请参见代码中的注解。

import numpy as np

arr = np.array([ 1,  5,  0,  5,  4,  6,  1, -1,  5, 10])
threshold = 3

all_idx = (arr > threshold).nonzero()[0]

# array([1, 3, 4, 5, 8, 9])

result = np.empty(0)
if all_idx.size > 1:
    diff1 = np.zeros_like(all_idx)
    diff1[1:] = np.diff(all_idx)
    # array([0, 2, 1, 1, 3, 1])
    diff1[0] = diff1[1]
    # array([2, 2, 1, 1, 3, 1])
    #**Positions with a value 1 in diff1 should be reserved.**

    # But we also want the position before each 1. Create another diff2
    diff2 = np.zeros_like(all_idx)
    diff2[:-1] = np.diff(diff1)
    # array([ 2, -1,  0,  2, -2,  0])
    #**Positions with a negative value in diff2 should be reserved.**

    result = all_idx[(diff1==1) | (diff2<0)]

print(result)

# array([3, 4, 5, 8, 9])
pxy2qtax

pxy2qtax4#

我将使用窗口视图尝试一些不同的方法,我不确定这是否一直有效,因此欢迎使用反例。它的优点是不需要使用Python循环。

import numpy as np
from numpy.lib.stride_tricks import sliding_window_view as window

def consec_thresh(arr, thresh):
    win = window(np.argwhere(arr > thresh), (2, 1))
    return np.unique(win[np.diff(win, axis=2).ravel() == 1, :,:].ravel())

它是怎么工作的?

因此,我们从数组开始,收集达到阈值的索引:

In [180]: np.argwhere(arr > 3)
Out[180]:
array([[1],
       [3],
       [4],
       [5],
       [8],
       [9]])

然后,我们构建一个滑动窗口,该窗口沿该列由两个值组成(这就是该窗口的(2, 1)形状的原因)。

In [181]: window(np.argwhere(arr > 3), (2, 1))
Out[181]:
array([[[[1],
         [3]]],

       [[[3],
         [4]]],

       [[[4],
         [5]]],

       [[[5],
         [8]]],

       [[[8],
         [9]]]])

现在我们想要取每一对内部的差异,如果它是一,那么指数是连续的。

In [182]: np.diff(window(np.argwhere(arr > 3), (2, 1)), axis=2)
Out[182]:
array([[[[2]]],

       [[[1]]],

       [[[1]]],

       [[[3]]],

       [[[1]]]])

我们可以将这些值插入到上面创建的窗口中,

In [185]: window(np.argwhere(arr > 3), (2, 1))[np.diff(window(np.argwhere(arr > 3), (2, 1)), axis=2).ravel() == 1, :, :]
Out[185]:
array([[[[3],
         [4]]],

       [[[4],
         [5]]],

       [[[8],
         [9]]]])

然后我们可以拆分(如果可能的话,平坦化而不复制),我们必须消除窗口创建的重复索引,所以我调用np.unique。我们再一次拆分,得到:

array([3, 4, 5, 8, 9])
gev0vcfq

gev0vcfq5#

下面的迭代代码应该有助于降低O(n)的复杂性

arr = [1, 5, 0, 5, 4, 6, 1, -1, 5, 10]

threshold = 3
sequence = 2

output = []
temp_arr = []        

for i in range(len(arr)):
    if arr[i] > threshold:
        temp_arr.append(i)
    else:
        if len(temp_arr) >= sequence:
            output.append(temp_arr)
        temp_arr = []
if len(temp_arr):
    output.append(temp_arr)
    temp_arr = []

print(output)

# Output

# [[3, 4, 5], [8, 9]]
fslejnso

fslejnso6#

我建议使用带有两个索引的for循环。您将有一个从j=1开始的索引,另一个从i=0开始的索引,两者都前进1。然后,您可以询问这两个值是否都大于阈值,如果是,则将索引添加到列表中,并继续向前移动j,直到阈值或.Next()不大于阈值。

values = [1, 5, 0, 5, 4, 6, 1, -1, 5, 10]
res=[]
threshold= 3

i=0
j=0
for _ in values:
  j=i+1
  lista=[]
  try:
      print(f"i: {i} j:{j}")
      # check if condition is met
      if(values[i] > threshold and values[j] > threshold):
        lista.append(i)
        # add sequence 
        while values[j] > threshold:
          lista.append(j)
          print(f"j while: {j}")
          j+=1
          if(j>=len(values)):
            break
        res.append(lista)

      i=j
      if(j>=len(values)): 
        break
  except:
    print("ex")

这很管用。但需要重构

dtcbnfnu

dtcbnfnu7#

让我们试一试以下代码:


# Simple is better than complex

# Complex is better than complicated

arr = [1, 5, 0, 5, 4, 6, 1, -1, 5, 10]

arr_3=[i if arr[i]>3 else 'a' for i in range(len(arr))]

arr_4=''.join(str(x) for x in arr_3)

i=0

while i<len(arr_5):
    if len(arr_5[i]) <=1:
        del arr_5[i]
    else:
        i+=1

arr_6=[list(map(lambda x: int(x), list(x))) for x in arr_5]

print(arr_6)

输出:

[[3, 4, 5], [8, 9]]
ejk8hzay

ejk8hzay8#

这里有一个利用PandasSeries的解决方案:

thresh = 3
win_size = 2

s = pd.Series(arr)

# locating groups of values where there are at least (win_size) consecutive values above the threshold

groups = s.groupby(s.le(thresh).cumsum().loc[s.gt(thresh)]).transform('count').ge(win_size)

0    False
1    False
2    False
3     True
4     True
5     True
6    False
7    False
8     True
9     True
dtype: bool

我们现在可以轻松地在一维数组中获取它们的索引:

np.flatnonzero(groups)

# array([3, 4, 5, 8, 9], dtype=int64)

或多个列表:

[np.arange(index.start, index.stop) for index in np.ma.clump_unmasked(np.ma.masked_not_equal(groups.values, value=True))]

# [array([3, 4, 5], dtype=int64), array([8, 9], dtype=int64)]

相关问题