I'm building a Scikit-learn model on Sagemaker.
I'd like to reference the data used in training in my predict_fn. (Instead of the indices returned from NNS, I'd like to return the names and data of each neighbor.)
I know this can be done by writing/reading from S3, as in https://aws.amazon.com/blogs/machine-learning/associating-prediction-results-with-input-data-using-amazon-sagemaker-batch-transform/ , but was wondering if there were more elegant solutions.
Are there other ways to make the data used in the training job available to the prediction function?
Edit: Using the advice from the accepted solution I was able to pass data as a dict.
model = nn.fit(train_data)
model_dict = {
"model": model,
"reference": train_data
}
joblib.dump(model_dict, path)
predict_fn:
def predict_fn(input_data, model_dict):
model = model_dict["model"]
reference = model_dict["reference"]