6
votes

In scikit learn I have a model (in my case a Linear model)

clf = linear_model.LinearRegression()

I can train this model with some data

clf.fit(x1,y1)

But if I call again fit it will continue training the model.

clf.fit(x2,y2)

Now clf is a model trained with both (x1,y1) and (x2,y2)

If I want to start training from 0, I can create again the model by redefining clf

clf = linear_model.LinearRegression()
clf.fit(x1,y1)
# save the model
# ...
clf = linear_model.LinearRegression()
clf.fit(x2,y2)

However I don't want to define clf again:

Basically the type of regressor is chosen before, something like:

if params.linear_algorithm == 'least_squares':
    clf = linear_model.LinearRegression()
elif params.linear_algorithm == 'ridge':
    clf = linear_model.Ridge()
elif params.linear_algorithm == 'lasso':
    clf = linear_model.Lasso()

So I don't want inside my train function to redefine clf with all the conditional block, instead I just want to take clf, clean it from previous trainings and reuse it to train another set of data.

Does clf have a method to clean what has learned so far, so when I call clf.fit(x2,y2) is only trained on this data?

EDIT: You guys are right, the training is overwriten everytime.

My problem is that I'm saving the model in a dictionary, and it just take the reference to clf, so each time clf is retrained all previous saves are changed.

Redefining clf everytime creates a new object so each save points now so a different model

Example

for i in range(3):
   # get the x and y
   # ...
   clf.fit(x,y)
   model[i] = clf

Any idea how to save every time a different model instead of pointing all model[i] to the same clf?

2
Is it related to this question?: stackoverflow.com/questions/32916255/…JARS
unless warm_start is True the call to fit resets the classifier according to the docs.Sreeram TP

2 Answers

13
votes

Your assumption is wrong. According to the Scikit-Learn docs:

Calling fit() more than once will overwrite what was learned by any previous fit().

You can therefore use your code safely and it will achieve what you need.

1
votes

I am pretty sure it overwrites any existing information from before. Scikit Learn docs specify that. Unless you use warm_start = True, fit() calls will overwrite existing data.