You can use CheckpointSaverListener.
from __future__ import print_function
import tensorflow as tf
import os
from sacred import Experiment
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
ex = Experiment('test-07-05-2018')
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
checkpoint_path = "/tmp/checkpoints/"
class ExampleCheckpointSaverListener(CheckpointSaverListener):
def begin(self):
print('Starting the session.')
self.prev_accuracy = 0
self.acc = 0
def after_save(self, session, global_step_value):
print('Only keep this checkpoint if it is better than the previous one')
self.acc = acc
if self.acc < self.prev_accuracy :
os.remove(tf.train.latest_checkpoint())
else:
self.prev_accuracy = self.acc
def end(self, session, global_step_value):
print('Done with the session.')
@ex.config
def my_config():
pass
@ex.automain
def main():
#build the graph of vanilla multiclass logistic regression
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y_pred = tf.nn.softmax(tf.matmul(x, W) + b) #
loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(y_pred), reduction_indices=1))
optimizer = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
init = tf.global_variables_initializer()
y_pred_cls = tf.argmax(y_pred, dimension=1)
y_true_cls = tf.argmax(y, dimension=1)
correct_prediction = tf.equal(y_pred_cls, y_true_cls)
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
saver = tf.train.Saver()
listener = ExampleCheckpointSaverListener()
saver_hook = tf.train.CheckpointSaverHook(checkpoint_dir, listeners=[listener])
with tf.train.MonitoredTrainingSession(chief_only_hooks=[saver_hook]) as sess:
sess.run(init)
for epoch in range(25):
avg_loss = 0.
total_batch = int(mnist.train.num_examples/100)
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(100)
_, l, acc = sess.run([optimizer, loss, accuracy], feed_dict={x: batch_xs, y: batch_ys})
avg_loss += l / total_batch
saver.save(sess, checkpoint_path)