scipy 稀疏矩阵的点积乘法中的矩阵大小错误

ma8fv8wu  于 2023-05-07  发布在  其他
关注(0)|答案(2)|浏览(162)

我正在编写一个线性回归。我在使用SciPy稀疏矩阵时遇到了.dot()产品的问题。这是一个最小可重复的例子,有200个观测值,50个回归量和400个输出。

import numpy as np
import scipy

n_row = 200
n_col = 50
n_outcomes = 400

x = np.random.rand(n_row, n_col)

y = scipy.sparse.rand(n_row, n_outcomes, format="csc", density=0.05)

print(x.shape)  # (200, 50)
print(y.shape)  # (200, 400)
print((x.T @ y).shape) # (50, 400)
print(((x.T).dot(y)).shape)  # (50, 200) <- WRONG, it should be (50, 400)
xx_inv = np.linalg.inv(x.T.dot(x))
xx_inv.dot((x.T).dot(y)) # this calculation takes a long time

阅读the documentation,它说对于二维矩阵,.dot()就像矩阵乘法@,但使用@也可以。
.dot()对稀疏矩阵做了什么,为什么它这样做而不是抛出错误?

vh0rcniy

vh0rcniy1#

使用@矩阵乘法的稀疏定义,产生与密集等价物相同的东西-但更快一点。

In [65]: np.allclose((x.T @ y), x.T.dot(y.A))
Out[65]: True

In [66]: timeit x.T@y
381 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [67]: timeit x.T.dot(y.A)
521 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

y.A是将稀疏转换为密集(.toarray())的正确方法。np.array(y)错误。

In [68]: timeit x.T.dot(y)
1.69 s ± 210 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

不仅形状不对,dtype也不对,而且时间要长得多。当形状出乎意料时,进一步挖掘,确保dtype,甚至元素值是有意义的。
另一种正确计算的方法

In [71]: timeit y.T.dot(x).T
379 µs ± 1.36 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

测试matmul建议:

In [75]: np.matmul(x.T,y).shape
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[75], line 1
----> 1 np.matmul(x.T,y).shape

ValueError: matmul: Input operand 1 does not have enough dimensions (has 0, gufunc core with signature (n?,k),(k,m?)->(n?,m?) requires 1)

此错误是因为np.array(y)是0 d对象dtype数组;只是一个粗糙的数组 Package 器围绕稀疏矩阵。与dot相比,matmul不能对0 d,标量,对象或数组进行操作。

chhqkbe1

chhqkbe12#

您可能已经从我们在评论中的讨论中了解到了这一点,但是您看到这种差异的原因可能是因为您对matmul使用了操作符,而对dot使用了函数/方法。首先使用调用对象的运算符实现。但是,如果失败,则调用另一个操作数的等效反射运算符,在本例中为__rmatmul__,它由稀疏矩阵实现。当调用dot时,它总是使用调用对象的dot实现,这无法正确处理稀疏数组。

相关问题