具有不同维度的椭圆的Numpy.einsum

nkoocmlb  于 2022-11-10  发布在  其他
关注(0)|答案(2)|浏览(187)

我经常发现,我喜欢在两个数组的最后几个维度之间进行运算,其中第一个维度不一定匹配。作为一个例子,我想做一些类似的事情:

a = np.random.randn(10, 10, 3, 3)
b = np.random.randn(5, 3)
c = np.einsum('...ij, ,,,j -> ...,,,i', a, b)

其结果应满足c.shape = (10, 10, 5, 3)c[i, j, k] = a[i, j] @ b[k]。有没有办法用现有的界面来实现这一点?

qv7cva1a

qv7cva1a1#

In [82]: c = np.einsum('...ij,...j->...i', a, b)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Input In [82], in <cell line: 1>()
----> 1 c = np.einsum('...ij,...j->...i', a, b)

File <__array_function__ internals>:5, in einsum(*args,**kwargs)

File ~\anaconda3\lib\site-packages\numpy\core\einsumfunc.py:1359, in einsum(out, optimize, *operands,**kwargs)
   1357     if specified_out:
   1358         kwargs['out'] = out
-> 1359     return c_einsum(*operands,**kwargs)
   1361 # Check the kwargs to avoid a more cryptic error later, without having to
   1362 # repeat default values here
   1363 valid_einsum_kwargs = ['dtype', 'order', 'casting']

ValueError: operands could not be broadcast together with remapped shapes 
[original->remapped]: (10,10,3,3)->(10,10,3,3) (5,3)->(5,newaxis,3)

所以它试图使用broadcasting来匹配尺寸。
让我们做a(10,10,1,3,3)形状。这样,(10,10,1)部分用b的(5,)广播:

In [83]: c = np.einsum('...ij,...j->...i', a[:,:,None], b)
In [84]: c.shape
Out[84]: (10, 10, 5, 3)
jaql4c8m

jaql4c8m2#

最终通过以下助手函数绕过了这个问题

def batch_matvec(A, b):
    product = np.einsum('...ij, ...kj->...ki', A, b.reshape(-1, b.shape[-1]))
    return product.reshape((*A.shape[:-2], *b.shape[:-1], -1))

相关问题