0
votes

I am having trouble using an h2o model (in mojo format) on a Spark cluster, but only when I try to run it in parallel, not when I use collect and run it on the driver.

Since the dataframe I am predicting on has > 100 features, I am using the following function to convert dataframe rows to RowData format for h2o (from here):

def rowToRowData(df: DataFrame, row: Row): RowData = {
  val rowAsMap = row.getValuesMap[Any](df.schema.fieldNames)
  val rowData = rowAsMap.foldLeft(new RowData()) { case (rd, (k,v)) =>
    if (v != null) { rd.put(k, v.toString) }
    rd
  }
  rowData
}

Then, I import the mojo model and create an easyPredictModel wrapper

val mojo = MojoModel.load("/path/to/mojo.zip")
val easyModel = new EasyPredictModelWrapper(mojo)

Now, I can make predictions on my dataframe (df) by mapping over the rows if I collect it first, so the following works:

val predictions = df.collect().map { r =>
  val rData = rowToRowData(df, r) . // convert row to RowData using function
  val prediction = easyModel.predictBinomial(rData).label
  (r.getAs[String]("id"), prediction.toInt)
  }
  .toSeq
  .toDF("id", "prediction")

However, I wish to do this in parallel on the cluster since the final df will be too large to collect on the driver. But if I try to run the same code without collecting first:

val predictions = df.map { r =>
  val rData = rowToRowData(df, r)
  val prediction = easyModel.predictBinomial(rData).label
  (r.getAs[String]("id"), prediction.toInt)
}
  .toDF("id", "prediction")

I get the following errors:

18/01/03 11:34:59 WARN TaskSetManager: Lost task 0.0 in stage 118.0 (TID 9914, 213.248.241.182, executor 0): java.lang.ClassCastException: cannot assign instance of scala.collection.immutable.List$SerializationProxy to field org.apache.spark.rdd.RDD.org$apache$spark$rdd$RDD$$dependencies_ of type scala.collection.Seq in instance of org.apache.spark.rdd.MapPartitionsRDD
    at java.io.ObjectStreamClass$FieldReflector.setObjFieldValues(ObjectStreamClass.java:2133)
    at java.io.ObjectStreamClass.setObjFieldValues(ObjectStreamClass.java:1305)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2024)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
    at java.io.ObjectInputStream.defaultReadFields(ObjectInputStream.java:2018)
    at java.io.ObjectInputStream.readSerialData(ObjectInputStream.java:1942)
    at java.io.ObjectInputStream.readOrdinaryObject(ObjectInputStream.java:1808)
    at java.io.ObjectInputStream.readObject0(ObjectInputStream.java:1353)
    at java.io.ObjectInputStream.readObject(ObjectInputStream.java:373)
    at org.apache.spark.serializer.JavaDeserializationStream.readObject(JavaSerializer.scala:75)
    at org.apache.spark.serializer.JavaSerializerInstance.deserialize(JavaSerializer.scala:114)
    at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:80)
    at org.apache.spark.scheduler.Task.run(Task.scala:108)
    at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:335)
    at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1142)
    at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:617)
    at java.lang.Thread.run(Thread.java:745)

So it looks like a datatype mismatch. I have tried converting the dataframe to an rdd first (i.e. df.rdd.map, but get the same errors), doing df.mapPartition, or placing the rowToData function code within the map, but nothing has worked so far.

Any ideas on the best way to achieve this?

2

2 Answers

0
votes

I've found some messy Spark ticket https://issues.apache.org/jira/browse/SPARK-18075 describing the same problem related to different ways of submitting Spark application. Take a look, maybe it'll give you a clue about your problem.

0
votes

You can't call prediction.toInt. The prediction returned is a tuple. You need to extract the second element of that tuple to get the actual score for level 1. I have a complete example here: https://stackoverflow.com/a/47898040/9120484