1
votes

I am using the ModelCheckpoint callback to monitor validation accuracy. However, sometimes random spikes in validation accuracy at the beginning of training produce false signal. Is there a callback that is only called after a certain epoch, eg. after epoch 100 when the "random spikes" have settled out? Thanks!

1

1 Answers

3
votes

You can just make your own callback:

from keras.callbacks import ModelCheckpoint
from keras.models import Sequential
from keras.layers import Dense

import numpy as np

# Subclass ModelCheckpoint
class MyModelCheckpoint(ModelCheckpoint):

    def __init__(self, *args, **kwargs):
        super(MyModelCheckpoint, self).__init__(*args, **kwargs)


    # redefine the save so it only activates after 100 epochs
    def on_epoch_end(self, epoch, logs=None):
        if epoch > 100: super(MyModelCheckpoint, self).on_epoch_end(epoch, logs)


# A simple example neural net
model = Sequential()
model.add(Dense(1, input_dim=5))
model.compile(loss='mse', optimizer='adam')

# Toy dataset
X = np.random.rand(5, 5)
y = np.random.rand(5, 1)

# Create checkpointer as you would with a regular ModelCheckpoint
checkpointer = MyModelCheckpoint(filepath='{epoch}.h5')

# Fit the model using it as a callback
model.fit(X, y, callbacks=[checkpointer], verbose=1, epochs=200)