I have found a simple way to do that, and posting it here in case someone will find it useful:
- Load the Estimator checkpoint.
- Create a placeholder and a copy of the model graph, under a new name scope.
- Extract all trainable variables under the two scopes.
- Create
assign
ops for every variable.
Code:
# Load the trained model from checkpoint.
new_saver = tf.train.import_meta_graph('{}.meta'.format(config.ckpt_fullpath))
new_saver.restore(sess, config.ckpt_fullpath)
# Create new graph with a placeholder for input.
new_model_scope = 'new_scope'
trained_model_scope = 'old_scope' # this should be taken from the original model function of the estimator.
with tf.name_scope(new_model_scope):
model = Model(config)
input_tensor = tf.placeholder(tf.float32,
[None, config.img_size[0], config.img_size[1], 3])
model.build_model(input_tensor)
# Initialize the new graph variables with trained parameters.
trained_params = [t for t in tf.trainable_variables()
if t.name.startswith(trained_model_scope)]
trained_params = sorted(trained_params, key=lambda v: v.name)
new_params = [t for t in tf.trainable_variables()
if t.name.startswith(new_model_scope)]
new_params = sorted(new_params, key=lambda v: v.name)
update_ops = []
for trained_v, new_v in zip(trained_params, new_params):
op = new_v.assign(trained_v)
update_ops.append(op)
sess.run(update_ops)