numpy数组沿着内轴的多维索引

q5lcpyga  于 2023-06-23  发布在  其他
关注(0)|答案(4)|浏览(131)
  • 我有一个numpy数组x,形状为[4, 5, 3]
  • 我有一个索引i的二维数组,其形状为[4, 3],指的是x中沿着维度1(长度为5)的索引
  • 我想从x中提取一个形状为[4, 3]的子数组y,使得y[j, k] == x[j, i[j, k], k]
  • 我该怎么做?
50pmv0ei

50pmv0ei1#

我认为正确答案如下:

y = x[np.arange(4).reshape(4, 1), i, np.arange(3).reshape(1, 3)]

示例:

import numpy as np

rng = np.random.default_rng(0)

x = np.arange(4 * 5 * 3)
rng.shuffle(x)
x = x.reshape(4, 5, 3)
i = rng.integers(5, size=[4, 3])

y = x[np.arange(4).reshape(4, 1), i, np.arange(3).reshape(1, 3)]

print("x:", x, "i:", i, "y:", y, sep="\n")

输出:

x:
[[[16 27 20]
  [ 8 42 34]
  [51  4 52]
  [57 10  2]
  [44 23 24]]

 [[43 11 35]
  [30 18 54]
  [ 3  1 55]
  [17 21 36]
  [ 0 28  6]]

 [[19 48 22]
  [26 37 46]
  [58 32 25]
  [53  9 38]
  [47 50 40]]

 [[13 12  7]
  [45 39 59]
  [ 5 49 14]
  [29 41 56]
  [33 15 31]]]
i:
[[1 3 4]
 [0 0 3]
 [1 2 0]
 [4 2 4]]
y:
[[ 8 10 24]
 [43 11 36]
 [26 32 22]
 [33 49 31]]

(橡皮鸭调试FTW)

n3ipq98p

n3ipq98p2#

使用np.take_along_axis

y = np.take_along_axis(x, i[:, None, :], axis=1)[:, 0, :]
gj3fmq9x

gj3fmq9x3#

使用np.ix_的通用方法,可应用于任何axis

axis = 1
t = np.ix_(*(range(s) for s in [*x.shape[:axis], *x.shape[axis+1:]]))
y = x[*t[:axis], i, *t[axis:]]

一个内存消耗稍少但特别的替代方案:

y = x[[[v] for v in range(x.shape[0])], i, range(x.shape[2])]
cbjzeqam

cbjzeqam4#

这是一种方法:

y = x[np.arange(4)[:, None], i, np.arange(3)]

相关问题