5
votes

I have some python code to train a network using Tensorflow's TFRecords and Dataset APIs. I have built the network using tf.Keras.layers, this being arguably the easiest and fastest way. The handy function model_to_estimator()

modelTF = tf.keras.estimator.model_to_estimator(
    keras_model=model,
    custom_objects=None,
    config=run_config,
    model_dir=checkPointDirectory
)

converts a Keras model to an estimator, which allows us to take advantage of the Dataset API nicely, and automatically save checkpoints to checkPointDirectory during training, and upon training completion. The estimator API presents some invaluable features, such as automatically distributing the workload over multiple GPUs, with, e.g.

distribution = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(train_distribute=distribution)

Now for big models and lots of data, it is often useful to execute predictions after training using some form of saved model. It seems that as of Tensorflow 1.10 (see https://github.com/tensorflow/tensorflow/issues/19295), a tf.keras.model object supports load_weights() from a Tensorflow checkpoint. This is mentioned briefly in the Tensorflow docs, but not the Keras docs, and I can't find anyone showing an example of this. After defining the model layers again in some new .py, I have tried

checkPointPath = os.path.join('.', 'tfCheckPoints', 'keras_model.ckpt.index')
model.load_weights(filepath=checkPointPath, by_name=False)

but this gives a NotImplementedError:

Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.

2018-10-01 14:24:49.912087:
Traceback (most recent call last):
  File "C:/Users/User/PycharmProjects/python/mercury.classifier reductions/V3.2/wikiTestv3.2/modelEvaluation3.2.py", line 141, in <module>
    model.load_weights(filepath=checkPointPath, by_name=False)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\keras\engine\network.py", line 1526, in load_weights
    checkpointable_utils.streaming_restore(status=status, session=session)
  File "C:\Users\User\Anaconda3\lib\site-packages\tensorflow\python\training\checkpointable\util.py", line 880, in streaming_restore
    "Streaming restore not supported from name-based checkpoints. File a "
NotImplementedError: Streaming restore not supported from name-based checkpoints. File a feature request if this limitation bothers you.

I would like to do as suggested by the Warning and use the 'object-based saver' instead, but I haven't found a way to do this via a RunConfig passed to estimator.train().

So is there a better way to get the saved weights back into an estimator for use in prediction? The github thread seems to suggest that this is already implemented (though based on the error, probably in a different way than I am attempting above). Has anyone successfully used load_weights() on a TF checkpoint? I haven't been able to find any tutorials/examples on how this can be done, so any help is appreciated.

2
did you figure this out? could you share the solution if you have?Srikar Appalaraju
I ended up having to work around this. I just used Keras without the estimator API. The nightly build of TF allows passing a distribution to Keras's compile function, e.g. here. This allows one to easily use GPUs in parallel without having to mix Keras models with TF estimators.nadlr

2 Answers

2
votes

I'm not sure but maybe you can change keras_model.ckpt.index to keras_model.ckpt for test.

0
votes

You can create a separate graph, load your checkpoint normally and then transfer weights to your Keras model:

_graph = tf.Graph()
_sess = tf.Session(graph=_graph)

tf.saved_model.load(_sess, ['serve'], '../tf1_save/')

_weights_all, _bias_all = [], []
with _graph.as_default():
  for idx, t_var in enumerate(tf.trainable_variables()):
    # substitue variable_scope with your scope
    if 'variable_scope/' not in t_var.name: break
    
    print(t_var.name)
    val = _sess.run(t_var)
    _weights_all.append(val) if idx % 2 == 0 else _bias_all.append(val)

for layer, (weight, bias) in enumerate(zip(_weights_all, _bias_all)):
  self.model.layers[layer].set_weights([np.array(weight), np.array(bias)])