#create arrays
h,w,d = 3,3,4
x = np.arange(h*w*d).reshape(h,w,d)
y = np.random.randint(d,size=(h,w))
#categorical y
cat_y = np.eye(d)[y]
# select elements from x
x_sel = x[cat_y.astype(bool)]
#reshape to original form
x_out = x_sel.reshape(y.shape)
型 或者一行:
xx = x[ np.eye(d)[y].astype(np.bool)].reshape(y.shape)
3条答案
按热度按时间zpqajqem1#
我想你正在寻找numpy.take_沿着_axis:
字符串
输出量:
型
与Ulises' answer比较:
型
输出量:
型
0s0u357o2#
编辑
选择元素的最短方法是交换
x
中的尺寸,然后使用choose
选择y
值:字符串
旧方案
假设你有
x
和y
,你可以使用y
创建一个分类数组,并使用它来选择x
:示例:
型
或者一行:
型
sxissh063#
字符串
我们的想法是用
broadcast
和Y
的数组索引前两个dim。(h,1), (w,), (h,w)
这是所有高级索引的一般原则。这是我们在添加take_along
之前必须做的。ix_
、meshgrid
和ogrid
也可用于制造这些阵列。