2
votes

I have made a simple scatterplot using matplotlib showing data from 2 numerical variables (varA and varB) with colors that I defined with a 3rd categorical string variable (col) containing 10 unique colors (corresponding to another string variable with 10 unique names), all in the same Pandas DataFrame with 100+ rows. Is there an easy way to create a legend for this scatterplot that shows the unique colored dots and their corresponding category names? Or should I somehow group the data and plot each category in a subplot to do this? This is what I have so far:

import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

varA = df['A']
varB = df['B'] 
col = df['Color']

plt.scatter(varA,varB, c=col, alpha=0.8)
plt.legend()

plt.show()
2

2 Answers

1
votes

Considering, Color is the column that has all the colors and labels, you can simply do following.

colors = list(df['Color'].unique())
for i in range(0 , len(colors)):
    data = df.loc[df['Color'] == colors[i]]
    plt.scatter('A', 'B', data=data, color='Color', label=colors[i])
plt.legend()
plt.show()
0
votes

A simple way is to group your data by color, then plot all of the data on one plot. Pandas has a built in groupby function. For example:

import matplotlib.pyplot as plt
from matplotlib import colors as mcolors

for color, group in df.groupby(['Color']):
    plt.scatter(group['A'], group['B'], c=color, alpha=0.8, label=color)

plt.legend()
plt.show()

Notice that we call plt.scatter once for each grouping of data. Then we only need to call plt.legend and plt.show once all of the data is in our plot.