Briefly, I put in place a data input pipline using tensorflow Dataset API. Then, I implemented a CNN model for classification using keras, which i converted to an estimator. I feeded my estimator Train and Eval Specs with my input_fn providing input data for training and evaluation. And as final step I launched the model training with tf.estimator.train_and_evaluate
def my_input_fn(tfrecords_path):
dataset = (...)
return batch_fbanks, batch_labels
def build_model():
model = tf.keras.models.Sequential()
model.add(...)
model.compile(...)
return model
model = build_model()
run_config=tf.estimator.RunConfig(model_dir,save_summary_steps=100,save_checkpoints_steps=1000)
estimator = tf.keras.estimator.model_to_estimator(model,config=run_config)
def serving_input_receiver_fn():
inputs = {'Conv1_input': tf.compat.v1.placeholder(shape=[None, 11,120,1], dtype=tf.float32)}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
exporter = tf.estimator.BestExporter(serving_input_receiver_fn, name="best_exporter", exports_to_keep=5)
train_spec_dnn = tf.estimator.TrainSpec(input_fn = lambda: my_input_fn(train_data_path),hooks=[hook])
eval_spec_dnn = tf.estimator.EvalSpec(input_fn = lambda: my_eval_input_fn(eval_data_path),exporters=exporter,start_delay_secs=0,throttle_secs=15)
tf.estimator.train_and_evaluate(estimator, train_spec_dnn, eval_spec_dnn)
I save the 5 best checkpoints using the tf.estimator.BestExporter
as shown above. Once i finished training, i want to reload the best model and convert it to an estimator to re-evaluate the model and predict on new dataset. However my issue is in restoring the checkpoint to an estimator. I tried several solutions but each time i don't get the estimator object I need to run its evaluate
and predict
methods.
Just to specify more, each of the best checkpoints directory is organised as follow:
./
variables/
variables.data-00000-of-00002
variables.data-00001-of-00002
variables.index
saved_model.pb
So the question is how can I get an estimator object from the best checkpoint so that i can use it to evaluate my model and predict on new data?
Note : I found some proposed solutions relying on TensorFlow v1 features which can not solve my problem because i work with TF v2.
Thanks a lot, any help is appreciated.