My input function looks like this:
def input_fn():
dataset.repeat(epochs).batch(16)
estimator_model.train(input_fn, steps)
How can I notify my model that this is n-th repetition (epoch) over dataset? I would like to implement things like decaying learning rate, training model without adversarial loss for first n epochs etc. I am using tf.data.Dataset and tf.estimator.Estimator. If I call train method multiple times:
def input_fn():
dataset.batch(16)
for epoch in range(epochs):
estimator_model.train(input_fn, steps)
it will rebuild model (different weights, different checkpoint dir, different tensorflow logs) - that's unacceptable for me.
Before estimator I would do:
for epoch in range(epochs):
for iter, data in enumerate(dataset):
model.train(data, epoch)
Now such code is deep within guts of Estimator and Dataset, and I have no control over it - so it is hard for me to do things like decaying learning rate, etc. (do something for first/last n
epochs).