Please take a look at the following code, I created a _validate function inside predict and use a custom Exception class.
Basically, I validate instances, before I call the model predict method and handle the exception.
There may be some overhead to the response time when doing this validation, which you need to test for your use case.
requests = [
"god this episode sucks",
"meh, I kinda like it",
"what were the writer thinking, omg!",
"omg! what a twist, who would'v though :o!",
99999
]
api = discovery.build('ml', 'v1')
parent = 'projects/{}/models/{}/versions/{}'.format(PROJECT, MODEL_NAME, VERSION_NAME)
parent = 'projects/{}/models/{}'.format(PROJECT, MODEL_NAME)
response = api.projects().predict(body=request_data, name=parent).execute()
{'predictions': [{'Error code': 1, 'Message': 'Invalid instance type'}]}
Custom Prediction class:
import os
import pickle
import numpy as np
import logging
from datetime import date
import tensorflow.keras as keras
class CustomModelPredictionError(Exception):
def __init__(self, code, message='Error found'):
self.code = code
self.message = message # you could add more args
def __str__(self):
return str(self.message)
def isstr(s):
return isinstance(s, str) or isinstance(s, bytes)
def _validate(instances):
for instance in instances:
if not isstr(instance):
raise CustomModelPredictionError(1, 'Invalid instance type')
return instances
class CustomModelPrediction(object):
def __init__(self, model, processor):
self._model = model
self._processor = processor
def _postprocess(self, predictions):
labels = ['negative', 'positive']
return [
{
"label":labels[int(np.round(prediction))],
"score":float(np.round(prediction, 4))
} for prediction in predictions]
def predict(self, instances, **kwargs):
try:
instances = _validate(instances)
except CustomModelPredictionError as c:
return [{"Error code": c.code, "Message": c.message}]
else:
preprocessed_data = self._processor.transform(instances)
predictions = self._model.predict(preprocessed_data)
labels = self._postprocess(predictions)
return labels
@classmethod
def from_path(cls, model_dir):
model = keras.models.load_model(
os.path.join(model_dir,'keras_saved_model.h5'))
with open(os.path.join(model_dir, 'processor_state.pkl'), 'rb') as f:
processor = pickle.load(f)
return cls(model, processor)
Complete code in this notebook.