I am using the ModelCheckpoint
callback to monitor validation accuracy. However, sometimes random spikes in validation accuracy at the beginning of training produce false signal. Is there a callback that is only called after a certain epoch, eg. after epoch 100 when the "random spikes" have settled out? Thanks!
1
votes
1 Answers
3
votes
You can just make your own callback:
from keras.callbacks import ModelCheckpoint
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
# Subclass ModelCheckpoint
class MyModelCheckpoint(ModelCheckpoint):
def __init__(self, *args, **kwargs):
super(MyModelCheckpoint, self).__init__(*args, **kwargs)
# redefine the save so it only activates after 100 epochs
def on_epoch_end(self, epoch, logs=None):
if epoch > 100: super(MyModelCheckpoint, self).on_epoch_end(epoch, logs)
# A simple example neural net
model = Sequential()
model.add(Dense(1, input_dim=5))
model.compile(loss='mse', optimizer='adam')
# Toy dataset
X = np.random.rand(5, 5)
y = np.random.rand(5, 1)
# Create checkpointer as you would with a regular ModelCheckpoint
checkpointer = MyModelCheckpoint(filepath='{epoch}.h5')
# Fit the model using it as a callback
model.fit(X, y, callbacks=[checkpointer], verbose=1, epochs=200)