环境
- 操作系统:Windows 10
- Python版本:3.10
- Numba版本:0.57.0
- NumPy版本:1.24.3
示例
import numpy as np
from numba import njit
@njit
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
# return a @ b.T # also tried this with a similar result, np.matmul seems to be unsupported by Numba
return a.dot(b.transpose())
matmul_transposed(np.array([[1.0, 1.0]]), np.array([[1.0, 1.0]]))
错误
上面的示例引发了一个错误
- Resolution failure for literal arguments:
No implementation of function Function(<function array_dot at 0x...>) found for signature:
>>> array_dot(array(float64, 2d, C), array(float64, 2d, F))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'array_dot': File: numba\np\arrayobj.py: Line 5929.
With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function dot at 0x...>) found for signature:
>>> dot(array(float64, 2d, C), array(float64, 2d, F))
There are 4 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
LoweringError: Failed in nopython mode pipeline (step: native lowering)
scipy 0.16+ is required for linear algebra
File "[...]\numba\np\linalg.py", line 582:
def _dot2_codegen(context, builder, sig, args):
<source elided>
return lambda left, right: _impl(left, right)
^
During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at [...]\numba\np\linalg.py (582)
raised from [...]\numba\core\errors.py:837
- Of which 2 did not match due to:
Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
Rejected as the implementation raised a specific error:
TypingError: missing a required argument: 'out'
raised from [...]\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function dot at 0x...>)
During: typing of call at [...]\numba\np\arrayobj.py (5932)
File "[...]\numba\np\arrayobj.py", line 5932:
def dot_impl(arr, other):
return np.dot(arr, other)
^
raised from [...]\numba\core\typeinfer.py:1086
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.npytypes.Array'>, 'dot') for array(float64, 2d, C))
During: typing of call at [...]\example.py (7)
File "scratch_2.py", line 7:
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
<source elided>
"""Return a @ b.T"""
return a.dot(b.transpose())
^
解释
从错误信息中,我得出结论,Numba似乎通过将其布局样式从C更改为Fotran来转置数组,这是一种廉价的操作,因为它不必物理地更改布局,但它似乎不知道如何将C样式数组和Fotrtran样式数组相乘。
提问
有没有办法把这些矩阵相乘?最好不要复制整个b
,而做换位?
这似乎是一个相当普通的操作,所以我很困惑,它不工作。
1条答案
按热度按时间omqzjyyz1#
你的解释并不离谱:numba有四个候选项来乘以C和F布局数组,并给出了细节,为什么最终没有选择每一个。后两个被忽略是因为缺少一个参数,所以它们显然是用于另一个调用签名的。前两个被排除是因为某些东西在内部不起作用:
虽然第一行非常晦涩,但第二行仍然是错误消息的一部分,并给出了一个很好的提示。手动安装
scipy
,它应该可以工作。顺便说一句:这基本上是一个numpy函数的一行程序,它应该在单个CPU核心上执行得很好,因为numba没有太多的Python开销需要消除。当然,这取决于你还想做什么,但不要指望这一件事会得到显着的提升。