无法将numpy.dot与numba一起使用

njthzxwz  于 2023-08-05  发布在  其他
关注(0)|答案(2)|浏览(94)

我尝试用numba运行numpy.dot时出现错误。它似乎得到了支持(例如:numpy: Faster np.dot/ multiply(element-wise multiplication) when one array is the same)但例如这段代码给了我以下错误(如果我删除njit部分,它运行良好)
代码:

import numpy as np
import numba

@numba.njit()
def tst_dot():
    a = np.array([[1, 0], [0, 1]])
    b = np.array([[4, 1], [2, 2]])

    return np.dot(a, b)

print(tst_dot())

字符串
错误代码:

No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
 
 >>> dot(array(int64, 2d, C), array(int64, 2d, C))
 
There are 4 candidate implementations:
      - Of which 2 did not match due to:
      Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
        With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: Failed in nopython mode pipeline (step: native lowering)
       Failed in nopython mode pipeline (step: nopython frontend)
       No implementation of function Function(<function dot at 0x00000280CC542EF0>) found for signature:
        
        >>> dot(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))
        
       There are 4 candidate implementations:
             - Of which 2 did not match due to:
             Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
               With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
              Rejected as the implementation raised a specific error:
                TypingError: too many positional arguments
         raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784
             - Of which 2 did not match due to:
             Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
               With argument(s): '(array(int64, 2d, C), array(int64, 2d, C), array(int64, 2d, C))':
              Rejected as the implementation raised a specific error:
                LoweringError: Failed in nopython mode pipeline (step: native lowering)
              unsupported dtype for <BLAS function>()
              
              File "venv\lib\site-packages\numba\np\linalg.py", line 817:
                          def codegen(context, builder, sig, args):
                              <source elided>
              
                      return lambda left, right, out: _impl(left, right, out)
                      ^
              
              During: lowering "$10call_function.4 = call $2load_deref.0(left, right, out, func=$2load_deref.0, args=[Var(left, linalg.py:817), Var(right, linalg.py:817), Var(out, linalg.py:817)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (817)
         raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\errors.py:837
       
       During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
       During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (460)
       
       
       File "venv\lib\site-packages\numba\np\linalg.py", line 460:
           def dot_impl(a, b):
               <source elided>
               out = np.empty((m, n), a.dtype)
               return np.dot(a, b, out)
               ^
       
       During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\np\linalg.py (582)
  raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typeinfer.py:1086
      - Of which 2 did not match due to:
      Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
        With argument(s): '(array(int64, 2d, C), array(int64, 2d, C))':
       Rejected as the implementation raised a specific error:
         TypingError: missing a required argument: 'out'
  raised from C:\Users\a_che\PycharmProjects\minCovTarget\venv\lib\site-packages\numba\core\typing\templates.py:784

During: resolving callee type: Function(<function dot at 0x00000280CC542EF0>)
During: typing of call at C:\Users\a_che\PycharmProjects\minCovTarget\tst4.py (164)

File "tst4.py", line 164:
def tst_dot(a, b):
    <source elided>

    return np.dot(a, b)
    ^


我尝试添加out=None作为第三个参数(尽管它是可选的),但没有帮助。我期待着相同的结果,如果我没有使用numba

pgccezyw

pgccezyw1#

docs说道:
基本线性代数在 * 浮点 * 和复数的1-D和2-D连续数组上受支持:

  • numpy.dot()
  • ...

但是,您的两个数组包含 integers。请注意,错误消息:

dot(array(int64, 2d, C), array(int64, 2d, C))

字符串
因此,技巧是改变dtype

import numpy as np
import numba

@numba.njit()
def tst_dot():
    a = np.array([[1, 0], [0, 1]], dtype=np.float32)
    b = np.array([[4, 1], [2, 2]], dtype=np.float32)

    return np.dot(a, b)

print(tst_dot())

[[4. 1.]
 [2. 2.]]

yuvru6vn

yuvru6vn2#

www.example.com的numba实现numpy.dot似乎只支持浮点数。因此需要转换为浮点数,然后代码运行

相关问题