2
votes

I have trained a model with tensorflow and used batch normalization during training. Batch normalization requires the user to pass a boolean, called is_training, to set whether the model is in training or testing phase.

When the model was trained, is_training was set as a constant as shown below

is_training = tf.constant(True, dtype=tf.bool, name='is_training')

I have saved the trained model, the files include checkpoint, .meta file, .index file, and a .data. I'd like to restore the model and run inference using it. The model can't be retrained. So, I'd like to restore the existing model, set the value of is_training to False and then save the model back. How can I edit the boolean value associated with that node, and save the model again?

1
it would have been easier if you used is_training=tf.Variable.. rather than constant - Ishant Mrinal
Is there a reason why is_training needs to be a tensorflow constant? Can't it be a python bool? Note that changing is_training to a python bool should not give errors in restoring the model. - GeertH
@GeertH It can be, the question is how do I set is_training to False after I load the model, then save it back. So that when it is restored again, the node has value False. - Effective_cellist

1 Answers

4
votes

You can use the input_map argument of tf.train.import_meta_graph to remap the graph tensor to a updated value.

config = tf.ConfigProto(allow_soft_placement=True)
with tf.Session(config=config) as sess:
    # define the new is_training tensor
    is_training = tf.constant(False, dtype=tf.bool, name='is_training')

    # now import the graph using the .meta file of the checkpoint
    saver = tf.train.import_meta_graph(
    '/path/to/model.meta', input_map={'is_training:0':is_training})

    # restore all weights using the model checkpoint 
    saver.restore(sess, '/path/to/model')

    # save updated graph and variables values
    saver.save(sess, '/path/to/new-model-name')