I'm currently trying to figure out how I can load a saved H2O MOJO model and use it on a Spark DataFrame without needing Sparkling Water. The approach I am trying to use is to load up a h2o-genmodel.jar
file when Spark starts up, and then use then use PySpark's Py4J
interface to access it. My concrete question will be about how access the values generated by the py4j.java_gateway
objects.
Below is a minimal example:
Train model
import h2o
from h2o.estimators.random_forest import H2ORandomForestEstimator
import pandas as pd
import numpy as np
h2o.init()
features = pd.DataFrame(np.random.randn(6,3),columns=list('ABC'))
target = pd.DataFrame(pd.Series(["cat","dog","cat","dog","cat","dog"]), columns=["target"])
df = pd.concat([features, target], axis=1)
df_h2o = h2o.H2OFrame(df)
rf = H2ORandomForestEstimator()
rf.train(["A","B","C"],"target",training_frame=df_h2o, validation_frame=df_h2o)
Save MOJO
model_path = rf.download_mojo(path="./mojo/", get_genmodel_jar=True)
print(model_path)
Load MOJO
from pyspark.sql import SparkSession
spark = SparkSession.builder.config("spark.jars", "/home/ec2-user/Notebooks/mojo/h2o-genmodel.jar").getOrCreate()
MojoModel = spark._jvm.hex.genmodel.MojoModel
EasyPredictModelWrapper = spark._jvm.hex.genmodel.easy.EasyPredictModelWrapper
RowData = spark._jvm.hex.genmodel.easy.RowData
mojo = MojoModel.load(model_path)
easy_model = EasyPredictModelWrapper(mojo)
Predict on a single row of data
r = RowData()
r.put("A", -0.631123)
r.put("B", 0.711463)
r.put("C", -1.332257)
score = easy_model.predictBinomial(r).classProbabilities
So, that far I have been able to get. Where I am having trouble is that I find it difficult to inpect what score
is giving back to me. print(score)
yields the following: <py4j.java_gateway.JavaMember at 0x7fb2e09b4e80>
. Presumably there must be a way to the actual generated values from this object, but how would I do that?
predict
function and then applyspark_df.map(predict)
- this won't work since there's no way to get these functions on the executors (Py4J only on driver) and they cannot simply be serialized. If you rewrote the above in Scala or Java it would work though. – Karl