numpy 使用广播将单个向量与向量数组相乘时出现Numba键入错误

6yjfywim  于 2023-10-19  发布在  其他
关注(0)|答案(1)|浏览(97)

我在将numba应用于一组函数时遇到了一个问题,我试图优化这些函数的性能。所有的函数在没有numba的情况下都能正常工作,但是当我尝试使用numba时,我得到了一个编译错误。
以下是我正在努力解决的编译错误:

Exception occurred:
Type: TypingError
Message: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
Cannot unify array(float64, 2d, C) and array(float64, 1d, C) for 'q1.2', defined at .\rotations.py (82)

File "rotations.py", line 82:
def quaternion_mult(q1, qa):
    <source elided>

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    ^

During: typing of assignment at .\rotations.py (82)

File "rotations.py", line 82:
def quaternion_mult(q1, qa):
    <source elided>

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    ^

During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)

During: resolving callee type: type(CPUDispatcher(<function quaternion_mult at 0x00000290EE6FE670>))
During: typing of call at .\rotations.py (102)

File "rotations.py", line 102:
def quaternion_vect_mult(q1, vect_array):
    <source elided>

    temp = quaternion_mult(q1, q2)
    ^

下面是相应函数的完整代码:

@njit(cache=True)
def quaternion_conjugate_vect(q):
    """
    return the conjugate of a quaternion or an array of quaternions
    """
    return q * np.array([1, -1, -1, -1])

@njit(cache=True)
def quaternion_mult(q1, qa):
    """
    multiply an array of quaternions (Nx4) by a single quaternion.

    qa is always a (Nx4) array of quaternions np.ndarray
    q1 is always a single (1x4) quaternion np.ndarray

    """
    N = max(len(qa), len(q1))
    quat_result = np.zeros((N, 4), dtype=np.float64)

    if qa.ndim == 1:
        q2 = qa.copy().reshape((1, -1))
        # q2 = np.reshape(q1, (1,-1))
    else:
        q2 = qa

    if q1.ndim == 1:
        # q1 = q1.copy().reshape((1, -1))
        q1 = np.reshape(q1, (1, -1))

    quat_result[:, 0] = (q1[:, 0] * q2[:, 0]) - (q1[:, 1] * q2[:, 1]) - (q1[:, 2] * q2[:, 2]) - (q1[:, 3] * q2[:, 3])
    quat_result[:, 1] = (q1[:, 0] * q2[:, 1]) + (q1[:, 1] * q2[:, 0]) + (q1[:, 2] * q2[:, 3]) - (q1[:, 3] * q2[:, 2])
    quat_result[:, 2] = (q1[:, 0] * q2[:, 2]) + (q1[:, 2] * q2[:, 0]) + (q1[:, 3] * q2[:, 1]) - (q1[:, 1] * q2[:, 3])
    quat_result[:, 3] = (q1[:, 0] * q2[:, 3]) + (q1[:, 3] * q2[:, 0]) + (q1[:, 1] * q2[:, 2]) - (q1[:, 2] * q2[:, 1])

    return quat_result

@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
    """
    Multiplies an array of x,y,z coordinates by a single quaternion q1.
    """
    # q1 is the quaternion which the coordinates will be rotated by.

    # Add initial column of zeros to array
    # N = len(vect_array)
    q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
    q2[:, 1::] = vect_array

    temp = quaternion_mult(q1, q2)
    result = quaternion_mult(temp, quaternion_conjugate_vect(q1))

    return result[:, 1::]

我不明白统一错误,因为我在乘法中广播,所以形状应该是无关紧要的?所有的数组都是np.float64,所以我指定它作为类型。唯一的区别是形状,但正常的numpy广播应该在这里工作,因为它没有numba。(我已经添加了额外的括号,以确保我正确地乘以东西,但它们根本不需要。
我假设这个问题与np.zeros存储阵列的创建有关,我已经添加了这个,因为之前我分别计算了每列,然后与np.stack合并。
我唯一的其他想法是,它与if ... else...有关,我检查单个四元数是否为shape(1,4)而不是(,4)
我有点被这个问题难住了,其他类似的问题通常似乎有一些类型的差异,像intfloatfloat32float64
任何帮助都是感激不尽的。
为了清楚起见,下面是一个不使用numba但启用它时失败的示例:

from numba import njit
import numpy as np

quat_single = np.random.random(,4)
coord_array = np.random.random([9,3])

Note: quat_single = np.random.random([1,4]) will work with `numba`

quaternion_vect_mult(quat_single, coord_array)
Out[18]: 
array([[ 0.12035005,  1.51894951,  0.26731225],
       [ 1.56889141,  0.56465019,  0.18818138],
       [ 0.58966646,  1.09653585, -0.19548354],
       [ 1.15044012,  1.56034916,  0.73943456],
       [ 0.83003034,  1.80861828,  0.02678796],
       [ 1.15572912,  0.54263501,  0.16206597],
       [ 1.34243762,  1.0802315 , -0.20735991],
       [ 1.5876305 ,  0.70017144,  0.80066164],
       [ 1.20734218,  1.2747372 , -0.47177605]])
aamkag61

aamkag611#

使用这些行:

temp = quaternion_mult(q1, q2)
    result = quaternion_mult(temp, quaternion_conjugate_vect(q1))

你每次都给quaternion_mult不同的参数类型,所以numba对如何编译这个函数感到困惑。
为您要支持的每个参数类型/维度分别创建quaternion_mult,例如:

@njit(cache=True)
def quaternion_conjugate_vect(q):
    """
    return the conjugate of a quaternion or an array of quaternions
    """
    return q * np.array([1, -1, -1, -1])

@njit(cache=True)
def quaternion_mult1(q1, qa):
    """
    multiply an array of quaternions (Nx4) by a single quaternion.

    qa is always a (Nx4) array of quaternions np.ndarray
    q1 is always a single (1x4) quaternion np.ndarray

    """
    N = max(len(qa), len(q1))
    quat_result = np.zeros((N, 4), dtype=np.float64)

    # if qa.ndim == 1:
    #     q2 = qa.copy().reshape((1, -1))
    #     # q2 = np.reshape(q1, (1,-1))
    # else:
    #     q2 = qa

    q2 = qa

    # if q1.ndim == 1:
    #     # q1 = q1.copy().reshape((1, -1))
    #     q1 = np.reshape(q1, (1, -1))

    quat_result[:, 0] = (
        (q1[0] * q2[:, 0])
        - (q1[1] * q2[:, 1])
        - (q1[2] * q2[:, 2])
        - (q1[3] * q2[:, 3])
    )
    quat_result[:, 1] = (
        (q1[0] * q2[:, 1])
        + (q1[1] * q2[:, 0])
        + (q1[2] * q2[:, 3])
        - (q1[3] * q2[:, 2])
    )
    quat_result[:, 2] = (
        (q1[0] * q2[:, 2])
        + (q1[2] * q2[:, 0])
        + (q1[3] * q2[:, 1])
        - (q1[1] * q2[:, 3])
    )
    quat_result[:, 3] = (
        (q1[0] * q2[:, 3])
        + (q1[3] * q2[:, 0])
        + (q1[1] * q2[:, 2])
        - (q1[2] * q2[:, 1])
    )

    return quat_result

@njit(cache=True)
def quaternion_mult2(q1, qa):
    N = max(len(qa), len(q1))
    quat_result = np.zeros((N, 4), dtype=np.float64)
    q2 = qa.copy().reshape((1, -1))

    quat_result[:, 0] = (
        (q1[:, 0] * q2[:, 0])
        - (q1[:, 1] * q2[:, 1])
        - (q1[:, 2] * q2[:, 2])
        - (q1[:, 3] * q2[:, 3])
    )
    quat_result[:, 1] = (
        (q1[:, 0] * q2[:, 1])
        + (q1[:, 1] * q2[:, 0])
        + (q1[:, 2] * q2[:, 3])
        - (q1[:, 3] * q2[:, 2])
    )
    quat_result[:, 2] = (
        (q1[:, 0] * q2[:, 2])
        + (q1[:, 2] * q2[:, 0])
        + (q1[:, 3] * q2[:, 1])
        - (q1[:, 1] * q2[:, 3])
    )
    quat_result[:, 3] = (
        (q1[:, 0] * q2[:, 3])
        + (q1[:, 3] * q2[:, 0])
        + (q1[:, 1] * q2[:, 2])
        - (q1[:, 2] * q2[:, 1])
    )

    return quat_result

@njit(cache=True)
def quaternion_vect_mult(q1, vect_array):
    """
    Multiplies an array of x,y,z coordinates by a single quaternion q1.
    """
    # q1 is the quaternion which the coordinates will be rotated by.

    # Add initial column of zeros to array
    # N = len(vect_array)
    q2 = np.zeros((len(vect_array), 4), dtype=np.float64)
    q2[:, 1::] = vect_array

    temp = quaternion_mult1(q1, q2)
    result = quaternion_mult2(temp, quaternion_conjugate_vect(q1))

    return result[:, 1::]

有:

np.random.seed(42)
quat_single = np.random.random(4)
coord_array = np.random.random([9, 3])

print(quaternion_vect_mult(quat_single, coord_array))

这将打印:

[[ 0.26852132  0.20522199  0.28520316]
 [ 1.89120847  1.35797162  0.79965888]
 [ 2.322112   -0.39389235  0.76960471]
 [ 0.51270351  0.3143128   0.24153831]
 [ 1.2691966   0.32325645  0.60666047]
 [ 0.85615508  0.20021656  1.01254022]
 [ 1.15864463  0.39780013  0.3251974 ]
 [ 1.17341506  1.41237398  0.29654629]
 [ 1.15734464  1.16277993 -0.14839415]]

根据我的“基准”,它应该比非jitted版本快30- 40倍。

相关问题