1
votes

My input function looks like this:

    def input_fn():
        dataset.repeat(epochs).batch(16)
    estimator_model.train(input_fn, steps)

How can I notify my model that this is n-th repetition (epoch) over dataset? I would like to implement things like decaying learning rate, training model without adversarial loss for first n epochs etc. I am using tf.data.Dataset and tf.estimator.Estimator. If I call train method multiple times:

    def input_fn():
        dataset.batch(16)
    for epoch in range(epochs):
        estimator_model.train(input_fn, steps)

it will rebuild model (different weights, different checkpoint dir, different tensorflow logs) - that's unacceptable for me.

Before estimator I would do:

for epoch in range(epochs):
    for iter, data in enumerate(dataset):
        model.train(data, epoch)

Now such code is deep within guts of Estimator and Dataset, and I have no control over it - so it is hard for me to do things like decaying learning rate, etc. (do something for first/last n epochs).

1

1 Answers

0
votes

If you know the size of your train set you can set a parameter steps_per_epoch = train_size//batch_size. Then in your model_fn query the global_step = tf.train.get_global_step() tensor and then get the number of epochs that have passed as a tensor epochs_passed = tf.cast(global_step, tf.float32)/steps_per_epoch.

For many applications like the learning rate schedule you mentioned it's often more idiomatic to just use tf.train.piecewise_constant_decay which is based around a similar concept.