8
votes

I need to train a Random Forest classifier using a 3-fold cross-validation. For each sample, I need to retrieve the prediction probability when it happens to be in the test set.

I am using scikit-learn version 0.18.dev0.

This new version adds the feature to use the method cross_val_predict() with an additional parameter method to define which kind of prediction require from the estimator.

In my case I want to use the predict_proba() method, which returns the probability for each class, in a multiclass scenario.

However, when I run the method, I get as a result the matrix of prediction probabilities, where each rows represents a sample, and each column represents the prediction probability for a specific class.

The problem is that the method does not indicate which class corresponds to each column.

The value I need is the same (in my case using a RandomForestClassifier) returned in the attribute classes_ defined as:

classes_ : array of shape = [n_classes] or a list of such arrays The classes labels (single output problem), or a list of arrays of class labels (multi-output problem).

which is needed by predict_proba() because in its documentation it is written that:

The order of the classes corresponds to that in the attribute classes_.

A minimal example is the following:

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_predict

clf = RandomForestClassifier()

X = np.random.randn(10, 10)
y = y = np.array([1] * 4 + [0] * 3 + [2] * 3)

# how to get classes from here?
proba = cross_val_predict(estimator=clf, X=X, y=y, method="predict_proba")

# using the classifier without cross-validation
# it is possible to get the classes in this way:
clf.fit(X, y)
proba = clf.predict_proba(X)
classes = clf.classes_
1
For cases where you won't have access to y?juanpa.arrivillaga
@juanpa.arrivillaga I have access to y but I don't know in which order the labels are sorted. I may speculate that they are sorted in ascending order but I am not completely sure.gc5
sorry, I don't have 0.18 yet so I can't test this, and it may be obvious, but doesn't the clf object now contain a classes_ attribute?juanpa.arrivillaga
Np. No, I tested it and it doesn't have the classes_ attribute. Maybe the classifier is copied inside the cross_val_predict() method and the original is not updated.gc5
You're right! It does.juanpa.arrivillaga

1 Answers

4
votes

Yes, they will be in sorted order; this is because DecisionTreeClassifier (which is the default base_estimator for RandomForestClassifier) uses np.unique to construct the classes_ attribute which returns the sorted unique values of the input array.