I am trying to applying VAE in a simple toy example to familiarize with its property. However, I get stuck in training the model. The total loss and the reconstruction error does not seem to decrease.
The toy example is listed below
- random generate 5000 observation from a 2-dimensional multivariate normal distribution.
- apply a transformation f, [x,y] --> [sin(x),sin(y)]
- train a VAE with 1 hidden layer with 5 neuron in both encoder and decoder. The VAE has 2 latent variables.
In this example, I am not able to decrease the training loss to a sufficiently low level, and the reconstruction is also messy.
I have made several attempts
- increase the hidden layers to 2 and 3 (this does not help) --> I think it is not due to the complexity of the model
- check the network on MNIST (the result is comparable with the example I found on other sources) --> the model design is right
- delete the KL divergence in the Loss function (the model can reconstruct well) --> the model design is right
- I try to balance the weight on KL divergence --> when beta on KL divergence is low, it reconstruct well, but latent space is too far away from standard normal, when beta on KL divergence is high, it can not reconstruct well, but the latent space perform well.
I now suspect several potential reasons, but I can not distinguish which one could be the reason.
- It seems that I need to find a balance between weight on KL divergence and reconstruction loss
- Is it appropriate to use MSE loss + KL divergence as loss function?
- In low-dimension, the VAE does not perform because the ELBO is not so tight?
Could any one help?
The code is attached.
This part defines the model
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
class VAE_Encoder(nn.Module):
def __init__(self,input_size,hidden_size_list,latent_size):
The class is the builder of the encoder part of VAE. It does not need to be directly called.
:param input_size: int
:param hidden_size_list: list(int)
:param latent_size: int
encoder_size = [input_size]+hidden_size_list
encoder_layers = []
for in_size,out_size in zip(encoder_size[:-1],encoder_size[1:]):
self.encoder = nn.Sequential(*encoder_layers)
self.encoder_mu = nn.Linear(encoder_size[-1],latent_size)
self.encoder_logvar = nn.Linear(encoder_size[-1],latent_size)
def encode(self,x):
return self.encoder(x)
def encode_gaussian_param(self,encode_x):
return self.encoder_mu(encode_x),self.encoder_logvar(encode_x)
def reparametrize(self,mu,logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu+eps*std
def forward(self,x):
encode_x = self.encoder(x)
mu,logvar = self.encode_gaussian_param(encode_x)
z = self.reparametrize(mu,logvar)
return z,mu,logvar
class VAE_Decoder(nn.Module):
def __init__(self,input_size,hidden_size_list,latent_size):
The class is the builder of the decoder part of VAE. It does not need to be directly called.
:param input_size: int
:param hidden_size_list: list(int)
:param latent_size: int
decoder_size = [latent_size] + hidden_size_list
decoder_layers = []
for in_size,out_size in zip(decoder_size[:-1],decoder_size[1:]):
self.decoder = nn.Sequential(*decoder_layers)
def forward(self,z):
return self.decoder(z)
class VAE(nn.Module):
def __init__(self,input_size,encoder_size,latent_size,decoder_size=None):
The class builds the whole VAE. It consists of a encoder model and a decoder model.
The user has flexibility to choose the number of layers in each part of the model by
setting the encoder size and decoder size.
:param input_size: int
:param encoder_size: list(int)
:param latent_size: int
:param decoder_size: list(int)
if decoder_size is None:
decoder_size = encoder_size[::-1]
self.encoder = VAE_Encoder(input_size,encoder_size,latent_size)
self.decoder = VAE_Decoder(input_size,decoder_size,latent_size)
def decode(self,z):
return self.decoder(z)
def forward(self,x):
z,mu,logvar = self.encoder(x)
x = self.decoder(z)
return x,mu,logvar
def simple_vae_loss(real,recon,mu,logvar,penalty=1):
MSE = nn.functional.mse_loss(recon,real,reduction="sum")
KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return MSE+penalty*KLD
This part generate the toy example
def generate_value(x,y):
v1 = np.sin(x)
v2 = np.sin(y)
x = np.zeros(x.shape)
y = np.zeros(y.shape)
return (v1,v2)
rand_number = torch.randn((5000,2))
x = rand_number.numpy()[:,0]
y = rand_number.numpy()[:,1]
new_value = generate_value(x,y)
new_x = new_value[0].reshape((5000,1))
new_y = new_value[1].reshape((5000,1))
new_data = np.concatenate((new_x,new_y),axis=1)