0
votes

I'm currently training two autoencoders using the model.train_on_batch from Tensorflow and I have written my own for loop for the epochs. I load the models from the .h5 files. If I were to stop training my model and then continue to train it later, will the model just continue training from the last epoch in which the weights were saved. For example if, the weights were saved on the 10,000th epoch and I run the code again to train it, will it automatically continue training from the last saved weight since the files the .h5 have been loaded in with model.load_weights.

I have the code below, would I need to make any alterations in this code.

enc .load_weights('models/encoder.h5')
decoderA.load_weights('models/decoder_A.h5')
decoderB.load_weights('models/decoder_B.h5')

def save_model_weights():
    enc .save_weights('models/encoder.h5')
    decoderA.save_weights('models/decoder_A.h5')
    decoderB.save_weights('models/decoder_B.h5')

for epoch in range(1000000):
    batch_size = 64
    warped_A, target_A = train_util.training_data( train_setA, batch_size )
    warped_B, target_B = train_util.training_data( train_setB, batch_size )

    loss_A = aeA.train_on_batch( warped_A, target_A )
    loss_B = aeB.train_on_batch( warped_B, target_B )
    print( loss_A, loss_B )
    print('Current epoch no... ' + str(epoch))

    if epoch % 50 == 0:
        save_model_weights()
        print('Model weights saved')
        test_A = target_A[0:14]
        test_B = target_B[0:14]

    figure_A = np.stack([
        test_A,
        aeA.predict( test_A ),
        aeB.predict( test_A ),
        ], axis=1 )
    figure_B = np.stack([
        test_B,
        aeB.predict( test_B ),
        aeA.predict( test_B ),
        ], axis=1 )

    figure = np.concatenate( [ figure_A, figure_B ], axis=0 )
    figure = figure.reshape( (4,7) + figure.shape[1:] )
    figure = train_util.stack_images( figure )

    figure = np.clip( figure * 255, 0, 255 ).astype('uint8')

    cv2.imshow( "", figure )
    key = cv2.waitKey(1)
    if key == ord('q'):
        save_model_weights()
        exit()