我在将numba
应用于一组函数时遇到了一个问题,我试图优化这些函数的性能。所有的函数在没有numba
的情况下都能正常工作,但是当我尝试使用numba
时,我得到了一个编译错误。
以下是我正在努力解决的编译错误:
Exception occurred:
Type: TypingError
Message: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(float64, 2d, C) and array(float64, 1d, C) for 'q1.2', defined at .\rotations.py (82)
File "rotations.py", line 82:
def quaternion_mult(q1, qa):
<source elided>
quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
^
During: typing of assignment at .\rotations.py (82)
File "rotations.py", line 82:
def quaternion_mult(q1, qa):
<source elided>
quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
^
During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)
During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)
File "rotations.py", line 102:
def quaternion_vect_mult(q1, vect_array):
<source elided>
temp = quaternion_mult(q1, q2)
^
下面是相应函数的完整代码:
@njit(cache=True)
def quaternion_conjugate_vect(q):
"""
return the conjugate of a quaternion or an array of quaternions
"""
return q * np.array([1, -1, -1, -1])
@njit(cache=True)
def quaternion_mult(q1, qa):
"""
multiply an array of quaternions (Nx4) by a single quaternion.
qa is always a (Nx4) array of quaternions np.ndarray
q1 is always a single (1x4) quaternion np.ndarray
"""
N = max(len(qa), len(q1))
quat_result = np.zeros((N, 4), dtype=np.float64)
if qa.ndim == 1:
q2 = qa.copy().reshape((1, -1))
# q2 = np.reshape(q1, (1,-1))
else:
q2 = qa
if q1.ndim == 1:
# q1 = q1.copy().reshape((1, -1))
q1 = np.reshape(q1, (1, -1))
quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
quat_result[:, 1] = (q1[:, 0] * q2[:, 1]) + (q1[:, 1] * q2[:, 0]) + (q1[:, 2] * q2[:, 3]) - (q1[:, 3] * q2[:, 2])
quat_result[:, 2] = (q1[:, 0] * q2[:, 2]) + (q1[:, 2] * q2[:, 0]) + (q1[:, 3] * q2[:, 1]) - (q1[:, 1] * q2[:, 3])
quat_result[:, 3] = (q1[:, 0] * q2[:, 3]) + (q1[:, 3] * q2[:, 0]) + (q1[:, 1] * q2[:, 2]) - (q1[:, 2] * q2[:, 1])
return quat_result
@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
"""
Multiplies an array of x,y,z coordinates by a single quaternion q1.
"""
# q1 is the quaternion which the coordinates will be rotated by.
# Add initial column of zeros to array
# N = len(vect_array)
q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
q2[:, 1::] = vect_array
temp = quaternion_mult(q1, q2)
result = quaternion_mult(temp, quaternion_conjugate_vect(q1))
return result[:, 1::]
我不明白统一错误,因为我在乘法中广播,所以形状应该是无关紧要的?所有的数组都是np.float64
,所以我指定它作为类型。唯一的区别是形状,但正常的numpy
广播应该在这里工作,因为它没有numba
。(我已经添加了额外的括号,以确保我正确地乘以东西,但它们根本不需要。
我假设这个问题与np.zeros
存储阵列的创建有关,我已经添加了这个,因为之前我分别计算了每列,然后与np.stack
合并。
我唯一的其他想法是,它与if ... else...
有关,我检查单个四元数是否为shape
(1,4)
而不是(,4)
。
我有点被这个问题难住了,其他类似的问题通常似乎有一些类型的差异,像int
和float
或float32
和float64
。
任何帮助都是感激不尽的。
为了清楚起见,下面是一个不使用numba
但启用它时失败的示例:
from numba import njit
import numpy as np
quat_single = np.random.random(,4)
coord_array = np.random.random([9,3])
Note: quat_single = np.random.random([1,4]) will work with `numba`
quaternion_vect_mult(quat_single, coord_array)
Out[18]:
array([[ 0.12035005, 1.51894951, 0.26731225],
[ 1.56889141, 0.56465019, 0.18818138],
[ 0.58966646, 1.09653585, -0.19548354],
[ 1.15044012, 1.56034916, 0.73943456],
[ 0.83003034, 1.80861828, 0.02678796],
[ 1.15572912, 0.54263501, 0.16206597],
[ 1.34243762, 1.0802315 , -0.20735991],
[ 1.5876305 , 0.70017144, 0.80066164],
[ 1.20734218, 1.2747372 , -0.47177605]])
1条答案
按热度按时间aamkag611#
使用这些行:
你每次都给
quaternion_mult
不同的参数类型,所以numba对如何编译这个函数感到困惑。为您要支持的每个参数类型/维度分别创建
quaternion_mult
,例如:有:
这将打印:
根据我的“基准”,它应该比非jitted版本快30- 40倍。