numpy如何实现flatten?或一般的strided数组扁平化

roejwanj  于 2023-04-06  发布在  其他
关注(0)|答案(2)|浏览(148)

我试图理解当你执行一些操作时在python中发生了什么。例如,从this reply,我理解了strides是如何工作的,以及它是如何重要的。但是现在,我想知道,如果在转置之后,在内存中,数据没有被“物理”转置,当我在转置操作之后调用.flatten(order="C")时,数据是正确排序的。由于步幅,我知道它肯定是可以实现这个操作,不幸的是,我不能想出一个算法,适用于任何'转置'步幅。

import numpy as np

array = np.arange(24).reshape(2, 3, 4)
print(array.flatten(order='C'))
array = array.transpose(1, 0, 2)
print(array.flatten(order='C'))

>>> [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23]
>>> [ 0  1  2  3 12 13 14 15  4  5  6  7 16 17 18 19  8  9 10 11 20 21 22 23]
ukqbszuj

ukqbszuj1#

调用转置数组arrt,我们看到:

In [372]: arrt.shape, arrt.strides
Out[372]: ((3, 2, 4), (16, 48, 4))

所以步长,根据itemsize调整,是(4,12,1)。
ravel/flatten可以通过以下方式生成:

In [373]: res = np.zeros(arrt.size, int)
     ...: rcnt = 0
     ...: for i in range(0,3):
     ...:     for j in range(0,2):
     ...:         for k in range(0, 4):
     ...:             res[rcnt] = arrt.base[i*4+j*12+k*1]
     ...:             rcnt += 1
     ...:             

In [374]: res
Out[374]: 
array([ 0,  1,  2,  3, 12, 13, 14, 15,  4,  5,  6,  7, 16, 17, 18, 19,  8,
        9, 10, 11, 20, 21, 22, 23])

arrt.base是原始的np.arange。在最里面的循环(k)中,我们通过1步进底部,j循环通过12步进,外部通过4步进。

vof42yt1

vof42yt12#

您可以检查strides属性,以了解实际情况:

array = np.arange(24).reshape(2, 3, 4)
print(array.strides)                      # (48, 16, 4)
tmp = array.flatten(order='C')
print(tmp.strides)                        # (4,)
array = array.transpose(1, 0, 2)
print(array.strides)                      # (16, 48, 4)
tmp = array.flatten(order='C')
print(tmp.strides)                        # (4,)

正如我们所看到的,平坦化数组总是连续的,而(非平坦化)转置数组则不是。
实际上,Numpy试图永远不复制数据,除非你要求它这样做(或用于基本的out-of-place操作)。也就是说,在这种情况下,Numpy别无选择,只能创建目标数组的副本。事实上,步长总是沿着给定的轴均匀(根据设计),因此flatten转置数组必须是连续的

相关问题