2
votes

I have a LinearRegression model trained on historical data, now I am trying to re-use the same model on the new data to make predictions.

I know that we can save and load model using model.save and LinearRegression.load methods respectively, however, I am unable to find a way to pass in new data to loaded model for predictions.

Code for creating and training model is pasted below:

val assembler = new VectorAssembler().setInputCols(Array("total", "connected", "c_403", "c_480", "c_503", "hour", "day_of_week")).setOutputCol("features")

val output = assembler.transform(df).select($"label", $"features")
val Array(training, test) = output.select("label", "features").randomSplit(Array(0.7, 0.3), seed = 12)

val lr = new LinearRegression()

val paramGrid = new ParamGridBuilder().addGrid(lr.regParam, Array(0.1, 0.01)).addGrid(lr.fitIntercept).addGrid(lr.elasticNetParam, Array(0.0, 0.25, 0.5, 0.75, 1.0)).build()


val trainvalSplit = new TrainValidationSplit().setEstimator(lr).setEvaluator(new RegressionEvaluator()).setEstimatorParamMaps(paramGrid).setTrainRatio(0.75)
val model = trainvalSplit.fit(training)

val holdout = model.transform(test).select("prediction","label")
1
You want to retrain your model with new data ?eliasah
Hmm actually i would like to use this trained model to make predictions on new dataWaqas
How are you saving your model ? I don't see that in the code you are sharingeliasah
like this: model.save("/home/waqas/models/lreg")Waqas

1 Answers

3
votes

Ok that's straight forward actually but you'll need to use the TrainValidationSplitModel and not the LinearRegressionModel to load your model considering that you have saved your model with model.save("/home/waqas/models/lreg") :

scala> import org.apache.spark.ml.tuning.TrainValidationSplitModel

scala> val model2 = TrainValidationSplitModel.load("/home/waqas/models/lreg")
// model2: org.apache.spark.ml.tuning.TrainValidationSplitModel = tvs_99887a2f788d

scala> model2.transform(newData).show(3)
// +-----+--------------------+--------------------+
// |label|            features|          prediction|
// +-----+--------------------+--------------------+
// |  0.0|(692,[121,122,123...| 0.11220528529664375|
// |  0.0|(692,[122,123,148...|  0.1727599038728312|
// |  0.0|(692,[123,124,125...|-0.09619225628995537|
// +-----+--------------------+--------------------+
// only showing top 3 rows