numpy 更快的数字解决方案,而不是itertools,

cs7cruho  于 2022-12-26  发布在  其他
关注(0)|答案(3)|浏览(159)

我使用的itertools.combinations()如下所示:

import itertools
import numpy as np

L = [1,2,3,4,5]
N = 3

output = np.array([a for a in itertools.combinations(L,N)]).T

这就产生了我需要的输出:

array([[1, 1, 1, 1, 1, 1, 2, 2, 2, 3],
       [2, 2, 2, 3, 3, 4, 3, 3, 4, 4],
       [3, 4, 5, 4, 5, 5, 4, 5, 5, 5]])

我在多处理环境中反复地、过度地使用这个表达式,我需要它尽可能快。
通过this post,我了解到基于itertools的代码不是最快的解决方案,使用numpy可能是一种改进,但是我在numpy优化技巧方面不够好,无法理解和适应那里编写的迭代代码或提出自己的优化。
任何帮助都将不胜感激。
编辑:
L来自一个panda Dataframe ,因此它也可以看作一个numpy数组:

L = df.L.values
ogq8wdun

ogq8wdun1#

下面是一个比itertools UPDATE稍快的方法:另一个(nump2)实际上要快得多:

import numpy as np
import itertools
import timeit

def nump(n, k, i=0):
    if k == 1:
        a = np.arange(i, i+n)
        return tuple([a[None, j:] for j in range(n)])
    template = nump(n-1, k-1, i+1)
    full = np.r_[np.repeat(np.arange(i, i+n-k+1),
                           [t.shape[1] for t in template])[None, :],
                 np.c_[template]]
    return tuple([full[:, j:] for j in np.r_[0, np.add.accumulate(
        [t.shape[1] for t in template[:-1]])]])

def nump2(n, k):
    a = np.ones((k, n-k+1), dtype=int)
    a[0] = np.arange(n-k+1)
    for j in range(1, k):
        reps = (n-k+j) - a[j-1]
        a = np.repeat(a, reps, axis=1)
        ind = np.add.accumulate(reps)
        a[j, ind[:-1]] = 1-reps[1:]
        a[j, 0] = j
        a[j] = np.add.accumulate(a[j])
    return a

def itto(L, N):
    return np.array([a for a in itertools.combinations(L,N)]).T

k = 6
n = 12
N = np.arange(n)

assert np.all(nump2(n,k) == itto(N,k))

print('numpy    ', timeit.timeit('f(a,b)', number=100, globals={'f':nump, 'a':n, 'b':k}))
print('numpy 2  ', timeit.timeit('f(a,b)', number=100, globals={'f':nump2, 'a':n, 'b':k}))
print('itertools', timeit.timeit('f(a,b)', number=100, globals={'f':itto, 'a':N, 'b':k}))

时间:

k = 3, n = 50
numpy     0.06967267207801342
numpy 2   0.035096961073577404
itertools 0.7981023890897632

k = 3, n = 10
numpy     0.015058324905112386
numpy 2   0.0017436158377677202
itertools 0.004743851954117417

k = 6, n = 12
numpy     0.03546895203180611
numpy 2   0.00997065706178546
itertools 0.05292179994285107
bhmjp9jg

bhmjp9jg2#

这肯定比itertools.combinations * 不 * 快,但它 * 是 * 矢量化的numpy:

def nd_triu_indices(T,N):
    o=np.array(np.meshgrid(*(np.arange(len(T)),)*N))
    return np.array(T)[o[...,np.all(o[1:]>o[:-1],axis=0)]]

 %timeit np.array(list(itertools.combinations(T,N))).T
The slowest run took 4.40 times longer than the fastest. This could mean that an intermediate result is being cached.
100000 loops, best of 3: 8.6 µs per loop

%timeit nd_triu_indices(T,N)
The slowest run took 4.64 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 52.4 µs per loop

不确定这是否可以用另一种方式矢量化,或者这里的某个优化向导是否可以使这个方法更快。
编辑:想到了另一种方法,但仍然不比combinations快:

%timeit np.array(T)[np.array(np.where(np.fromfunction(lambda *i: np.all(np.array(i)[1:]>np.array(i)[:-1], axis=0),(len(T),)*N,dtype=int)))]
The slowest run took 7.78 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 34.3 µs per loop
5jvtdoz2

5jvtdoz23#

我知道这个问题很老了,但我最近一直在研究它,它仍然可能会有帮助。从我(相当广泛)的测试中,我发现首先生成每个索引的组合,然后使用这些索引对数组进行切片,比直接从数组中进行组合要快得多。我确信使用@Paul Panzer的nump2函数生成这些索引会更快。
下面是一个例子:

import numpy as np
from math import factorial
import itertools as iters
from timeit import timeit
from perfplot import show

def combinations_iter(array:np.ndarray, r:int = 3) -> np.ndarray:
    return np.array([*iters.combinations(array, r = r)], dtype = array.dtype)

def combinations_iter_idx(array:np.ndarray, r:int = 3) -> np.ndarray:
    n_items = array.shape[0]
    num_combinations = factorial(n_items)//(factorial(n_items-r)*factorial(r))
    combination_idx = np.fromiter(
        iters.chain.from_iterable(iters.combinations(np.arange(n_items, dtype = np.int64), r = r)),
        dtype = np.int64,
        count = num_combinations*r,
    ).reshape(-1,r)
    return array[combination_idx]

show(
    setup = lambda n: np.random.uniform(0,100,(n,3)),
    kernels = [combinations_iter, combinations_iter_idx],
    labels = ['pure itertools', 'itertools for index'],
    n_range = np.geomspace(5,300,10, dtype = np.int64),
    xlabel = "n",
    logx = True,
    logy = False,
    equality_check = np.allclose,
    show_progress = True,
    max_time = None,
    time_unit = "ms",
)

显然,索引方法要快得多。

相关问题