2
votes

I am making a scatter plot for a dataset that looks like this:

x = [1, 1, 2, 2, 3, 3, 4, 4]
y = [1, 2, 3, 4, 1, 2, 3, 4]
labels = [1, 3, 0, 2, 2, 1, 0, 3]

colors = np.array(plt.rcParams['axes.prop_cycle'].by_key()['color'])

plt.scatter(x, y, color=colors[labels])

If I call plt.legend, only one entry will be shown, for the entire dataset, with the first symbol. How do I create a legend with all four elements in it, displayed as if I had plotted four separate datasets?

Probably relevant: Matplotlib histogram with multiple legend entries
Based on: Matplotlib, how to loop?

3

3 Answers

3
votes

You can plot empty lists for the set of labels:

for l in set(labels):
    plt.scatter([],[], color=colors[l], label=l)
plt.legend()
3
votes

I think the overall simplest solution is to delegate all the work to matplotlib. The method is described here: https://matplotlib.org/gallery/lines_bars_and_markers/scatter_with_legend.html#automated-legend-creation. For this simplified approach, using PathCollection's legend_elements method is all you need:

s = plt.scatter(x, y, c=labels)
plt.legend(*s.legend_elements())

enter image description here

It's trivial to change the colormap or replace the labels with something else, like text:

text_labels = ['one', 'two', 'three', 'four']
s = plt.scatter(x, y, c=labels, cmap='jet', vmin=0, vmax=4)
plt.legend(s.legend_elements()[0], text_labels)

enter image description here

If labels is not already a sorted array of elements in the range [0-n), it can easily be obtained using np.unique:

labels = ['b', 'd', 'a', 'c', 'c', 'b', 'a', 'd']
text_labels, labels = np.unique(labels, return_inverse=True)
1
votes

Just another way of achieving the desired result. I get rid of the duplicates using this solution

for i, j, l in zip(x, y, labels):
    plt.scatter(i, j, c=colors[l], label=l)

handles, labels = plt.gca().get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())

enter image description here