python 在Numpy中使用低内存的大型矩阵乘法

g6baxovj  于 2023-06-20  发布在  Python
关注(0)|答案(2)|浏览(126)

我有一个复杂的矩阵乘法,有几十万行和列。在某个时候,内存使用率增长到100%,然后计算机冻结,我必须手动重新启动它。
我尝试过使用Numba(使用装饰器在函数中编写代码)和Dask(将numpy数组转换为da.from_array(var,chunk)),但没有成功。我不是这方面的Maven。
我已经读了很多类似的问题,但没有找到一个很好的解决我的问题。
一个最小可重复的例子可能是

m = 100000
n = 100000
a1 = np.random.rand(m)
a2 = np.random.rand(n)
c = np.random.rand(m)+1j*np.random.rand(m)
b = np.random.rand(n)+1j*np.random.rand(n)
A = np.exp(1j*np.outer(a1,a2))
d = c*np.dot(A,b)

在内存使用方面,解决它的最佳选择是什么?(不一定是最快的)

pgky5nke

pgky5nke1#

主要问题

主要问题是1j*np.outer(a1,a2)需要100_000 * 100_000 * (8 * 2) = 149 GiB。最重要的是,np.exp需要读取这个矩阵并生成另一个相同大小的矩阵,因此您至少需要300 GiB的RAM。这是巨大的和低效的。
您应该避免以任何代价创建矩阵A(包括类似的临时矩阵)。

快速节省内存

Numba可以在这种情况下提供帮助:你可以在运行中计算数组d,避免了巨大的临时矩阵。下面是一个优化的Numba代码:

import numba as nb
import numpy as np

@nb.njit('(float64[::1], float64[::1], complex128[::1], complex128[::1])', parallel=True)
def compute(a1, a2, b, c):
    m, n = a1.size, a2.size
    assert n == m  # seems already mantatory in the initial code
    tmpDot = np.zeros(n, dtype=np.complex128)
    for i in nb.prange(n):
        for j in range(n):
            tmpDot[i] += np.exp(1j * (a2[j] * a1[i])) * b[j]
    return c * tmpDot

m = 100000
n = 100000
a1 = np.random.rand(m)
a2 = np.random.rand(n)
c = np.random.rand(m)+1j*np.random.rand(m)
b = np.random.rand(n)+1j*np.random.rand(n)
d = compute(a1, a2, b, c)

与最初的代码相比,这段代码只占用很少的内存:只有几个MiB。因此,它需要的内存减少100_000倍!此外,我还希望它运行得更快(因为它是多线程的,可以更好地使用CPU缓存和RAM)。在我的机器上只需要17.1秒(而我甚至不能运行初始代码)!

c9x0cxw0

c9x0cxw02#

估计内存使用情况:

m float64  # a1 = np.random.rand(m)
n float64  # a2 = np.random.rand(n)
m complex128   # c = np.random.rand(m)+1j*np.random.rand(m)
n complex   #b = np.random.rand(n)+1j*np.random.rand(n)

B、c行将有几个复杂的临时数组,但最终各有一个。

(m,n) complex   # A = np.exp(1j*np.outer(a1,a2))

outer构成一个(m,n)复形; 1j*生成另一个; exp其他
可以试试

A = np.zeros((m,n), dtype=complex)
np.outer(1j*a1,a2, out=A)
np.exp(A, out=A)

最后:

m complex    # d = c*np.dot(A,b)

dot做(m,n),其中(n,)=>(m,)。可能会让这个更紧凑

np.multiply(np.dot(A,b), c, out=c))

dot需要一个out,但我没有可用的(m,) complex
out的这些使用可以保存一些内存,消除一些(m,n)复杂的临时缓冲区。甚至可以保存一点时间。

相关问题