1
votes

I have 3 folders for my CNN model which are train_data, val_data, test_data.

when I am training my model, I found that the accuracy may vary and sometimes the last epoch does not show the best accuracy. for example, last epoch accuracy is 71% but I found the better accuracy in the earlier epoch. I want to save the checkpoint of that epoch which has higher accuracy and then use that checkpoint to predict my model on test_data

I trained my model on train_data and predicted on val_data and save the checkpoint of the model like below:

    print("{} Saving checkpoint of model...". format(datetime.now()))
    checkpoint_path = os.path.join(checkpoint_dir, 'model_epoch' + str(epoch) + '.ckpt')
    save_path = saver.save(session, checkpoint_path)

and before starting the tf.Session() I have this line:

saver = tf.train.Saver()

I want to know how can I save the best epoch which has higher accuracy and then use this checkpoint for my test_data?

2
By the way, seeing your code, you should not add '.ckpt' to your path as you are doing. You should only specify the directory. - nairouz mrabah
A note: from what I understand, by doing this you are actually training on the validation data. The is no guarantee that it will perform better on test data. - geometrikal
@geometrikal, You are right. There is no guarantee that it performs well on my test data but the aim of validation set is to find the best hyperparameters for your model and then use those hyperparameters on the test data. that's why I want to save the best hyperparameters of my model and then use them on test - user2975921
Out of interest, do you use learning rate decay? I have found that can both improve and stabilise accuracy. - geometrikal
I am using AdamOptimizer with fix learning rate - user2975921

2 Answers

0
votes

The tf.train.Saver() documentation describes the following:

saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

Note that if you pass global_step to the saver you will generate checkpoint files that contain the global step number. I generally save checkpoints every X minutes and then come back and review the results and choose a checkpoint at the appropriate step value. If you're using tensorboard you'll find this intuitive since all your graphs can be displayed by global step as well.

https://www.tensorflow.org/api_docs/python/tf/train/Saver

0
votes

You can use CheckpointSaverListener.

from __future__ import print_function
import tensorflow as tf
import os
from sacred import Experiment

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data

ex = Experiment('test-07-05-2018')    

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
checkpoint_path = "/tmp/checkpoints/"

class ExampleCheckpointSaverListener(CheckpointSaverListener):
    def begin(self):
       print('Starting the session.')
       self.prev_accuracy = 0
       self.acc = 0

   def after_save(self, session, global_step_value):
       print('Only keep this checkpoint if it is better than the previous one')
       self.acc = acc 
       if self.acc <  self.prev_accuracy :
            os.remove(tf.train.latest_checkpoint())
       else:
            self.prev_accuracy = self.acc

   def end(self, session, global_step_value):
       print('Done with the session.')

@ex.config
def my_config():
pass

@ex.automain
def main():
      #build the graph of vanilla multiclass logistic regression
      x = tf.placeholder(tf.float32, [None, 784])
      y = tf.placeholder(tf.float32, [None, 10]) 
      W = tf.Variable(tf.zeros([784, 10]))
      b = tf.Variable(tf.zeros([10]))
      y_pred = tf.nn.softmax(tf.matmul(x, W) + b) #
      loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=1))
      optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
      init = tf.global_variables_initializer()
      y_pred_cls = tf.argmax(y_pred, dimension=1)
      y_true_cls = tf.argmax(y, dimension=1)
      correct_prediction = tf.equal(y_pred_cls, y_true_cls)
      accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
      saver = tf.train.Saver()
      listener = ExampleCheckpointSaverListener()
      saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir, listeners=[listener])
      with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]) as sess:
          sess.run(init)
          for epoch in range(25):
              avg_loss = 0.
              total_batch = int(mnist.train.num_examples/100)
              # Loop over all batches
              for i in range(total_batch):
                  batch_xs, batch_ys = mnist.train.next_batch(100)
                  _, l, acc = sess.run([optimizer, loss, accuracy], feed_dict={x: batch_xs, y: batch_ys})
                  avg_loss += l / total_batch
                  saver.save(sess, checkpoint_path)