numpy操作中的批处理轴

xt0899hw  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(113)

为什么“批处理”轴总是NumPy中的前导轴?我把我所有的包都设计成使用尾随轴作为批处理轴,因为这对我来说更自然。现在我考虑切换到NumPy的约定-只是为了让NumPy用户更直观。有什么想法吗
从性能方面来说,这可能是一个非常糟糕的主意:

import numpy as np

np.random.seed(6512)
a = np.random.rand(50000, 8, 3, 3)

np.random.seed(85742)
b = np.random.rand(50000, 8, 3, 3)

c = a @ b
# 19.8 ms ± 543 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

d = np.einsum("...ik,...kj->...ij", a, b)
# 84.1 ms ± 2.4 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

# now use the trailing axes (ensure C-contiguous arrays for transposed data)
A = np.ascontiguousarray(np.transpose(a, [2, 3, 0, 1])) # A_ijab
B = np.ascontiguousarray(np.transpose(b, [2, 3, 0, 1])) # B_ijab

C = (B.T @ A.T).T # (C^T)_baji = B_bajk A_baki -> C_ijab
# 16.9 ms ± 1.82 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

D = np.einsum("ik...,kj...->ij...", A, B)
# 17.2 ms ± 842 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

assert np.allclose(c, d)
assert np.allclose(C, D)
assert np.allclose(np.transpose(D, [2, 3, 0, 1]), d)
assert np.allclose(np.transpose(C, [2, 3, 0, 1]), c)

或者更复杂的例子:

# crossed-dyadic product
# ----------------------

E = np.einsum("ik...,jl...->ijkl...", A, B)
# 76.5 ms ± 2.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

e = np.einsum("...ik,...jl->...ijkl", a, b)
# 207 ms ± 3.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

assert np.allclose(np.transpose(E, [4, 5, 0, 1, 2, 3]), e)
jslywgbw

jslywgbw1#

Numpy是用C写的,使用C的数组约定,即row-major array ordering。因此,在最后一个轴上应用操作(即,最右边的,也是最 * 连续 * 的),对于CPU缓存来说更有效。对于大型数组,调换数组会显著增加 * 对RAM* 的压力,因此通常会导致计算速度变慢(RAM通常是Numpy操作的限制因素)。
话虽如此,在您的情况下,Numpy显然没有针对3x3矩阵进行优化。在这种情况下,内部通用Numpy迭代器**(启用广播)的**开销非常巨大,以至于计算受到它们的约束。大多数BLAS库也没有针对这种极小的矩阵进行优化。一些线性代数库为此提供批处理操作(例如AFAIK CuBLAS这样做)。但是,Numpy还不支持它们。
现代主流CPU可以在几纳秒内完成3x3矩阵乘法运算,因此通用代码的开销太大,无法有效地计算它们。为了获得快速实现,您需要编写一个支持 * 特定 * 固定大小的3x3矩阵的 * 编译 * 代码。然后,编译器可以生成针对这种特定情况选择的高效指令。手工编写的汇编代码(或编译器 SIMD intrinsic)对于这个用例来说肯定会快得多,但它们很难编写,维护,而且容易出错。好的解决方案是使用Cython(带有内存视图和正确的 * 编译标志 *),甚至在这种情况下使用Numba(如果可能的话,带有快速数学标志)。你可以在这里找到一个Numba代码的例子来解决类似的问题。

相关问题