我尝试用numba
运行numpy.dot
时出现错误。它似乎得到了支持(例如:numpy: Faster np.dot/ multiply(element-wise multiplication) when one array is the same)但例如这段代码给了我以下错误(如果我删除njit部分,它运行良好)
代码:
import numpy as np
import numba
@numba.njit()
def tst_dot():
a = np.array([[1, 0], [0, 1]])
b = np.array([[4, 1], [2, 2]])
return np.dot(a, b)
print(tst_dot())
字符串
错误代码:
No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
>>> dot(array(int64, 2d, C), array(int64, 2d, C))
There are 4 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: native lowering)
Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
>>> dot(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))
There are 4 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: too many positional arguments
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
- Of which 2 did not match due to:
Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
LoweringError: Failed in nopython mode pipeline (step: native lowering)
unsupported dtype for <BLAS function>()
File "venv\lib\site-packages\numba\np\linalg.py", line 817:
def codegen(context, builder, sig, args):
<source elided>
return lambda left, right, out: _impl(left, right, out)
^
During: lowering "$10call_function.4 = call $2load_deref.0(left, right, out, func=$2load_deref.0, args=[Var(left, linalg.py:817), Var(right, linalg.py:817), Var(out, linalg.py:817)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (817)
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\errors.py:837
During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (460)
File "venv\lib\site-packages\numba\np\linalg.py", line 460:
def dot_impl(a, b):
<source elided>
out = np.empty((m, n), a.dtype)
return np.dot(a, b, out)
^
During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (582)
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typeinfer.py:1086
- Of which 2 did not match due to:
Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
Rejected as the implementation raised a specific error:
TypingError: missing a required argument: 'out'
raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\tst4.py (164)
File "tst4.py", line 164:
def tst_dot(a, b):
<source elided>
return np.dot(a, b)
^
型
我尝试添加out=None
作为第三个参数(尽管它是可选的),但没有帮助。我期待着相同的结果,如果我没有使用numba
。
2条答案
按热度按时间pgccezyw1#
docs
说道:基本线性代数在 * 浮点 * 和复数的1-D和2-D连续数组上受支持:
numpy.dot()
个但是,您的两个数组包含 integers。请注意,错误消息:
字符串
因此,技巧是改变
dtype
:型
yuvru6vn2#
www.example.com的numba实现numpy.dot似乎只支持浮点数。因此需要转换为浮点数,然后代码运行