在numpy数组上高效实现某个操作

bq9c1y66  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(94)

考虑下面的代码块

N = 28 * 28
X = rng.randn(10000, N)
n_groups = group_size = 28
Q = X[:10]
Z = X[:, None] * Q[None]  # line 4: multiply every row of Q by every row of X
Z = Z.reshape((len(X), len(Q), n_groups, group_size)).mean(axis=3)

字符串

**Question.**如何重新实现上面的代码片段,以输出Z的相同结果,但没有第4行的昂贵(内存方面等)操作?

我希望这可以通过某种原生Tensor或多维点积来实现。
先谢了。

jrcvhitl

jrcvhitl1#

在Numpy中执行此操作的快速标准方法是使用**einsum**:

X2 = X.reshape(len(X), n_groups, group_size)
Q2 = Q.reshape(len(Q), n_groups, group_size)
Z = np.einsum('ikl,jkl->ijk', X2, Q2, optimize=True) / group_size

字符串
这不仅明显更快,而且内存效率更高,因为不需要创建 * 临时数组 *。请注意,einsum在这种情况下不是最佳的,因为最后一个维度相当小,并且是顺序执行的。如果速度不够快,可以编写优化的并行Numba/Cython代码以获得更好的性能。
请注意,/ group_size可以应用于Q2,而不是np.einsum的结果,以获得更好的性能(因为Q2更小,这在数学上是等效的)。

基准测试

以下是我的i5- 9600 KF处理器使用Numpy 1.24.3的结果:

Initial implementation:  167 ms
Reinderien's solution:    70 ms
Naive einsum:             53 ms
Optimized einsum:         49 ms

v8wbuo2f

v8wbuo2f2#

Jérôme Richard的基本见解,即这是一个乘法和除法,是正确的,并且可能是最快的,而不绕过Numpy。
einsum是一个高度广义的矩阵乘积函数,你所拥有的“只是”最后两个维度上的标准矩阵乘积。为了说明,我证明这等价于以下内容:

import time

import numpy as np
from numpy.random import default_rng

n_groups = group_size = 28
N = n_groups * group_size
rng = default_rng(seed=0)
X = rng.random((10_000, N))
Q = X[:10, :]

a = time.perf_counter()

xx, qq = np.broadcast_arrays(
    X.reshape((len(X), 1, -1)),
    Q.reshape((1, len(Q), -1)),
)

Z = (
    xx.reshape((len(X), len(Q), n_groups, 1, group_size)) @
    qq.reshape((len(X), len(Q), n_groups, group_size, 1))
)[...,0,0] / n_groups

b = time.perf_counter()
print(b - a)

字符串

相关问题