我正在尝试编写一个函数,将numpy ndarray减少到给定的形状,这有效地“取消广播”数组。例如,使用add作为通用函数,我希望得到以下结果:
- A = [1,2,3,4,5],reduced_shape =(1,)-> [1+2+3+4+5] = [15](轴上的总和=0)
- A = [[1,2],[1,2]],reduced_shape =(2,)-> [1+1,2+2] = [2,4](轴上的总和=0)
- A = [[1,2],[1,2]],reduced_shape =(1,)-> [1+2+1+2] = [6](轴上的总和=(0,1))
- A = [1,2],[1,2]],[[1,2],[1,2],reduced_shape =(2,2)-> [[1+1,2+2],[1+1,2+2]] = [[2,4],[2,4]](轴上和=0)
这是我想出的解决方案:
def unbroadcast(A, reduced_shape):
fill = reduced_shape[-1] if reduced_shape else None
reduced_axes = tuple(i for i, (a,b) in enumerate(itertools.zip_longest(A, shape, fillvalue=fill)) if a!=b)
return np.add.reduce(A, axis=reduced_axes).reshape(shape)
字符串
但是感觉不必要的复杂,有没有依赖Numpy的公共API的方法来实现它?
2条答案
按热度按时间kknvjkwl1#
目前尚不清楚这是如何“非广播”的。
直接的计算方法是使用
sum
的axis
参数:字符串
我不太使用
reduce
,但看起来axis也是这样:型
但是我还没有弄清楚从
reduced_shape
到必要的axis
值的逻辑。对于像(2,)和(2,2,2)这样的形状,当你说将形状缩小到(2,2)时,可能会有歧义。如果使用np.arange(24).reshape(2,3,4)
这样的示例数组,可能会更清楚zfciruhq2#
如果从最右边的dims开始迭代,您可以得到一个更通用的解决方案。这就是numpy在广播时所做的:https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules。
只需添加
itertools.zip_longest(reversed(list(A)), reversed(list(shape)))