3
votes

The pandas dataframe has the following form, whereYear is the index:

      A:Cat1  A:Cat2  B:Cat1  B:Cat2  B:Cat3
Year                                        
1977     0.5    0.25    0.15     0.1     0.1
1981     0.2     NaN    0.40     0.1     0.2
1983     0.1    0.10    0.30     0.2     0.3

The important thing is that you have the same categories Cat1 and Cat2 in two different "super-categories", A and B. To plot the variations of all the categories, I use a stacked graph and use two different set of colors for each super-category. All those colors are saved in the list colors.

What I am doing right now to draw the graph is (plt is pyplot):

plt.stackplot(data.index.values,data.fillna(0).T.values,colors=colors,labels=data.columns.values)
plt.legend(loc="best")

This gives the following with the previous data:

Result of previous code

Now, what I would like to do is to avoid repeating the super-categories A and B in the legend, either by creating two distinct legends for each of the supercategories, or by having subheadings inside the same legend. I looked at this other question concerning subheadings, but the point is that I would like to be able to specify the breaking point between the two columns of the legend, so just specifying ncol=2 does not work because it does not break at the right point since I don't have the same number of categories in each «supercategory».

1

1 Answers

2
votes

Maybe try to add a placeholder to deal with unequal number of categories in each supercategory. Or use a horizontal group labels:

import io

import pandas as pd
import matplotlib.lines as mlines
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt

s = """1977 0.5 0.25    0.15    0.1 0.1
1981    0.2 NaN 0.40    0.1 0.2
1983    0.1 0.10    0.30    0.2 0.3"""
t = [('A', 'Cat1'), 
     ('A', 'Cat2'), 
     ('B', 'Cat1'), 
     ('B', 'Cat2'), 
     ('B', 'Cat3')]
index = pd.MultiIndex.from_tuples(t)
df = pd.read_table(io.StringIO(s), names=index)
df.index.name = 'Year'
colors = ['b', 'c', 'k', 'g', 'w']
plt.stackplot(df.index.values,df.fillna(0).T.values,colors=colors)

ha = mlines.Line2D([], [], marker='None', linestyle='None')
hb = mlines.Line2D([], [], marker='None', linestyle='None')
ha1 = mpatches.Patch(color=colors[0], ec='k')
ha2 = mpatches.Patch(color=colors[1], ec='k')
hb1 = mpatches.Patch(color=colors[2], ec='k')
hb2 = mpatches.Patch(color=colors[3], ec='k')
hb3 = mpatches.Patch(color=colors[4], ec='k')
hblank = mpatches.Patch(visible=False)
l1 = plt.legend([ha, ha1, ha2, hblank, hb, hb1, hb2, hb3], 
                ['A', 'Cat1', 'Cat2', '', 'B', 'Cat1', 'Cat2', 'Cat3'], 
                loc=2, ncol=2) # Two columns, vertical group labels
l2 = plt.legend([ha, hblank, hb, hblank, hblank, ha1, ha2, hb1, hb2, hb3], 
                ['A', '', 'B', '', '', 'Cat1', 'Cat2', 'Cat1', 'Cat2', 'Cat3'], 
                loc=4, ncol=2) # Two columns, horizontal group labels

ax = plt.gca()
ax.add_artist(l1)
ax.get_xaxis().get_major_formatter().set_useOffset(False)
plt.show()

enter image description here