numpy 如何使用数组索引第二个数组的最后一个dim

4xrmg8kj  于 2023-08-05  发布在  其他
关注(0)|答案(3)|浏览(107)

假设X.shape = (h, w, d)Y.shape = (h, w)包含range(d)中的值(索引)。
如何使用Y中的索引获取X中的元素(来自最后一个dim)?
也就是说,我想做一些类似X[Y]的事情,它将返回一个h x w数组,其中Y被用作最后一个dim d中的索引。

zpqajqem

zpqajqem1#

我想你正在寻找numpy.take_沿着_axis:

import numpy as np
rng = np.random.default_rng()

h, w, d = 2, 3, 4
x = rng.random((h, w, d))
y = rng.integers(0, d, (h, w))

x, y, np.take_along_axis(x, y[..., None], axis=-1)[..., 0]

字符串
输出量:

(array([[[0.51705108, 0.68891581, 0.84475703, 0.77938839],
         [0.02115493, 0.47689898, 0.19786926, 0.73959225],
         [0.40821923, 0.0119006 , 0.89595898, 0.81798467]],
 
        [[0.60350791, 0.11501983, 0.15932539, 0.35923036],
         [0.27939872, 0.13691148, 0.47528086, 0.71320657],
         [0.98294212, 0.75039413, 0.06087527, 0.68233282]]]),
 array([[0, 2, 0],
        [3, 0, 0]]),
 array([[0.51705108, 0.19786926, 0.40821923],
        [0.35923036, 0.27939872, 0.98294212]]))


Ulises' answer比较:

assert np.array_equal(np.take_along_axis(x, y[..., None], axis=-1)[..., 0], x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape))
%timeit np.take_along_axis(x, y[..., None], axis=-1)[..., 0]
%timeit x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape)


输出量:

9.42 µs ± 91.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
55.8 µs ± 463 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

0s0u357o

0s0u357o2#

编辑

选择元素的最短方法是交换x中的尺寸,然后使用choose选择y值:

np.choose(y, x.transpose(2,0,1))

字符串

旧方案

假设你有xy,你可以使用y创建一个分类数组,并使用它来选择x
示例:

#create arrays
h,w,d = 3,3,4
x = np.arange(h*w*d).reshape(h,w,d)
y = np.random.randint(d,size=(h,w))

#categorical y
cat_y =  np.eye(d)[y]

# select elements from x
x_sel = x[cat_y.astype(bool)]

#reshape to original form
x_out = x_sel.reshape(y.shape)


或者一行:

xx = x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape)

sxissh06

sxissh063#

X[np.arange(h)[:,None], np.arange(w), Y]

字符串
我们的想法是用broadcastY的数组索引前两个dim。(h,1), (w,), (h,w)这是所有高级索引的一般原则。这是我们在添加take_along之前必须做的。
ix_meshgridogrid也可用于制造这些阵列。

相关问题