在2D numpy数组中查找数字块

mtb9vblg  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(113)

首先构建以下numpy数组:

import numpy as np
import itertools as it

a = np.array([[j for j in range(1,5)] for i in range(3)])

table = np.array(list(it.product(*a)))

看起来像是

[[1 1 1]
 [1 1 2]
 [1 1 3]
 ....

也就是数组[1,2,3,4]的三个副本的乘积。接下来,如果对于任何一行,列1和2(不是3)包含3,我将下一列中的所有元素设置为3(例如,如果一行读取[3 1 1],则变为[3 3 3]):

for row in table:
    if len(np.where(row==3)[0])>0:
      row[np.where(row==3)[0][0]:]=3

接下来,我必须通过存储四个值来识别包含3的块:行,列在它们开始的地方,行,列在它们结束的地方。为此,我使用了另一个非Python的for循环:

nr, nc = table.shape
blocks = []
bh = nr
for i in range(nc-1):
  bh = (int)(bh/4) 
  for j in range((int)(nr/bh)):
    temp = table[j*bh:(j+1)*bh,i:]
    if temp[0,0]==3 and temp[-1,-1]==3:
      blocks.append([[j*bh+1,i+1],[j*bh+bh,nc]])
      temp[:,:]=0

并得到,

[[[33, 1], [48, 3]], [[9, 2], [12, 3]], [[25, 2], [28, 3]], [[57, 2], [60, 3]]]

显然,这是一种非常低效的使用numpy的方式。有人能提出一个更pythonic的方法来做到这一点吗?

oug3syen

oug3syen1#

我不能摆脱所有的循环,但至少删除了嵌套的。

a = np.array([[j for j in range(1,5)] for i in range(3)])
table = np.array(list(it.product(*a)))

def mask_approach(table):
  bool_mask = table == 3
  rows_of_3s, columns_of_3s = np.where(bool_mask == True)

  
  output = []
  flag = False
  stop_iter = len(rows_of_3s)-1
  for i, v in enumerate(rows_of_3s):
    if i == stop_iter:
      break        
    if rows_of_3s[i+1] - v <= 1:
      if flag == False:
        start = (v, columns_of_3s[i])
        flag = True
    elif flag == True:
      flag = False
      output.append((start, (v, columns_of_3s[i])))
return output

我利用掩码,只使用布尔数组,因为它包含解决这个问题所需的所有信息。
我的方法:

%%timeit
mask_approach(table)

22.3 µs ± 485 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

您的方法:

%%timeit
ted_black_approach(table)

171 µs ± 9.96 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

由于解决方案肯定不是完美的,我想它引入了将问题简化为更容易的问题的方法,并且它可以在进一步优化中使用。

相关问题