numpy 将3D阵列压缩为2D图像

0g0grzrc  于 2023-03-18  发布在  其他
关注(0)|答案(2)|浏览(170)

我有一个numpy数组中的图像堆栈。

image_stack.shape

产出(36,51,51)。我有36个51乘51像素的图像。基本上,这36个图像是更大图像的块。我需要重新排列这些块以形成一个9乘4块的大图像,其中堆叠的前四个图像将在大图像的第一行,接下来的四个图像将在第二行,等等。我试图搞砸重塑,但我想我只是没有看到正确的技巧或正确理解它。谢谢

qzwqbdag

qzwqbdag1#

TLDR,用途:

out = arr.reshape(9, 4, 51, 51).swapaxes(1, 2).reshape(9 * 51, 4 * 51)

你需要一个整形和移调的组合。考虑下面的小例子:

import numpy as np

arr = np.multiply.outer(np.arange(1, 7), np.ones((3, 3), dtype=int))

它看起来像:

array([[[1, 1, 1],
        [1, 1, 1],
        [1, 1, 1]],

       [[2, 2, 2],
        [2, 2, 2],
        [2, 2, 2]],

       [[3, 3, 3],
        [3, 3, 3],
        [3, 3, 3]],

       [[4, 4, 4],
        [4, 4, 4],
        [4, 4, 4]],

       [[5, 5, 5],
        [5, 5, 5],
        [5, 5, 5]],

       [[6, 6, 6],
        [6, 6, 6],
        [6, 6, 6]]])

假设我们想把补丁排列并连接成一个3 x 2的补丁图像,我们可以这样做:

out = arr.reshape(3, 2, 3, 3).swapaxes(1, 2).reshape(9, 6)

输出:

array([[1, 1, 1, 2, 2, 2],
       [1, 1, 1, 2, 2, 2],
       [1, 1, 1, 2, 2, 2],
       [3, 3, 3, 4, 4, 4],
       [3, 3, 3, 4, 4, 4],
       [3, 3, 3, 4, 4, 4],
       [5, 5, 5, 6, 6, 6],
       [5, 5, 5, 6, 6, 6],
       [5, 5, 5, 6, 6, 6]])

说明:

# H = patch height
# W = patch width
# X = grid height (i.e. number of patches to put along the first axis)
# Y = grid width (i.e. number of patches to put along the second axis)

X, Y, H, W = 3, 2, 3, 3

out = (
    arr                     # (X Y) H W
    .reshape(X, Y, H, W)    #  X Y H W
    .swapaxes(1, 2)         #  X H Y W     
    .reshape(X * H, Y * W)  # (X H) (Y W)
)

# Note that reshaping can only ever "regroup" axes,
# and that transposing can only ever reorder "groups" of axes.
h7wcgrx3

h7wcgrx32#

使用numpy.block,您可以通过将子数组平铺到2D中来构造所需的数组,并使用所需的列表形式。

[
    [image_stack[0], image_stack[1], image_stack[2], image_stack[3]],
    [image_stack[4], image_stack[5], image_stack[6], image_stack[7]],
    ...
    [image_stack[32], image_stack[33], image_stack[34], image_stack[35]],
]

其可以使用一些基本的列表解析来构造:

target_shape = (9, 4)

result = np.block(
    [
        [
            image_stack[col + target_shape[1]*row, ...]
            for col in range(target_shape[1])
        ]
        for row in range(target_shape[0])
    ]
)
print(result.shape)  # (459, 204) == (9*51, 4*51)

相关问题