sklearn
estimators implement methods to make it easy for you to save relevant trained properties of an estimator. Some estimators implement __getstate__
methods themselves, but others, like the GMM
just use the base implementation which simply saves the objects inner dictionary:
def __getstate__(self):
try:
state = super(BaseEstimator, self).__getstate__()
except AttributeError:
state = self.__dict__.copy()
if type(self).__module__.startswith('sklearn.'):
return dict(state.items(), _sklearn_version=__version__)
else:
return state
The recommended method to save your model to disc is to use the pickle
module:
from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
pickle.dump(model,f)
However, you should save additional data so you can retrain your model in the future, or suffer dire consequences (such as being locked into an old version of sklearn).
From the documentation:
In order to rebuild a similar model with future versions of
scikit-learn, additional metadata should be saved along the pickled
model:
The training data, e.g. a reference to a immutable snapshot
The python source code used to generate the model
The versions of scikit-learn and its dependencies
The cross validation score obtained on the training data
This is especially true for Ensemble estimators that rely on the tree.pyx
module written in Cython(such as IsolationForest
), since it creates a coupling to the implementation, which is not guaranteed to be stable between versions of sklearn. It has seen backwards incompatible changes in the past.
If your models become very large and loading becomes a nuisance, you can also use the more efficient joblib
. From the documentation:
In the specific case of the scikit, it may be more interesting to use
joblib’s replacement of pickle
(joblib.dump
& joblib.load
), which is
more efficient on objects that carry large numpy arrays internally as
is often the case for fitted scikit-learn estimators, but can only
pickle to the disk and not to a string: