我有一个大小为N的数组A和另一个数组B,数组B的排列索引是数组A的,现在我想对数组A多次应用这个排列,有没有一个好的有效的numpy命令?
再现性示例:
import numpy as np
np.random.seed(121421)
# Example Array
A = np.random.uniform(0, 100, (100, 3))
# Permutation indices
B = np.random.choice(np.arange(A.shape[0]), A.shape[0], replace=False)
# Permute many times
npermut = 10000
C = np.array(B)
for n in range(npermut-1):
C = C[B]
print(A[C])
谢谢大家!
1条答案
按热度按时间klr1opcd1#
我不确定是否有任何numpy函数可以直接解决这个问题。但是,通过数学分析可以极大地提高性能。您所做的是计算排列s的n次幂。假设n = 2^m,则
s^n = (((s^2)^2)...)^2
其中括号中有m个项,因此我们将索引操作的数量减少到O(log(n))。如果n不是2的幂,则我们写
n = a_1*2^1 + a_2*2^2 + ... + a_m*2^m
其中m是满足2^m〈n的最大整数,序列(a_1,...,a_m)是n的二进制逆表示,则
s^n = g_m(s^(2^m))...g_1(s^(2^1))
其中,如果a_m = 1,则g_m(x)= x,否则g_m(x)= 1.最多有m+(m-1)+...+1 = m *(m +1)/2次索引操作,即索引操作的数量按O(m^2)= O(2log(n))= O(log(n))的比例缩放.
下面的代码说明了上述思想:
为了检查正确性并测试速度的提高,我还将定义函数
请注意,这个函数与您当前正在执行的操作完全相同,因此您只需用新的优化代码替换代码的这一部分。
如果我跑了
没有出现Assert错误,因此表明优化后的代码(很可能)工作正常。我尝试过D和N的其他组合,Assert总是通过。
我还在jupyter notebook中运行了以下代码来测试性能:
结果是
如您所见,当N = 100000时,性能增益约为1000倍。