I am having trouble using an h2o model (in mojo format) on a Spark cluster, but only when I try to run it in parallel, not when I use collect
and run it on the driver.
Since the dataframe I am predicting on has > 100 features, I am using the following function to convert dataframe rows to RowData format for h2o (from here):
def rowToRowData(df: DataFrame, row: Row): RowData = {
val rowAsMap = row.getValuesMap[Any](df.schema.fieldNames)
val rowData = rowAsMap.foldLeft(new RowData()) { case (rd, (k,v)) =>
if (v != null) { rd.put(k, v.toString) }
rd
}
rowData
}
Then, I import the mojo model and create an easyPredictModel wrapper
val mojo = MojoModel.load("/path/to/mojo.zip")
val easyModel = new EasyPredictModelWrapper(mojo)
Now, I can make predictions on my dataframe (df
) by mapping over the rows if I collect it first, so the following works:
val predictions = df.collect().map { r =>
val rData = rowToRowData(df, r) . // convert row to RowData using function
val prediction = easyModel.predictBinomial(rData).label
(r.getAs[String]("id"), prediction.toInt)
}
.toSeq
.toDF("id", "prediction")
However, I wish to do this in parallel on the cluster since the final df will be too large to collect on the driver. But if I try to run the same code without collecting first:
val predictions = df.map { r =>
val rData = rowToRowData(df, r)
val prediction = easyModel.predictBinomial(rData).label
(r.getAs[String]("id"), prediction.toInt)
}
.toDF("id", "prediction")
I get the following errors:
18/01/03 11:34:59 WARN TaskSetManager: Lost task 0.0 in stage 118.0 (TID 9914, 213.248.241.182, executor 0): java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2133)
at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1305)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2024)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:80)
at org.apache.spark.scheduler.Task.run(Task.scala:108)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:335)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
at java.lang.Thread.run(Thread.java:745)
So it looks like a datatype mismatch. I have tried converting the dataframe to an rdd first (i.e. df.rdd.map
, but get the same errors), doing df.mapPartition
, or placing the rowToData
function code within the map, but nothing has worked so far.
Any ideas on the best way to achieve this?