numpy lambda的输出大于1

gxwragnw  于 2023-05-07  发布在  其他
关注(0)|答案(2)|浏览(252)

可以用lambda来做以下事情吗?(我只是需要一些能够做到这一点真的很快)
因为这个答案而要求lambda:https://stackoverflow.com/a/35216364/3776738

import numpy as np

def one_to_more(i):
    //do some calculations etc...
    return [i*3,i*9]

x = np.array([1, 2, 3])
f = lambda x: one_to_more(x)
more = f(x)
print(more)

---〉[3,9,6,18,9,27]
编辑:一定不是LAMBDA。我只是在寻找以这种方式扩展大型列表或numpy数组的最快方法。这种方式意味着它将有两倍(上面的示例代码甚至更长的长度)

清洁

这是实际使用的函数:

MAX_NUM=100000
def num_to_arr (num):
    num = int(num) 
    if (num < 0 or num >= MAX_NUM):
        num = 0

    num3 = (num // 1600)
    num2 = ((num - num3 * 1600) // 40)
    num1 = int((num - num3 * 1600 - num2 * 40))
    arr = [num1 / 40, num2 / 40, num3 / 40]
    return arr

这样使用:

result=list(map(num_to_arr,large_array))

大数组由大约10 k个整数组成,执行时间大约为17 ms,这太高了。(CPU为AMD RYZEN 7950X)

vkc1a9a2

vkc1a9a21#

好的,谢谢你提供你的实际问题。我创建了你的函数的一个版本,它对整个数组起作用,并使用了一些向量化(尽管它没有完全向量化)。这提供了相当好的改进,虽然:

In [5]: MAX_NUM = 5000

In [6]: def num_to_arr (num):
    ...:     num = int(num)
    ...:     if (num < 0 or num >= MAX_NUM):
    ...:         num = 0
    ...:
    ...:     num3 = (num // 1600)
    ...:     num2 = ((num - num3 * 1600) // 40)
    ...:     num1 = int((num - num3 * 1600 - num2 * 40))
    ...:     arr = [num1 / 40, num2 / 40, num3 / 40]
    ...:     return arr
    ...:

In [7]: def num_to_arr_vec(arr):
    ...:     arr[(arr < 0) | (arr >= MAX_NUM)] = 0
    ...:     result = np.zeros((len(arr), 3))
    ...:     result[:,2] = arr//1600
    ...:     result[:,1] = ((arr - result[:,2] * 1600) // 40)
    ...:     result[:,0] = (arr - result[:,2] * 1600 - result[:,1] * 40)
    ...:     return result / 40
    ...:

In [8]: arr = np.random.randint(-10_000, 10_000, 10_000)

In [9]: %timeit list(map(num_to_arr, arr))
6.72 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [10]: %timeit num_to_arr_vec(arr)
333 µs ± 6.46 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

所以速度快了20倍。我怀疑要做得更好,您会希望使用numba here and write a very imperative loop
编辑:
添加numba示例:

In [51]: import numba

In [52]: @numba.jit(numba.float64[:,:](numba.int64[:]), nopython=True)
    ...: def num_to_arr_numba(arr):
    ...:     result = np.empty((len(arr), 3))
    ...:     for i in range(len(arr)):
    ...:         num = arr[i]
    ...:         if num < 0 or num >= MAX_NUM:
    ...:             num = 0
    ...:         num3 = num // 1600
    ...:         num2 = ((num - num3 * 1600) // 40)
    ...:         num1 = (num - num3 * 1600 - num2 * 40)
    ...:         result[i, 0] = num1
    ...:         result[i, 1] = num2
    ...:         result[i, 2] = num3
    ...:     result /= 40
    ...:     return result
    ...:

In [53]: %timeit num_to_arr_numba(arr)
85.7 µs ± 2.64 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

这似乎使我们的速度提高了80倍。

8ftvxx2r

8ftvxx2r2#

您可以使用广播和ravel

def one_to_more(i):
    return (np.array([3, 9])*i[:,None]).ravel()

out = one_to_more(x)

输出:array([ 3, 9, 6, 18, 9, 27])

相关问题