如何消除循环并使用numpy reshape?

h7appiyu  于 2023-04-21  发布在  其他
关注(0)|答案(2)|浏览(87)

我有以下代码

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image as im

B = 9
pic = np.array(im.open('portrait.png'))
pic = pic[pic.shape[0] % B:, pic.shape[1] % B:, :3]
reduced, _i, _j = [], 0, 0
for i in range(B, pic.shape[0] + B, B):
    for j in range(B, pic.shape[1] + B, B):
        pix = pic[_i:i, _j:j].reshape(B * B, 3).mean(axis=0)
        reduced.append(pix.astype(np.uint8))
        _j = j
    _i, _j = i, 0
reduced = np.array(reduced)
reduced = reduced.reshape((pic.shape[0] // B, pic.shape[1] // B, 3))
plt.title(reduced.shape)
plt.imshow(reduced)
plt.show()

它本质上是通过B对图像的B块进行迭代,并通过取其平均值将它们变成一个像素。
所以这个

变成了这样,

我认为这可以通过改变形状来实现,沿着某个轴取平均值,然后再次改变形状,但我不确定如何消除这个循环。

p4tfgftt

p4tfgftt1#

您可以在图像中引入BxB块作为额外的轴,然后沿着这些轴取平均值:

import numpy as np
from matplotlib import pyplot as plt
from PIL import Image as im

B = 9
pic = np.array(im.open('portrait.png'))

h, w, c = pic.shape
H, W = h - h % B, w - w % B

reduced = (
    pic[-H:,-W:]                             # crop to multiple of block size
        .reshape((H // B, B, W // B, B, c))  # split y and x axes into blocks
        .mean(axis=(1,3))                    # take the mean
        .astype(np.uint8)                    # convert from float
)

plt.title(reduced.shape)
plt.imshow(reduced)
plt.show()
slmsl1lt

slmsl1lt2#

如果您不介意使用额外的库,可以使用einops以一种直接和自我解释的方式实现它,如下所示:

import einops

img = ...
B=9

im_reduced = einops.reduce(img, "(h Bh) (w Bw) C -> h w C", "mean", Bh=B, Bw=B)

相关问题