numpy 了解4D ndarray上高级多维索引的行为

piztneat  于 2023-08-05  发布在  其他
关注(0)|答案(1)|浏览(130)

场景

我有一个4D ndarray,由多个3D图像/体素组成,尺寸为(体素,dim 1,dim 2,dim 3),比方说(12体素,96像素,96像素,96像素)。我的目标是从m个体素的体积的中间采样n个切片**的范围。
我已经查看了Numpy关于(高级)索引的文档,以及解释广播的this answer和解释numpy插入newaxisthis answer,但我仍然无法理解我的场景中的底层行为。

问题

最初,我试图通过使用以下代码一次性索引数组来实现上述目标:

import numpy as np

array = np.random.rand(12, 96, 96, 96)

n = 4
m_voxels = 6
samples_range = np.arange(0, m_voxels)

middle_slices = array.shape[1] // 2
middle_slices_range = np.arange(middle_slices - n // 2, middle_slices + n // 2)

samples_from_the_middle = array[samples_range, middle_slices_range, :, :]

字符串
我没有得到一个shape(6,4,96,96)数组,而是遇到了下面的IndexError:

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (6,) (4,)


当我尝试显式地或分两步索引数组时,它按预期工作:

explicit_indexing = array[0:6, 46:50, :, :]
temp = array[samples_range]
samples_from_the_middle = temp[:, middle_slices_range, :, :]
explicit_indexing.shape # output: (6, 4, 96, 96)
samples_from_the_middle.shape  # output: (6, 4, 96, 96)


或者,如本answer中所述,另一种方法是:

samples_from_the_middle = array[samples_range[:, np.newaxis], middle_slices_range, :, :]  
samples_from_the_middle.shape # output: (6, 4, 96, 96)


我有以下问题:
1.为什么np.arange方法无法产生预期的结果,而显式索引(带冒号)可以正常工作,即使我们实际上是用相同范围的整数进行索引?
1.为什么在第一个索引一维数组中添加newaxis似乎解决了这个问题?
如有任何见解,将不胜感激。

wz3gfoph

wz3gfoph1#

因此,numpy处理索引的方式不同,这取决于你是使用slicesmy_array[a:b]时创建的)还是numpy数组。一个有用的方法来思考它是cartesian products。看看这个demo:

In [1]: import numpy as np

In [2]: x = np.array([[1,2,3],[4,5,6],[7,8,9]])

In [3]: x
Out[3]:
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [4]: x[0:3, 0:3]
Out[4]:
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [5]: x[np.arange(3), np.arange(3)]
Out[5]: array([1, 5, 9])

字符串
请注意,当我们使用切片时,我们得到了您想要的输出。当我们使用numpy数组时,我们得到的是一个只有3个元素而不是9个元素的一维数组。为什么?这是因为切片被自动用于创建笛卡尔积。Python会自动为两个切片中所有可能的值对生成[0, 0], [0, 1], [0, 2], [1, 0], ...形式的索引。
当使用numpy数组进行索引时,情况并非如此。相反,数组被匹配元素。这意味着只创建了[0, 0], [1, 1], [2, 2]对,我们只得到了3个对角元素。这与numpy没有将一维数组视为正确的行或列向量有关,除非我们显式说明数组有多少行和列。当我们这样做时,我们启用numpy执行broadcasting,其中,本质上,数组沿着长度为1的轴“重复”。这让我们做一些事情

In [10]: x = np.array([1,2,3,4,5])

In [11]: y = np.array([6,7,8])

In [12]: from numpy import newaxis as nax

In [13]: x = x[:, nax]

In [14]: y = y[nax, :]

In [15]: x + y
Out[15]:
array([[ 7,  8,  9],
       [ 8,  9, 10],
       [ 9, 10, 11],
       [10, 11, 12],
       [11, 12, 13]])


在那里你可以看到我们得到了你在索引时所寻找的行为! x阵列中的每个元素与y阵列中的每个元素配对。
现在我们可以使用这些知识如下:

In [16]: x = np.array([[1,2,3],[4,5,6],[7,8,9]])

In [17]: x
Out[17]:
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [18]: x[0:3, 0:3]
Out[18]:
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [19]: x[np.arange(3), np.arange(3)]
Out[19]: array([1, 5, 9])

In [20]: x[np.arange(3)[:, nax], np.arange(3)[nax, :]]
Out[20]:
array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])


我们完了!
为了完整起见,请注意numpy.ix_函数的存在正是为了解决这个问题。下面是一个例子:

In [21]: x = np.array([1,2,3,4,5])

In [22]: y = np.array([6,7,8])

In [23]: x, y = np.ix_(x,y)

In [24]: x
Out[24]:
array([[1],
       [2],
       [3],
       [4],
       [5]])

In [25]: y
Out[25]: array([[6, 7, 8]])


最后,所有这些都等同于使用numpy.meshgrid函数,该函数显式创建包含xy元素的所有可能配对的数组。但是,您不想使用它来索引,因为显式地同时创建这些配对并将它们保存在RAM中,这会非常浪费内存。最好让numpy为你施展魔法。

相关问题