numpy numba vstack不适用于数组列表

rslzwgfq  于 11个月前  发布在  其他
关注(0)|答案(2)|浏览(82)

对我来说很奇怪,当输入是数组列表时,vstack不能与Numba一起工作,它只在输入是数组元组时工作。示例代码:

@nb.jit(nopython=True)
def stack(items):
    return np.vstack(items)

stack((np.array([1,2,3]), np.array([4,5,6])))

返回

array([[1, 2, 3],
       [4, 5, 6]])

stack([np.array([1,2,3]), np.array([4,5,6])])

抛出一个错误

TypingError: No implementation of function Function(<function vstack at 0x0000027271963488>) found for signature:
>>>vstack(reflected list(array(int32, 1d, C)))

由于tuple不受支持,我努力寻找解决方法-我错过了什么吗?

pkwftd7m

pkwftd7m1#

这是@hpaulj提到的一个解决方案:

stack(tuple([np.array([1,2,3]), np.array([4,5,6])]))

[[1 2 3]
 [4 5 6]]
x0fgdtte

x0fgdtte2#

在numba中,vstackhstackconcatenate只支持tuple作为输入,而不支持list
他们说这是因为numba在编译Ref(https://github.com/numba/numba/issues/7476)时无法推断堆栈数组的维数。但我怀疑它实际上可以,因为你可以手动完成这一点,如下所示。
您可以通过以下间接方式堆叠list

from numba import njit, prange
import numpy as np

@njit()
def test_list_stack(i, array_to_be_stacked):
    shape = (i,) + array_to_be_stacked.shape
    list_of_array = [array_to_be_stacked] * i
    stacked_array = np.empty(shape)
    for j in prange(i):
        stacked_array[j] = list_of_array[j]
    return stacked_array

if __name__ == "__main__":
    test_list_stack(10, np.ones((2, 3)))

或者,您可以定义一个自定义函数来执行堆栈工作:

from numba import njit, prange
from numba.typed import List
import numpy as np

@njit()
def stack(list_of_array):
    shape = (len(list_of_array),) + list_of_array[0].shape
    stacked_array = np.empty(shape)
    for j in prange(len(list_of_array)):
        stacked_array[j] = list_of_array[j]
    return stacked_array

if __name__ == "__main__":
    # Note that you have to use typed list provided by numba here.
    typed_list = List()
    [typed_list.append(np.ones((2, 3))) for _ in range(10)]
    stacked = stack(typed_list)
    print(stacked.shape)
    print(stacked)

相关问题