考虑下面显示的代码。它产生a
具有shape (1, 4, 2, 2)
,b
具有shape (4, 1, 2, 2)
,c
广播到shape (4, 4, 2, 2)
的数组。
import numpy as np
a = np.random.randint(0, 10, size=(4, 2, 2))
b = np.random.randint(0, 10, size=(4, 2, 2))
a = a[np.newaxis, :]
b = b[:, np.newaxis]
c = a+b
字符串
问题:如前所述,c
包含162x2数组。但是,我只想要当a
沿axis=1
的索引沿着axis=0
的索引比b
沿沿着axis=0
的索引大时形成的62x2数组。这些将被保存在shape (6, 2, 2)
的3d数组中。如何才能做到这一点?
(the对于greater、less和equal情况,2x2数组计数将是6 + 6 + 4 = 16,这是我的代码片段当前生成的情况)
1条答案
按热度按时间4zcjmb1e1#
我是这么做的。如果你是对的,我会让你核实数字。
字符串
6个较低的三指数:
型
高级索引:
型