I have some python code to train a network using Tensorflow's TFRecords and Dataset APIs. I have built the network using tf.Keras.layers, this being arguably the easiest and fastest way. The handy function model_to_estimator()
modelTF = tf.keras.estimator.model_to_estimator(
keras_model=model,
custom_objects=None,
config=run_config,
model_dir=checkPointDirectory
)
converts a Keras model to an estimator, which allows us to take advantage of the Dataset API nicely, and automatically save checkpoints to checkPointDirectory during training, and upon training completion. The estimator API presents some invaluable features, such as automatically distributing the workload over multiple GPUs, with, e.g.
distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)
Now for big models and lots of data, it is often useful to execute predictions after training using some form of saved model. It seems that as of Tensorflow 1.10 (see https://github.com/tensorflow/tensorflow/issues/19295), a tf.keras.model object supports load_weights() from a Tensorflow checkpoint. This is mentioned briefly in the Tensorflow docs, but not the Keras docs, and I can't find anyone showing an example of this. After defining the model layers again in some new .py, I have tried
checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)
but this gives a NotImplementedError:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.
2018-10-01 14:24:49.912087:
Traceback (most recent call last):
File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
model.load_weights(filepath=checkPointPath, by_name=False)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
checkpointable_utils.streaming_restore(status=status, session=session)
File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
"Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.
I would like to do as suggested by the Warning and use the 'object-based saver' instead, but I haven't found a way to do this via a RunConfig passed to estimator.train().
So is there a better way to get the saved weights back into an estimator for use in prediction? The github thread seems to suggest that this is already implemented (though based on the error, probably in a different way than I am attempting above). Has anyone successfully used load_weights() on a TF checkpoint? I haven't been able to find any tutorials/examples on how this can be done, so any help is appreciated.