使用Numba JIT与转置的NumPy数组的矩阵乘法不起作用

qmb5sa22  于 2023-06-23  发布在  其他
关注(0)|答案(1)|浏览(143)

环境

  • 操作系统:Windows 10
  • Python版本:3.10
  • Numba版本:0.57.0
  • NumPy版本:1.24.3

示例

import numpy as np
from numba import njit

@njit
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    # return a @ b.T  # also tried this with a similar result, np.matmul seems to be unsupported by Numba
    return a.dot(b.transpose())

matmul_transposed(np.array([[1.0, 1.0]]), np.array([[1.0, 1.0]]))

错误

上面的示例引发了一个错误

- Resolution failure for literal arguments:
No implementation of function Function(<function array_dot at 0x...>) found for signature:
 >>> array_dot(array(float64, 2d, C), array(float64, 2d, F))
There are 2 candidate implementations:
  - Of which 2 did not match due to:
  Overload in function 'array_dot': File: numba\np\arrayobj.py: Line 5929.
    With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
   Rejected as the implementation raised a specific error:
     TypingError: Failed in nopython mode pipeline (step: nopython frontend)
   No implementation of function Function(<function dot at 0x...>) found for signature:
    >>> dot(array(float64, 2d, C), array(float64, 2d, F))
   There are 4 candidate implementations:
         - Of which 2 did not match due to:
         Overload in function 'dot_2': File: numba\np\linalg.py: Line 525.
           With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
          Rejected as the implementation raised a specific error:
            LoweringError: Failed in nopython mode pipeline (step: native lowering)
          scipy 0.16+ is required for linear algebra
          
          File "[...]\numba\np\linalg.py", line 582:
                      def _dot2_codegen(context, builder, sig, args):
                          <source elided>
                  return lambda left, right: _impl(left, right)
                  ^
          
          During: lowering "$8call_function.3 = call $2load_deref.0(left, right, func=$2load_deref.0, args=[Var(left, linalg.py:582), Var(right, linalg.py:582)], kws=(), vararg=None, varkwarg=None, target=None)" at [...]\numba\np\linalg.py (582)
     raised from [...]\numba\core\errors.py:837
         - Of which 2 did not match due to:
         Overload in function 'dot_3': File: numba\np\linalg.py: Line 784.
           With argument(s): '(array(float64, 2d, C), array(float64, 2d, F))':
          Rejected as the implementation raised a specific error:
            TypingError: missing a required argument: 'out'
     raised from [...]\numba\core\typing\templates.py:784
   
   During: resolving callee type: Function(<function dot at 0x...>)
   During: typing of call at [...]\numba\np\arrayobj.py (5932)
   
   
   File "[...]\numba\np\arrayobj.py", line 5932:
       def dot_impl(arr, other):
           return np.dot(arr, other)
           ^
  raised from [...]\numba\core\typeinfer.py:1086
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.npytypes.Array'>, 'dot') for array(float64, 2d, C))
During: typing of call at [...]\example.py (7)
File "scratch_2.py", line 7:
def matmul_transposed(a: np.ndarray, b: np.ndarray) -> np.ndarray:
    <source elided>
    """Return a @ b.T"""
    return a.dot(b.transpose())
    ^

解释

从错误信息中,我得出结论,Numba似乎通过将其布局样式从C更改为Fotran来转置数组,这是一种廉价的操作,因为它不必物理地更改布局,但它似乎不知道如何将C样式数组和Fotrtran样式数组相乘。

提问

有没有办法把这些矩阵相乘?最好不要复制整个b,而做换位?
这似乎是一个相当普通的操作,所以我很困惑,它不工作。

omqzjyyz

omqzjyyz1#

你的解释并不离谱:numba有四个候选项来乘以C和F布局数组,并给出了细节,为什么最终没有选择每一个。后两个被忽略是因为缺少一个参数,所以它们显然是用于另一个调用签名的。前两个被排除是因为某些东西在内部不起作用:

LoweringError: Failed in nopython mode pipeline (step: native lowering)
          scipy 0.16+ is required for linear algebra

虽然第一行非常晦涩,但第二行仍然是错误消息的一部分,并给出了一个很好的提示。手动安装scipy,它应该可以工作。
顺便说一句:这基本上是一个numpy函数的一行程序,它应该在单个CPU核心上执行得很好,因为numba没有太多的Python开销需要消除。当然,这取决于你还想做什么,但不要指望这一件事会得到显着的提升。

相关问题