import numba as nb
# The first call is slower due to the build.
# Please consider specifying the signature of the function (ie. input types)
# to precompile the function ahead of time.
@nb.njit # Use nb.njit(parallel=True) for the parallel version
def compute(arr, n):
k, m = arr.shape[0], arr.shape[1] // n
assert arr.shape[1] == n * m
out = np.empty((n, k, m), dtype=arr.dtype)
# Use nb.prange for the parallel version
for i2 in range(k):
for i1 in range(n):
outView = out[i1, i2]
inView = a[i2]
cur = i1
for i3 in range(m):
outView[i3] = inView[cur]
cur += n
return out
from time import time
k=37; n=42; m=53
a = np.arange(k*n*m).reshape(k, n*m)
start = time()
for _ in range(1_000_000):
res = a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0,1)
time() - start
# 0.95 s per 1 mil repetitions
3条答案
按热度按时间icnyk63a1#
这是很难写一个更快的Numpy实现。一个有效的解决方案是使用Numba来加速。也就是说,内存访问模式可能是代码在相对较大的矩阵上运行缓慢的主要原因。因此,需要关注迭代顺序,以便访问可以相对缓存友好。此外,对于大型数组,使用多个线程是一个好主意,这样可以更好地减轻由于相对较高的内存延迟(由于内存访问模式)而产生的开销。
以下是我的计算机在配备i5- 9600 KF处理器(6核)的
k=37
、n=42
、m=53
和a.dtype=np.int32
上的测试结果:e5nszbig2#
我认为这做了你想要的,没有循环。我测试了二维输入,它可能需要一些调整更多的维度。
在我的测试中,它比您原来的列表解决方案要慢一些。
6ojccjat3#
如果我没记错的话,这是你所期望的,而且速度很快:
示例:
时间安排: