I am writing a code for running autoencoder on CIFAR10 dataset and see the reconstructed images.
The requirement is to create
Encoder with First Layer
- Input shape: (32,32,3)
- Conv2D Layer with 64 Filters of (3,3)
- BatchNormalization layer
- ReLu activation
- 2D MaxpoolingLayer with (2,2) filter
Encoder with Second Layer
- Conv2D layer with 16 filters (3,3)
- BatchNormalization layer
- ReLu activation
- 2D MaxpoolingLayer with (2,2) filter
- Final Encoded as MaxPool with (2,2) with all previous layers
Decoder with First Layer
- Input shape: encoder output
- Conv2D Layer with 16 Filters of (3,3)
- BatchNormalization layer
- ReLu activation
- UpSampling2D with (2,2) filter
Decoder with Second Layer
- Conv2D Layer with 32 Filters of (3,3)
- BatchNormalization layer
- ReLu activation
- UpSampling2D with (2,2) filter
- Final Decoded as Sigmoid with all previous layers
I understand that
- When we are creating Convolutional Autoencoder (or any AE), we need to pass the output of the previous layer to the next layer.
- So, when I create the first Conv2D layer with ReLu and then perform BatchNormalization .. in which I pass the Conv2D layer .. right?
- But when I do MaxPooling2D .. what should I pass .. BatchNormalization output or Conv2D layer output?
Also, is there any order in which I should be performing these operations?
- Conv2D --> BatchNormalization --> MaxPooling2D
- OR
- Conv2D --> MaxPooling2D --> BatchNormalization
I am attaching my code below ... I have attempted it to two different ways and hence getting different outputs (in terms of model summary and also model training graph)
Can someone please help me by explaining which is the correct method (Method-1 or Method-2)? Also, how do I understand which graph shows better model performance?
Method - 1
input_image = Input(shape=(32, 32, 3))
### Encoder
conv1_1 = Conv2D(64, (3, 3), activation='relu', padding='same')(input_image)
bnorm1_1 = BatchNormalization()(conv1_1)
mpool1_1 = MaxPooling2D((2, 2), padding='same')(conv1_1)
conv1_2 = Conv2D(16, (3, 3), activation='relu', padding='same')(mpool1_1)
borm1_2 = BatchNormalization()(conv1_2)
encoder = MaxPooling2D((2, 2), padding='same')(conv1_2)
### Decoder
conv2_1 = Conv2D(16, (3, 3), activation='relu', padding='same')(encoder)
bnorm2_1 = BatchNormalization()(conv2_1)
up1_1 = UpSampling2D((2, 2))(conv2_1)
conv2_2 = Conv2D(32, (3, 3), activation='relu', padding='same')(up1_1)
bnorm2_2 = BatchNormalization()(conv2_2)
up2_1 = UpSampling2D((2, 2))(conv2_2)
decoder = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(up2_1)
model = Model(input_image, decoder)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.summary()
history = model.fit(trainX, trainX,
epochs=50,
batch_size=1000,
shuffle=True,
verbose=2,
validation_data=(testX, testX)
)
As an output of the model summary, I get this
Total params: 18,851
Trainable params: 18,851
Non-trainable params: 0
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper right')
plt.show()
Method - 2
input_image = Input(shape=(32, 32, 3))
### Encoder
x = Conv2D(64, (3, 3), activation='relu', padding='same')(input_image)
x = BatchNormalization()(x)
x = MaxPooling2D((2, 2), padding='same')(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
encoder = MaxPooling2D((2, 2), padding='same')(x)
### Decoder
x = Conv2D(16, (3, 3), activation='relu', padding='same')(encoder)
x = BatchNormalization()(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = UpSampling2D((2, 2))(x)
decoder = Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
model = Model(input_image, decoder)
model.compile(optimizer='adam', loss='binary_crossentropy')
model.summary()
history = model.fit(trainX, trainX,
epochs=50,
batch_size=1000,
shuffle=True,
verbose=2,
validation_data=(testX, testX)
)
As an output of the model summary, I get this
Total params: 19,363
Trainable params: 19,107
Non-trainable params: 256
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'test'], loc='upper right')
plt.show()