11
votes

I have a very simple question. I have a Keras model (TF backend) defined for classification. I want to dump the training images fed into my model during training for debugging purposes. I am trying to create a custom callback that writes Tensorboard image summaries for this.

But how can I obtain the real training data inside the callback?

Currently I am trying this:

class TensorboardKeras(Callback):                                                                                                                                                                                                                                     
    def __init__(self, model, log_dir, write_graph=True):                                                                                                                                                                                                             
        self.model = model                                                                                                                                                                                                                                            
        self.log_dir = log_dir                                                                                                                                                                                                                                        
        self.session = K.get_session()                                                                                                                                                                                                                                

        tf.summary.image('input_image', self.model.input)                                                                                                                                                                                                             
        self.merged = tf.summary.merge_all()                                                                                                                                                                                                                          

        if write_graph:                                                                                                                                                                                                                                               
            self.writer = tf.summary.FileWriter(self.log_dir, K.get_session().graph)                                                                                                                                                                                  
        else:                                                                                                                                                                                                                                                         
            self.writer = tf.summary.FileWriter(self.log_dir)

    def on_batch_end(self, batch, logs=None):
        summary = self.session.run(self.merged, feed_dict={})                                                                                                                                                                                                         
        self.writer.add_summary(summary, batch)                                                                                                                                                                                                                       
        self.writer.flush()

But I am getting the error: InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]

There must be a way to see what models, get as an input, right?

Or maybe I should try another way to debug it?

1
While you definitely should be able to accomplish this, why not check the data at the point of input? eg. your data generator. - matt
I am using tf.data.TFRecords dataset fed directly into model.fit method. To check the data directly I have to write a wrapper code that retrieves the data batch by batch. Also, it will not be a part of training, but an auxiliary code to debug data. Alternatively, I can employ a callback and keep the code simpler. - Dmytro Prylipko
I'm afraid that is the only solution in Keras... Never saw anything related to data in callbacks, only stats. - Daniel Möller

1 Answers

3
votes

You don't need callbacks for this. All you need to do is implementing a function that yields an image and its label as a tuple. flow_from_directory function has a parameter called save_to_dir which could satisfy all of your needs, in case it doesn't, here is what you can do:

def trainGenerator(batch_size,train_path, image_size)
    #preprocessing see https://keras.io/preprocessing/image/ for details
    image_datagen = ImageDataGenerator(horizontal_flip=True)
    #create image generator see https://keras.io/preprocessing/image/#flow_from_directory for details
    train_generator = image_datagen.flow_from_directory(
        train_path,
        class_mode = "categorical",
        target_size = image_size,
        batch_size = batch_size,
        save_prefix  = "augmented_train",
        seed = seed)

    for (batch_imgs, batch_labels) in train_generator: 
        #do other stuff such as dumping images or further augmenting images
    yield (batch_imgs,batch_labels)


t_generator = trainGenerator(32, "./train_data", (224,224,3))
model.fit_generator(t_generator,steps_per_epoch=10,epochs=1)