如何在不使用循环的情况下找到2D numpy数组每行的第一个正元素位置

xxls0lw8  于 2023-11-18  发布在  其他
关注(0)|答案(2)|浏览(98)

All.给定一个2D numpy数组,如何在不使用循环的情况下找出每行第一个正元素的行和列索引。例如,给定数据

data = np.array([[1.0, -2.0, 3.0,3],
                 [0.0, 1.5, -2.0,5],
                 [-1.0, -2.0, -3.0,-5]])

字符串
如何找到每行每个正元素的索引。它期望得到2个1d索引作为答案,row_idx=[0,1]col_idx=[0,1]对应于数据中的元素[1,1.5]。任何人都可以在这方面有所启发,提前感谢!
我试过用

import numpy as np
print(data)
row_idx = np.any(data>0, axis=1)
col_idx = np.argmax(data>0,axis=1)
print('row indices: ', row_idx)
print('col indices: ', col_idx)
print(data[row_idx,col_idx])


但得不到正确答案。

xuo3flqw

xuo3flqw1#

np.any返回boolean掩码,而不是索引。您可以使用np.where直接获取索引。

import numpy as np

data = np.array([[1.0, -2.0, 3.0, 3],
                 [0.0, 1.5, -2.0, 5],
                 [-1.0, -2.0, -3.0, -5]])

col_idx = np.argmax(data > 0, axis=1)

rows_with_positive = np.any(data > 0, axis=1)

row_idx = np.where(rows_with_positive)[0]

col_idx = col_idx[rows_with_positive]

print('row indices: ', row_idx)
print('col indices: ', col_idx)
print(data[row_idx, col_idx])

字符串

wnvonmuf

wnvonmuf2#

你几乎就在那里了,你需要过滤掉argmax的结果,以删除没有值大于0的情况:

m1 = data>0
m2 = m1.any(axis=1)

row_idx = np.arange(data.shape[0])[m2]
col_idx = np.argmax(m1, axis=1)[m2]

print('row indices: ', row_idx)
print('col indices: ', col_idx)
print(data[row_idx,col_idx])

字符串
输出量:

row indices:  [0 1]
col indices:  [0 1]
[1.  1.5]

相关问题