Numpy将数组减少到给定的形状

qni6mghb  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(98)

我正在尝试编写一个函数,将numpy ndarray减少到给定的形状,这有效地“取消广播”数组。例如,使用add作为通用函数,我希望得到以下结果:

  • A = [1,2,3,4,5],reduced_shape =(1,)-> [1+2+3+4+5] = [15](轴上的总和=0)
  • A = [[1,2],[1,2]],reduced_shape =(2,)-> [1+1,2+2] = [2,4](轴上的总和=0)
  • A = [[1,2],[1,2]],reduced_shape =(1,)-> [1+2+1+2] = [6](轴上的总和=(0,1))
  • A = [1,2],[1,2]],[[1,2],[1,2],reduced_shape =(2,2)-> [[1+1,2+2],[1+1,2+2]] = [[2,4],[2,4]](轴上和=0)

这是我想出的解决方案:

def unbroadcast(A, reduced_shape):
    fill = reduced_shape[-1] if reduced_shape else None
    reduced_axes = tuple(i for i, (a,b) in enumerate(itertools.zip_longest(A, shape, fillvalue=fill)) if a!=b)
    return np.add.reduce(A, axis=reduced_axes).reshape(shape)

字符串
但是感觉不必要的复杂,有没有依赖Numpy的公共API的方法来实现它?

kknvjkwl

kknvjkwl1#

目前尚不清楚这是如何“非广播”的。
直接的计算方法是使用sumaxis参数:

In [124]: np.array([1,2,3,4,5]).sum()
Out[124]: 15
In [125]: np.array([[1,2],[1,2]]).sum(axis=0)
Out[125]: array([2, 4])
In [126]: np.array([[1,2],[1,2]]).sum(axis=(0,1))
Out[126]: 6
In [128]: np.array([[[1,2], [1,2]], [[1,2], [1,2]]]).sum(axis=0)
Out[128]: 
array([[2, 4],
       [2, 4]])

字符串
我不太使用reduce,但看起来axis也是这样:

In [130]: np.add.reduce(np.array([1,2,3,4,5]))
Out[130]: 15
In [132]: np.add.reduce(np.array([[[1,2], [1,2]], [[1,2], [1,2]]]),axis=0)
Out[132]: 
array([[2, 4],
       [2, 4]])


但是我还没有弄清楚从reduced_shape到必要的axis值的逻辑。对于像(2,)和(2,2,2)这样的形状,当你说将形状缩小到(2,2)时,可能会有歧义。如果使用np.arange(24).reshape(2,3,4)这样的示例数组,可能会更清楚

zfciruhq

zfciruhq2#

如果从最右边的dims开始迭代,您可以得到一个更通用的解决方案。这就是numpy在广播时所做的:https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules。
只需添加itertools.zip_longest(reversed(list(A)), reversed(list(shape)))

相关问题