在numpy中计算A * B * A'(`A @ B @ A.T`)并保持对称性

k4aesqcs  于 2023-11-18  发布在  其他
关注(0)|答案(1)|浏览(126)

我想计算矩阵ABA * B * A'项。
A'A的转置。
在Python上有没有一种有效的方法来计算这个?
我可以做A @ B @ A.T,但我想要的东西:
1.即利用对称性进行计算。
1.保证结果对称。
我有最直接的numba为基础的代码:

import numpy as np
from numba import jit, njit

A = np.random.rand(3, 3)
B = np.random.rand(3, 3)
B = B + B.T

@njit
def my_fun(A, B, C):
    for i in range(A.shape[0]):
        for j in range(A.shape[1]):
            for k in range(B.shape[1]):
                for l in range(i, A.shape[1]):
                    C[i, l] += A[i, j] * B[j, k] * A[l, k]
                    
    for i in range(1, A.shape[0]):
        for j in range(i):
            C[i, j] = C[j, i]
    
    return

C = np.zeros(shape = (A.shape[0], A.shape[0]))

my_fun(A, B, C)

np.all(C == C.T)

字符串
输出是对称的。我们可以做得更好的性能明智吗?
我还考虑过通过以下方法来修复numpy结果的对称性:

import numpy as np
from numba import jit, njit

A = np.random.rand(10, 10)
B = np.random.rand(10, 10)
B = B + B.T

@njit
def symmetrize_mat(A):
    
    for i in range(1, A.shape[0]):
        for j in range(i):
            A[i, j] = A[j, i]
    
    return

def my_faster_fun(A, B, C):
    C = np.matmul(A, B, out=C)
    C = np.matmul(C, A.T, out=C)
                    
    symmetrize_mat(C)
    
    return

C = np.zeros(shape = (A.shape[0], A.shape[0]))
D = A @ B @ A.T

my_faster_fun(A, B, C)

print(np.all(C == C.T))
print(np.all(D == D.T))
print(np.max(np.abs(C - D)))


我正在寻找中小型矩阵的最快解决方案。每个维度的范围为2-100。

0aydgbwb

0aydgbwb1#

所提供的代码在**O(n**4)中运行,而两个矩阵乘法在O(n**3)**中运行。因此,2个矩阵乘法肯定要快得多。人们可以尝试改变循环的顺序,然后分解一些计算,但结果可能与2个乘法矩阵相似。一个更简单的方法是编写执行2个乘法矩阵的代码,然后通过交换循环进行优化,使其更加SIMD友好,然后使其仅计算最后一个乘法矩阵的上三角部分。中间矩阵实际上可以逐行计算(对于大型矩阵来说更有效)。下三角部分可以像您在提供的实现中所做的那样计算。
下面是生成的代码:

@njit
def faster_fun(A, B, C):
    # Constraints coming from the 2 matrix multiplications
    assert A.shape[1] == B.shape[0] and A.shape[1] == B.shape[1]

    n, m = A.shape
    line = np.zeros(m)

    for i in range(n):
        line.fill(0.0)

        for k in range(m):
            factor = A[i, k]
            for j in range(m):
                line[j] += factor * B[k, j]

        for j in range(i, n):
            for k in range(m):
                C[i, j] += line[k] * A[j, k]

    for i in range(1, A.shape[0]):
        for j in range(i):
            C[i, j] = C[j, i]

字符串
这个解决方案比提供的要快得多。但是,它比A @ B @ A.T慢一点。这是因为Numpy使用BLAS库来计算矩阵乘法,而我机器上使用的BLAS是OpenBLAS:高度优化的实现。OpenBLAS以并行方式执行矩阵乘法而上面的Numba代码是顺序的。如果你打算从多线程代码运行Numba函数,那么Numba代码会比Numpy的代码快。否则,你可以这样并行化循环i

from numba import prange

@njit(parallel=True)
def faster_fun(A, B, C):
    assert A.shape[1] == B.shape[0] and A.shape[1] == B.shape[1]
    n, m = A.shape

    for i in prange(n):
        line = np.zeros(m)

        # [....] (same code later, but not need for line.fill)

    # [...]


这种解决方案比大多数机器上的顺序代码快,除了非常小的矩阵(因为产生线程,分发工作和等待它们是不自由的)。但是,它仍然比我机器上的基本Numpy代码A @ B @ A.T慢。这是因为line的分配根本不伸缩。确实,在我的6核CPU上,并行代码只比顺序代码快3.2倍。AFAIK,Numba中还没有简单的解决方案来解决这个(已知的)问题。在Cython中,解决方案是使用堆栈分配或线程本地分配来解决这个问题。尽管如此,即使有一个完美的缩放,上面的代码在60 × 60矩阵上也只会快30%。这表明要超过高度优化的BLAS实现是多么困难,而且对称性只给中等大小的矩阵给予小的性能增益
实际上,第二个矩阵乘法只需要计算一半(最佳),但第一个肯定需要完全计算。因此,理论上的最佳增益肯定是25%,这是相当小的。也就是说,在实践中,对于非常小的矩阵,顺序Numba实现应该比Numpy代码快得多(由于线程开销)。

相关问题