将图保存到numpy数组

vdgimpew  于 2023-05-17  发布在  其他
关注(0)|答案(8)|浏览(126)

在Python和Matplotlib中,可以很容易地将图显示为弹出窗口或将图保存为PNG文件。如何将图保存为RGB格式的numpy数组?

fruv7luv

fruv7luv1#

当您需要对保存的图进行像素到像素的比较时,这对于单元测试等是一个方便的技巧。
一种方法是使用fig.canvas.tostring_rgb,然后使用numpy.fromstring和适当的dtype。还有其他方法,但这是我倾向于使用的方法。
例如:

import matplotlib.pyplot as plt
import numpy as np

# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)

# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()

# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
u1ehiz5o

u1ehiz5o2#

@JUN_NETWORKS的答案有一个更简单的选项。可以使用其他格式,如rawrgba,而不是将数字保存为png,并跳过cv2解码步骤。
换句话说,实际的plot-to-numpy转换归结为:

io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                     newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()

霍普这个有用

rggaifut

rggaifut3#

有些人提出了一种方法,它是这样的

np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')

当然,这个代码工作。但是,输出的numpy数组图像分辨率很低。
我的提案代码是这样的。

import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)

x = np.linspace(-np.pi, np.pi)

ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax.plot(x, np.sin(x), label="sin")

ax.legend()
ax.set_title("sin(x)")

# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)

这段代码运行良好。
如果在dpi参数上设置了一个大的数字,则可以将高分辨率图像作为numpy数组。

kognpnkq

kognpnkq4#

是时候对您的解决方案进行基准测试了。

import io
import matplotlib
matplotlib.use('agg')  # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot(range(10))

def plot1():
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))

def plot2():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='png')
        buff.seek(0)
        im = plt.imread(buff)

def plot3():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='raw')
        buff.seek(0)
        data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在这种情况下,IO原始缓冲区是将matplotlib图形转换为numpy数组的最快方法。
补充说明:

  • 如果你没有一个访问的数字,你总是可以提取它从轴:

fig = ax.figure

  • 如果需要channel x height x width的数组,请执行

im = im.transpose((2, 0, 1))

i86rm4rw

i86rm4rw5#

MoviePy使得将数字转换为numpy数组变得非常简单。它有一个内置函数,名为mplfig_to_npimage()。你可以这样使用它:

from moviepy.video.io.bindings import mplfig_to_npimage
import matplotlib.pyplot as plt

fig = plt.figure()  # make a figure
numpy_fig = mplfig_to_npimage(fig)  # convert it to a numpy array
ubby3x7f

ubby3x7f6#

如果有人想要一个即插即用的解决方案,而不需要修改任何先前的代码(获得对pyplot图的引用和所有内容),下面的代码对我很有用。只需在所有pyplot语句后添加此语句即可。在pyplot.show()之前

canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))
a6b3iqyw

a6b3iqyw7#

正如Joe Kington所指出的,一种方法是在画布上绘制,将画布转换为字节字符串,然后将其重塑为正确的形状。

import matplotlib.pyplot as plt
import numpy as np
import math

plt.switch_backend('Agg')

def canvas2rgb_array(canvas):
    """Adapted from: https://stackoverflow.com/a/21940031/959926"""
    canvas.draw()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    ncols, nrows = canvas.get_width_height()
    scale = round(math.sqrt(buf.size / 3 / nrows / ncols))
    return buf.reshape(scale * nrows, scale * ncols, 3)

# Make a simple plot to test with
t = np.arange(0.0, 2.0, 0.01)
s = 1 + np.sin(2 * np.pi * t)
fig, ax = plt.subplots()
ax.plot(t, s)

# Extract the plot as an array
plt_array = canvas2rgb_array(fig.canvas)
print(plt_array.shape)

但是,由于canvas.get_width_height()返回显示坐标中的宽度和高度,因此有时会出现缩放问题,在此答案中已解决。

at0kjp5o

at0kjp5o8#

Jonan Gueorguiev的回答:

with io.BytesIO() as io_buf:
  fig.savefig(io_buf, format='raw', dpi=dpi)
  image = np.frombuffer(io_buf.getvalue(), np.uint8).reshape(
      int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)

相关问题