In Python and Matplotlib, it is easy to either display the plot as a popup window or save the plot as a PNG file. How can I instead save the plot to a numpy array in RGB format?
5 Answers
This is a handy trick for unit tests and the like, when you need to do a pixel-to-pixel comparison with a saved plot.
One way is to use fig.canvas.tostring_rgb and then numpy.fromstring with the approriate dtype. There are other ways as well, but this is the one I tend to use.
E.g.
import matplotlib.pyplot as plt
import numpy as np
# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)
# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()
# Now we can save it to a numpy array.
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
There is a bit simpler option for @JUN_NETWORKS's answer. Instead of saving the figure in png, one can use other format, like raw or rgba and skip the cv2 decoding step.
In other words the actual plot-to-numpy conversion boils down to:
io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()
Hope, this helps.
Some people propose a method which is like this
np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
Ofcourse, this code work. But, output numpy array image is so low resolution.
My proposal code is this.
import io
import cv2
import numpy as np
import matplotlib.pyplot as plt
# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)
x = np.linspace(-np.pi, np.pi)
ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")
ax.plot(x, np.sin(x), label="sin")
ax.legend()
ax.set_title("sin(x)")
# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
buf = io.BytesIO()
fig.savefig(buf, format="png", dpi=dpi)
buf.seek(0)
img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
buf.close()
img = cv2.imdecode(img_arr, 1)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)
This code works well.
You can get a high-resolution image as a numpy array if you set a large number on the dpi argument.
In case somebody wants a plug and play solution, without modifying any prior code (getting the reference to pyplot figure and all), the below worked for me. Just add this after all pyplot statements i.e. just before pyplot.show()
canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))
Time to benchmark your solutions.
import io
import matplotlib
matplotlib.use('agg') # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np
fig, ax = plt.subplots()
ax.plot(range(10))
def plot1():
fig.canvas.draw()
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
def plot2():
with io.BytesIO() as buff:
fig.savefig(buff, format='png')
buff.seek(0)
im = plt.imread(buff)
def plot3():
with io.BytesIO() as buff:
fig.savefig(buff, format='raw')
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Under this scenario, IO raw buffers are the fastest to convert a matplotlib figure to a numpy array.
Additional remarks:
if you don't have an access to the figure, you can always extract it from the axes:
fig = ax.figureif you need the array in the
channel x height x widthformat, doim = im.transpose((2, 0, 1)).