python 在numba njit环境中传递一个形状给numpy.reshape失败了,我如何为目标形状创建一个合适的可迭代对象呢?

kwvwclae  于 2022-12-17  发布在  Python
关注(0)|答案(1)|浏览(185)

我有一个函数,它接受一个数组,执行一个任意的计算,然后返回一个新的形状,可以在其中广播。我想在numba.njit环境中使用这个函数:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    return tuple([2,2])
    
@nb.njit
def test():
    my_array = np.array([1,2,3,4])
    target_shape = generate_target_shape(my_array)
    reshaped = my_array.reshape(target_shape)
    print(reshaped)
test()

然而,numba不支持元组创建,当我尝试使用tuple()操作符将generate_target_shape的结果转换为元组时,我得到了以下错误消息:

No implementation of function Function(<class 'tuple'>) found for signature:
 
 >>> tuple(list(int64)<iv=None>)
 
There are 2 candidate implementations:
   - Of which 2 did not match due to:
   Overload of function 'tuple': File: numba/core/typing/builtins.py: Line 572.
     With argument(s): '(list(int64)<iv=None>)':
    No match.

During: resolving callee type: Function(<class 'tuple'>

如果我尝试将generate_target_shape的返回类型从tuple更改为listnp.array,则会收到以下错误消息:

Invalid use of BoundFunction(array.reshape for array(float64, 1d, C)) with parameters (array(int64, 1d, C))

有没有办法在nb.njit函数内部创建一个可迭代对象,并将其传递给np.reshape

mlnl4t2r

mlnl4t2r1#

看起来numba不支持标准的python函数tuple(),你可以通过重写代码来解决这个问题:

import numpy as np
import numba as nb

@nb.njit
def generate_target_shape(my_array):
    ### some functionality that calculates the desired target shape ###
    a, b = [2, 2] # (this will also work if the list is a numpy array)
    return a, b

相关问题