对numpy数组多次应用索引

inn6fuwd  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(173)

我有一个大小为N的数组A和另一个数组B,数组B的排列索引是数组A的,现在我想对数组A多次应用这个排列,有没有一个好的有效的numpy命令?
再现性示例:

import numpy as np
np.random.seed(121421)

# Example Array
A = np.random.uniform(0, 100, (100, 3))

# Permutation indices
B = np.random.choice(np.arange(A.shape[0]), A.shape[0], replace=False)

# Permute many times
npermut = 10000
C = np.array(B)
for n in range(npermut-1):
    C = C[B]
print(A[C])

谢谢大家!

klr1opcd

klr1opcd1#

我不确定是否有任何numpy函数可以直接解决这个问题。但是,通过数学分析可以极大地提高性能。您所做的是计算排列s的n次幂。假设n = 2^m,则
s^n = (((s^2)^2)...)^2
其中括号中有m个项,因此我们将索引操作的数量减少到O(log(n))。如果n不是2的幂,则我们写
n = a_1*2^1 + a_2*2^2 + ... + a_m*2^m
其中m是满足2^m〈n的最大整数,序列(a_1,...,a_m)是n的二进制逆表示,则
s^n = g_m(s^(2^m))...g_1(s^(2^1))
其中,如果a_m = 1,则g_m(x)= x,否则g_m(x)= 1.最多有m+(m-1)+...+1 = m *(m +1)/2次索引操作,即索引操作的数量按O(m^2)= O(2log(n))= O(log(n))的比例缩放.
下面的代码说明了上述思想:

def _perm_power_exp(perm, M):
    """
    Computes perm^(2^M) = ((perm^2)^2)...^2
    """
    for _ in range(M):
        perm = perm[perm]
    return perm

def perm_power(perm, N):
    """
    Fast computation of perm^N
    """
    if N == 1:
        return perm
    
    bin_rep = [int(i) for i in str(bin(N))[2:]] # binary representation of N
    
    perms = []
    for i, j in enumerate(reversed(bin_rep)):
        # j corresponds to a_{i+1} in the above explanation
        if j == 1:
            perms.append(_perm_power_exp(perm, i))
    
    perm = perms[0]
    for perm_i in perms[1:]:
        perm = perm_i[perm]
    
    return perm

为了检查正确性并测试速度的提高,我还将定义函数

def perm_power_naive(perm, N):
    new_perm = perm.copy()
    for _ in range(N-1):
        new_perm = new_perm[perm]
    return new_perm

请注意,这个函数与您当前正在执行的操作完全相同,因此您只需用新的优化代码替换代码的这一部分。
如果我跑了

D, N = 100, 1000
for _ in range(1000):
    B = np.random.choice(D, D, replace=False)
    C1 = perm_power(B, N)
    C2 = perm_power_naive(B, N)
    assert (C1 == C2).all()

没有出现Assert错误,因此表明优化后的代码(很可能)工作正常。我尝试过D和N的其他组合,Assert总是通过。
我还在jupyter notebook中运行了以下代码来测试性能:

D, N = 100, 100000
B = np.random.choice(D, D, replace=False)
%timeit perm_power(B, N)
%timeit perm_power_naive(B, N)

结果是

18.1 µs ± 18 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
21.2 ms ± 310 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

如您所见,当N = 100000时,性能增益约为1000倍。

相关问题