python 列表理解总和运行速度不够快

j0pj023g  于 2022-12-28  发布在  Python
关注(0)|答案(1)|浏览(128)

我有三个列表,我使用了一个列表解析的总和,然而,由于这些列表的长度n〉= 1500,我无法让我的代码运行得比每个列表解析3s更高效,这段代码需要运行数千次,所以3s per并不能解决问题。
下面是我当前的尝试。split只是我代码前面确定的一个浮点数。

sum([list1[k] * (list2[k] == 1) if list3[k] < split else list1[k] * (list2[k] == -1) for k in range(n)])

list1包含1500个0到1之间的正浮点数,其总和为1。
list2包含1500个随机抽样的-1和1。
list3包含来自正态分布的1500个随机采样值,例如np.random.normal(5, 0.5, 3)

qyswt5oh

qyswt5oh1#

我最后针对你的问题写了三种方法:改进了Python,麻木和numba。

  • 基于@KellyBelly评论的改进版python运行良好。zip对性能的影响非常大。
  • 使用numpy,您希望利用矢量化操作的强大功能,将条件转换为掩码,并完全摆脱循环。
  • 如果你对numba的重要概念(njitprange等)感到轻松自如,numba通常是最快的解决方案,它比numpy方法需要更多的校对,但它得到了很好的回报。

请注意,这些只是实现同一算法的不同方法,如果你在追逐那些宝贵的毫秒,改进一个低效的算法也是非常重要的。

    • 计时:**

| * * 项目**| * * 列出理解**| * * 压缩迭代器**| * * 数值数组**| * * 伦巴·尼Git**| * * 数字. njit(并行=真)**|
| - ------| - ------| - ------| - ------| - ------| - ------|
| 1千|0.191毫秒|0.129毫秒|0.487毫秒|0.006毫秒|0.013毫秒|
| 10千|2.288毫秒|1.206毫秒|0.477毫秒|0.048毫秒|0.019毫秒|
| 10万|十八点九四一毫秒|十三点二四五毫秒|2.857毫秒|0.477毫秒|0.056毫秒|

    • 代码:**
# Imports.
import numba as nb
import numpy as np
np.random.seed(0)

# Data.
N = 100000
SPLIT = 50
array1 = np.random.randint(0, 100, N)
array2 = np.random.choice((+1, -1), N)
array3 = np.random.randint(0, 100, N)
list1, list2, list3 = map(lambda a: a.tolist(), (array1, array2, array3))
print(N)

# Helpful timing function.
from contextlib import contextmanager
import time

@contextmanager
def time_this():
    t0 = time.perf_counter()
    yield
    dt = time.perf_counter() - t0
    print(f"{dt*1000:.3f} ms")

# List comprehension.
def list_comprehension():
    n = len(list1)
    return sum([list1[k] * (list2[k] == 1) if list3[k] < SPLIT else list1[k] * (list2[k] == -1) for k in range(n)])

# Zipped iterator.
def zipped_iterator():
    return sum(l1 if l2 == (1 if l3 < SPLIT else -1) else 0 for l1, l2, l3 in zip(list1, list2, list3))

# Numpy array.
def numpy_arrays():
    mask = array3 < SPLIT
    positives = array1[mask] * (array2[mask] == 1)
    negatives = array1[~mask] * (array2[~mask] == -1)
    return positives.sum() + negatives.sum()

# Numba.
@nb.njit
def numba_count():
    total = 0
    n = len(array1)
    for k in nb.prange(n):
        if array3[k] < SPLIT:
            sign = +1
        else:
            sign = -1
        if array2[k] == sign:
              total += array1[k]
    return total

# Numba in parallel.
@nb.njit(parallel=True)
def numba_count2():
    total = 0
    n = len(array1)
    for k in nb.prange(n):
        if array3[k] < SPLIT:
            sign = +1
        else:
            sign = -1
        if array2[k] == sign:
              total += array1[k]
    return total

# Timings.
totals = []
with time_this():
    totals.append(list_comprehension())

with time_this():
    totals.append(zipped_iterator())

with time_this():
    totals.append(numpy_arrays())

numba_count() # Compile before we time anything.
with time_this():
    totals.append(numba_count())

numba_count2() # Compile before we time anything.
with time_this():
    totals.append(numba_count2())

# Assert that all the returned values are identical.
assert np.isclose(totals, totals[0]).all()

相关问题