16
votes

Lately I'm planning to migrate my standalone python ML code to spark. The ML pipeline in spark.ml turns out quite handy, with streamlined API for chaining up algorithm stages and hyper-parameter grid search.

Still, I found its support for one important feature obscure in existing documents: caching of intermediate results. The importance of this feature arise when the pipeline involves computation intensive stages.

For example, in my case I use a huge sparse matrix to perform multiple moving averages on time series data in order to form input features. The structure of the matrix is determined by some hyper-parameter. This step turns out to be a bottleneck for the entire pipeline because I have to construct the matrix in runtime.

During parameter search, I usually have other parameters to examine other than this "structure parameter". So if I can reuse the huge matrix when the "structure parameter" is unchanged, I can save tons of time. For this reason, I intentionally formed my code to cache and reuse these intermediate results.

So my question is: can Spark's ML pipeline handle intermediate caching automatically? Or do I have to manually form code to do so? If so, is there any best practice to learn from?

P.S. I have looked into the official document and some other material, but none of them seems to discuss this topic.

1

1 Answers

11
votes

So I ran into the same problem and the way I solved is that I have implemented my own PipelineStage, that caches the input DataSet and returns it as it is.

import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.util.{DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType

class Cacher(val uid: String) extends Transformer with DefaultParamsWritable {
  override def transform(dataset: Dataset[_]): DataFrame = dataset.toDF.cache()

  override def copy(extra: ParamMap): Transformer = defaultCopy(extra)

  override def transformSchema(schema: StructType): StructType = schema

  def this() = this(Identifiable.randomUID("CacherTransformer"))
}

To use it then you would do something like this:

new Pipeline().setStages(Array(stage1, new Cacher(), stage2))