Numpy索引,使用掩码挑选2D数组的特定条目

uqxowvwt  于 2023-04-21  发布在  其他
关注(0)|答案(1)|浏览(102)

假设我有两个数组:

x     # shape(n, m)
mask  # shape(n), where each entry is a number between 0 and m-1

我的目标是使用mask来挑选x的条目,这样结果的形状为n

out[i] = x[i, mask[i]]

这可以很容易地使用for循环进行编码

out = np.zeros(n)
for i in range(n):
    out[i] = x[i, mask[i]]

我想用numpy矢量化一下。有什么想法吗?

wgx48brx

wgx48brx1#

您可以使用高级索引:

import numpy as np

n, m = 6, 6

x = np.arange(n * m).reshape(n, m)
mask = np.random.randint(m, size=n)

out = x[np.arange(n), mask]
>>> x
array([[ 0,  1,  2,  3,  4,  5],
       [ 6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17],
       [18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34, 35]])
>>> mask
array([2, 4, 2, 2, 5, 3])
>>> out
array([ 2, 10, 14, 20, 29, 33])

相关问题