16
votes

I'm using TensorFlow to train a Convolutional Neural Network (CNN) for a sign language application. The CNN has to classify 27 different labels, so unsurprisingly, a major problem has been addressing overfitting. I've taken several steps to accomplish this:

  1. I've collected a large amount of high-quality training data (over 5000 samples per label).
  2. I've built a reasonably sophisticated pre-processing stage to help maximize invariance to things like lighting conditions.
  3. I'm using dropout on the fully-connected layers.
  4. I'm applying L2 regularization to the fully-connected parameters.
  5. I've done extensive hyper-parameter optimization (to the extent possible given HW and time limitations) to identify the simplest model that can achieve close to 0% loss on training data.

Unfortunately, even after all these steps, I'm finding that I can't achieve much better that about 3% test error. (It's not terrible, but for the application to be viable, I'll need to improve that substantially.)

I suspect that the source of the overfitting lies in the convolutional layers since I'm not taking any explicit steps there to regularize (besides keeping the layers as small as possible). But based on examples provided with TensorFlow, it doesn't appear that regularization or dropout is typically applied to convolutional layers.

The only approach I've found online that explicitly deals with prevention of overfitting in convolutional layers is a fairly new approach called Stochastic Pooling. Unfortunately, it appears that there is no implementation for this in TensorFlow, at least not yet.

So in short, is there a recommended approach to prevent overfitting in convolutional layers that can be achieved in TensorFlow? Or will it be necessary to create a custom pooling operator to support the Stochastic Pooling approach?

Thanks for any guidance!

1
Tensorflow does have max pool and avg pool for conv layers linkshekkizh
Thanks for the response, shekkizh. I'm currently using max pooling. I'm certainly not an expert in machine learning, but I don't believe either max or average pooling helps to reduce overfitting because they're both deterministic. Stochastic pooling is a novel idea that introduces randomness to the operation. The linked paper discusses it in detail, but I think the gist is that selecting the activations randomly simulates data augmentation, thereby reducing the chance of overfitting.Aenimated1
Consider section 3.2 of this paper (network in network). It is probably the fully connected layers that are over fitting. The shared weights in a convolutional layer are already a strong regulizer.user728291
Thanks user728291 - that was very helpful. Very interesting paper - I tried to read it in its entirety although I lack some of the background necessary to fully appreciate it. Anyway, I've actually done quite a bit of tuning to the L2 regularization and dropout without much success. I'm beginning to think the problem may actually be that my training data is too perfect. I have a huge amount of training data, but it was all collected under rather "sterile" conditions.Aenimated1
You should definitely try using some sort of normalization on the weights in your convolution filters. In my experience it seems to be needed there much more than in fully connected layers.Frobot

1 Answers

15
votes

How can I fight overfitting?

How can I improve my CNN?

Thoma, Martin. "Analysis and Optimization of Convolutional Neural Network Architectures." arXiv preprint arXiv:1707.09725 (2017).

See chapter 2.5 for analysis techniques. As written in the beginning of that chapter, you can usually do the following:

  • (I1) Change the problem definition (e.g., the classes which are to be distinguished)
  • (I2) Get more training data
  • (I3) Clean the training data
  • (I4) Change the preprocessing (see Appendix B.1)
  • (I5) Augment the training data set (see Appendix B.2)
  • (I6) Change the training setup (see Appendices B.3 to B.5)
  • (I7) Change the model (see Appendices B.6 and B.7)

Misc

The CNN has to classify 27 different labels, so unsurprisingly, a major problem has been addressing overfitting.

I don't understand how this is connected. You can have hundreds of labels without a problem of overfitting.