numpy 通过沿着最后一个轴对每第n个元素进行采样来构造数组

wlp8pajw  于 2022-11-29  发布在  其他
关注(0)|答案(3)|浏览(142)

a是一个NumPy数组(不一定是一维的),其中最后一个轴沿着有n * m个元素。我希望沿着最后一个轴“拆分”这个数组,这样我就可以从0开始,直到n,取每个n的元素。
明确地说,让a具有(k, n * m)的形状,那么我希望构造(n, k, m)形状的数组

np.array([a[:, i::n] for i in range(n)])

我问题是,尽管这确实返回了我所寻找的数组,但我仍然觉得可能有一个更高效、更简洁的NumPy例程。
干杯!干杯!

icnyk63a

icnyk63a1#

这是很难写一个更快的Numpy实现。一个有效的解决方案是使用Numba来加速。也就是说,内存访问模式可能是代码在相对较大的矩阵上运行缓慢的主要原因。因此,需要关注迭代顺序,以便访问可以相对缓存友好。此外,对于大型数组,使用多个线程是一个好主意,这样可以更好地减轻由于相对较高的内存延迟(由于内存访问模式)而产生的开销。

import numba as nb

# The first call is slower due to the build.
# Please consider specifying the signature of the function (ie. input types)
# to precompile the function ahead of time.
@nb.njit # Use nb.njit(parallel=True) for the parallel version
def compute(arr, n):
    k, m = arr.shape[0], arr.shape[1] // n
    assert arr.shape[1] == n * m

    out = np.empty((n, k, m), dtype=arr.dtype)

    # Use nb.prange for the parallel version
    for i2 in range(k):
        for i1 in range(n):
            outView = out[i1, i2]
            inView = a[i2]
            cur = i1
            for i3 in range(m):
                outView[i3] = inView[cur]
                cur += n

    return out

以下是我的计算机在配备i5- 9600 KF处理器(6核)的k=37n=42m=53a.dtype=np.int32上的测试结果:

John Zwinck's solution:    986.1 µs
Initial implementation:     91.7 µs
Sequential Numba:           62.9 µs
Parallel Numba:             14.7 µs
Optimal lower-bound:        ~7.0 µs
e5nszbig

e5nszbig2#

我认为这做了你想要的,没有循环。我测试了二维输入,它可能需要一些调整更多的维度。

indexes = np.arange(0, a.size*n, n) + np.repeat(np.arange(n), a.size/n)
np.take(a, indexes, mode='wrap').reshape(n, a.shape[0], -1)

在我的测试中,它比您原来的列表解决方案要慢一些。

6ojccjat

6ojccjat3#

如果我没记错的话,这是你所期望的,而且速度很快:

a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)

示例:

import numpy as np
k=5; n=3; m=4
a = np.arange(k*n*m).reshape(k, n*m)
a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0, 1)
"""
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11],
       [12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35],
       [36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59]])

is transformed into:

array([[[ 0,  3,  6,  9],
        [12, 15, 18, 21],
        [24, 27, 30, 33],
        [36, 39, 42, 45],
        [48, 51, 54, 57]],

       [[ 1,  4,  7, 10],
        [13, 16, 19, 22],
        [25, 28, 31, 34],
        [37, 40, 43, 46],
        [49, 52, 55, 58]],

       [[ 2,  5,  8, 11],
        [14, 17, 20, 23],
        [26, 29, 32, 35],
        [38, 41, 44, 47],
        [50, 53, 56, 59]]])
"""

时间安排:

from time import time
k=37; n=42; m=53
a = np.arange(k*n*m).reshape(k, n*m)

start = time()
for _ in range(1_000_000):
    res = a.reshape(k, m, n).swapaxes(1, 2).swapaxes(0,1)
time() - start

# 0.95 s per 1 mil repetitions

相关问题