1
votes

I have a trouble with "regress" API on TensorFlow serving server. Please see gist link below to read more comfortably. https://gist.github.com/krikit/32c918cc03b52315ade562267a91fa6b

I made a simple keras model which got two inputs(x1, x2) and showed single output value(y). With this model, I got error results from TensorFlow serving server when I used "regress" REST API.

# the model
inputs = {
    'x1': tf.keras.layers.Input(shape=(1, ), name='x1', dtype='float32'),
    'x2': tf.keras.layers.Input(shape=(1, ), name='x2', dtype='float32'),
}
concat = tf.keras.layers.Concatenate(name='concat')([inputs['x1'], inputs['x2']])
dense = tf.keras.layers.Dense(10, use_bias=True, activation='relu', name='dense')(concat)
outputs = tf.keras.layers.Dense(1, use_bias=True, activation='sigmoid', name='y')(dense)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs)
model.compile(optimizer='SGD', loss='binary_crossentropy')
model.summary()


# training
num_exam = 10000
model.fit({'x1': np.random.randn(num_exam), 'x2': np.random.rand(num_exam)}, np.random.randn(num_exam))


# save
input_infos = {name: tf.saved_model.build_tensor_info(tensor) for name, tensor in model.input.items()}
output_infos = {'y': tf.saved_model.build_tensor_info(model.outputs[0])}
signature = tf.saved_model.build_signature_def(
    inputs=input_infos,
    outputs=output_infos,
    method_name=tf.saved_model.signature_constants.REGRESS_METHOD_NAME
)
print(signature)

model_dir = './random_regression/1'
shutil.rmtree(model_dir, ignore_errors=True)
model_builder = tf.saved_model.builder.SavedModelBuilder(model_dir)
model_builder.add_meta_graph_and_variables(
    tf.keras.backend.get_session(),
    tags=[tf.saved_model.tag_constants.SERVING, ],
    signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
)
model_builder.save()

After I saved model, it seemed to be all right from the output of "saved_model_cli" tool.

$ saved_model_cli show --dir ./random_regression/1 --all

MetaGraphDef with tag-set: 'serve' contains the following SignatureDefs:

signature_def['serving_default']:
  The given SavedModel SignatureDef contains the following input(s):
    inputs['x1'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: x1:0
    inputs['x2'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: x2:0
  The given SavedModel SignatureDef contains the following output(s):
    outputs['y'] tensor_info:
        dtype: DT_FLOAT
        shape: (-1, 1)
        name: y/Sigmoid:0
  Method name is: tensorflow/serving/regress

After I started serving server with the model, I tested REST API with "regress" method. But I got an error like below,

$ curl -X POST -H "Content-Type: application/json" http://localhost:8501/v1/models/random_regression/versions/1:regress -d '
{
  "examples": [
    {
      "x1": [0.1],
      "x2": [0.2]
    },
    {
      "x1": [0.1],
      "x2": [0.3]
    }
  ]
}'

Response:
{ "error": "Expected one input Tensor." }

Although I made a regress signature, predict API was also availble.

$ curl -X POST -H "Content-Type: application/json" http://localhost:8501/v1/models/random_regression/versions/1:predict -d '
{
  "instances": [
    {
      "x1": [0.1],
      "x2": [0.2]
    },
    {
      "x1": [0.1],
      "x2": [0.3]
    }
  ]
}'

Response:
{
    "predictions": [[0.143165469], [0.124352224]
    ]
}

The reason why I use the "regress" method is that I need the "context" field like below.

$ curl -X POST -H "Content-Type: application/json" http://localhost:8501/v1/models/random_regression/versions/1:regress -d '
{
  "context": {
    "x1": [0.1]
  },
  "examples": [
    {
      "x2": [0.2]
    },
    {
      "x2": [0.3]
    }
  ]
}'

Response:
{ "error": "Expected one input Tensor." }

I am very sorry for LONG~~~ question, but is there anybody who can help me, please?

1

1 Answers

0
votes

Just to make sure I am understanding the problem correct. When you used predict API it is fine, it does not give any errors but you need to use regress methods since you need context field?

Are you using TF 2.0?