Defining your own outputs is absolutely supported. A typical TensorFlow training program will:
- Build a training graph
- Train the model using that graph
- Build a prediction graph
- Export a SavedModel
This is exemplified, e.g., in this sample code.
When you build your prediction graph, you'll create placeholders for your inputs, something like:
with tf.Graph() as prediction_graph:
# dtypes can be anything
# First dimension of shape is "batch size" which must be None
# so the system can send variable-length batches. Beyond that,
# there are no other restrictions on shape.
x = tf.placeholder(dtype=tf.int32, shape=(None,))
y = tf.placeholder(dtype=tf.float32, shape=(None,))
z = build_prediction_graph(x, y)
saver = tf.train.Saver()
When you're exporting a SavedModel, you declare your inputs and outputs in what's called a "Signature"; in the process you give them friendly names (the keys in the dict), since TensorFlow does name mangling. These keys are what you use in your JSON when sending data and they are the keys in the JSON you get back in prediction.
For example:
# Define the inputs and the outputs.
inputs = {"x": tf.saved_model.utils.build_tensor_info(x),
"y": tf.saved_model.utils.build_tensor_info(y)}
outputs = {"z": tf.saved_model.utils.build_tensor_info(z)}
signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs,
outputs=outputs,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME
)
A hypothetical request to the service using that signature might look like:
{"instances": [{"x": 6, "y": 3.14}, {"x": 3, "y": 1.0}]}
With a hypothetical response looking like:
{"predictions": [{"z": [1, 2, 3]}, {"z": [4, 5, 6]}]}
Finally, you will need to actually save out the SavedModel:
with tf.Session(graph=prediction_graph) as session:
# Restore the most recently checkpointed variables from training
saver.restore(session, tf.train.latest_checkpoint(job_dir))
# Save the SavedModel
b = builder.SavedModelBuilder('/path/to/export')
b.add_meta_graph_and_variables(session, ['serving_default'], signature)
b.save()