2
votes

While training a convolutional neural network following this article, the accuracy of the training set increases too much while the accuracy on the test set settles.

Below is an example with 6400 training examples, randomly chosen at each epoch (so some examples might be seen at the previous epochs, some might be new), and 6400 same test examples.

For a bigger data set (64000 or 100000 training examples), the increase in training accuracy is even more abrupt, going to 98 on the third epoch.

I also tried using the same 6400 training examples each epoch, just randomly shuffled. As expected, the result is worse.

epoch 3  loss 0.54871 acc 79.01 
learning rate 0.1
nr_test_examples 6400    
TEST epoch 3  loss 0.60812 acc 68.48 
nr_training_examples 6400
tb 91
epoch 4  loss 0.51283 acc 83.52 
learning rate 0.1
nr_test_examples 6400
TEST epoch 4  loss 0.60494 acc 68.68 
nr_training_examples 6400
tb 91
epoch 5  loss 0.47531 acc 86.91 
learning rate 0.05
nr_test_examples 6400
TEST epoch 5  loss 0.59846 acc 68.98 
nr_training_examples 6400
tb 91
epoch 6  loss 0.42325 acc 92.17 
learning rate 0.05
nr_test_examples 6400
TEST epoch 6  loss 0.60667 acc 68.10 
nr_training_examples 6400
tb 91
epoch 7  loss 0.38460 acc 95.84 
learning rate 0.05
nr_test_examples 6400
TEST epoch 7  loss 0.59695 acc 69.92 
nr_training_examples 6400
tb 91
epoch 8  loss 0.35238 acc 97.58 
learning rate 0.05
nr_test_examples 6400
TEST epoch 8  loss 0.60952 acc 68.21

This is my model (I'm using RELU activation after each convolution):

conv 5x5 (1, 64)
max-pooling 2x2
dropout
conv 3x3 (64, 128)
max-pooling 2x2
dropout
conv 3x3 (128, 256)
max-pooling 2x2
dropout
conv 3x3 (256, 128)
dropout
fully_connected(18*18*128, 128)
dropout
output(128, 128)

What could be the cause?

I'm using Momentum Optimizer with learning rate decay:

    batch = tf.Variable(0, trainable=False)

    train_size = 6400

    learning_rate = tf.train.exponential_decay(
      0.1,                # Base learning rate.
      batch * batch_size,  # Current index into the dataset.
      train_size*5,          # Decay step.
      0.5,                # Decay rate.
      staircase=True)
    # Use simple momentum for the optimization.
    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                         0.9).minimize(cost, global_step=batch)
1
Is there a reason why you are trying to reproduce this article? Its probably best you try out a few of the "classical cases" like mnist, and cifar, before going to a specific paper.Guilherme de Lazari

1 Answers

3
votes

This is very much expected. This problem is called over-fitting. This is when your model starts "memorizing" the training examples without actually learning anything useful for the Test set. In fact, this is exactly why we use a test set in the first place. Since if we have a complex enough model we can always fit the data perfectly, even if not meaningfully. The test set is what tells us what the model has actually learned.

Its also useful to use a Validation set which is like a test set, but you use it to find out when to stop training. When the Validation error stops lowering you stop training. why not use the test set for this? The test set is to know how well your model would do in the real world. If you start using information from the test set to choose things about your training process, than its like your cheating and you will be punished by your test error no longer representing your real world error.

Lastly, convolutional neural networks are notorious for their ability to over-fit. It has been shown the Conv-nets can get zero training error even if you shuffle the labels and even random pixels. That means that there doesn't have to be a real pattern for the Conv-net to learn to represent it. This means that you have to regularize a conv-net. That is, you have to use things like Dropout, batch normalization, early stopping.

I'll leave a few links if you want to read more:

Over-fitting, validation, early stopping https://elitedatascience.com/overfitting-in-machine-learning

Conv-nets fitting random labels: https://arxiv.org/pdf/1611.03530.pdf (this paper is a bit advanced, but its interresting to skim through)

P.S. to actually improve your test accuracy you will need to change your model or train with data augmentation. You might want to try transfer learning as well.