移动Numpy二维数组中行的位置

v7pvogib  于 2022-12-23  发布在  其他
关注(0)|答案(3)|浏览(156)

给定数组:

np.array([[1, 2],
          [3, 4],
          [5, 6],
          [7, 8],
          [9, 10]])

如果我想将索引1处的行移动到索引3处,则输出应为:

[[1, 2],
 [5, 6],
 [7, 8],
 [3, 4],
 [9, 10]]

如果我想将索引4处的行移动到索引1处,则输出应为:

[[1, 2],
 [9, 10],
 [3, 4],
 [5, 6],
 [7, 8]]

执行此移动操作的最快方法是什么?

mjqavswn

mjqavswn1#

如果你仔细观察,如果你想把行i放在位置j,那么只有在ij之间的行会受到影响;外部的行不需要改变。并且这个改变基本上是roll操作。对于a,b,c,d,e项,将i=1处的项放置到j=3意味着b,c,d将变成c,d,b,得到a,c,d,b,e。移位是-1还是+1取决于i<j

i, j = 1,3
i, j, s = (i, j, -1) if i<j else (j, i, 1)
arr[i:j+1] = np.roll(arr[i:j+1],shift=s,axis=0)
xqnpmsa8

xqnpmsa82#

第一个轴上的tuple()索引如何?
例如:

arr[(0, 2, 3, 1, 4), :]

以及:

arr[(0, 4, 1, 2, 3), :]

分别为您的预期输出。
对于从两个索引开始生成索引的方法,您可以使用以下方法:

def inner_roll(arr, first, last, axis):
    stop = last + 1
    indices = list(range(arr.shape[axis]))
    indices.insert(first, last)
    indices.pop(last + 1)
    slicing = tuple(
        slice(None) if i != axis else indices
        for i, d in enumerate(arr.shape))
    return arr[slicing]

对于沿操作轴方向相对较小的输入(如问题中的输入),这是相当快的。
将它与@Mercury的答案的一个稍微改进的版本进行比较,以便将它 Package 在一个函数中,并使它对任意axis都能正确工作:

import numpy as np

def inner_roll2(arr, first, last, axis):
    if first > last:
        first, last = last, first
        shift = 1
    else:
        shift = -1
    slicing = tuple(
        slice(None) if i != axis else slice(first, last + 1)
        for i, d in enumerate(arr.shape))
    arr[slicing] = np.roll(arr[slicing], shift=shift, axis=axis)
    return arr

并获得一些计时:

funcs = inner_roll, inner_roll2
for n in (5, 50, 500):
    for m in (2, 20, 200):
        arr = np.arange(n * m).reshape((n, m))
        print(f'({n:<3d}, {m:<3d})', end='    ')
        for func in funcs:
            results = %timeit -o -q func(arr, 1, 2, 0)
            print(f'{func.__name__:>12s}  {results.best* 1e6:>7.3f} µs', end='    ')
        print()
# (5  , 2  )      inner_roll    5.613 µs     inner_roll2   15.393 µs    
# (5  , 20 )      inner_roll    5.592 µs     inner_roll2   15.468 µs    
# (5  , 200)      inner_roll    5.916 µs     inner_roll2   15.815 µs    
# (50 , 2  )      inner_roll   10.117 µs     inner_roll2   15.517 µs    
# (50 , 20 )      inner_roll   10.360 µs     inner_roll2   15.505 µs    
# (50 , 200)      inner_roll   12.067 µs     inner_roll2   15.886 µs    
# (500, 2  )      inner_roll   55.833 µs     inner_roll2   15.409 µs    
# (500, 20 )      inner_roll   57.364 µs     inner_roll2   15.319 µs    
# (500, 200)      inner_roll  194.408 µs     inner_roll2   15.731 µs

这表明inner_roll()是处理输入的最快方法,然而,inner_roll2()似乎可以更好地适应输入大小,即使对于中等大小的输入,inner_roll2()也比inner_roll()快。
请注意,inner_roll()创建拷贝时,inner_roll2()就地工作(修改输入arr)。可以通过在inner_roll2()的主体的开始处添加arr = arr.copy()来修改该行为,这将使该函数变慢(当然),并且其计时将更多地受到m的值(非滚动轴的大小)的影响。
另一方面,如果您要执行多个连续的滚动操作,inner_roll2()的时间就会累加起来,而对于inner_roll(),您只需要执行一次代价高昂的操作。

f8rj6qna

f8rj6qna3#

我喜欢@Mercury的解决方案,但发现避免与np.roll()相关的数组复制并完全就地操作会更快:

if index < index2:
      a[index:index2], a[index2] = a[index + 1:index2 + 1], a[index].copy()
    elif index2 < index:
      a[index2 + 1:index + 1], a[index2] = a[index2:index], a[index].copy()

相关问题