Author

Zeju Li

Variational AutoEncoder (VAE)

Adapted from https://github.com/Jackson-Kang/Pytorch-VAE-tutorial

Consider to download this Jupyter Notebook and run locally, or test it with Colab.

Download Open In Colab

Install required packages

!pip install torch
!pip install matplotlib
Requirement already satisfied: torch in /opt/anaconda3/lib/python3.12/site-packages (2.8.0)
Requirement already satisfied: filelock in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.13.1)
Requirement already satisfied: typing-extensions>=4.10.0 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (4.15.0)
Requirement already satisfied: setuptools in /opt/anaconda3/lib/python3.12/site-packages (from torch) (75.1.0)
Requirement already satisfied: sympy>=1.13.3 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (1.14.0)
Requirement already satisfied: networkx in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.3)
Requirement already satisfied: jinja2 in /opt/anaconda3/lib/python3.12/site-packages (from torch) (3.1.4)
Requirement already satisfied: fsspec in /opt/anaconda3/lib/python3.12/site-packages (from torch) (2024.6.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/anaconda3/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/anaconda3/lib/python3.12/site-packages (from jinja2->torch) (2.1.3)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.23 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
import torch
import torch.nn as nn

import numpy as np

from tqdm import tqdm
from torchvision.utils import save_image, make_grid
# Model Hyperparameters

dataset_path = '~/datasets'

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")

batch_size = 100

x_dim  = 784
hidden_dim = 400
latent_dim = 200

lr = 1e-3

epochs = 30

Step 1. Load (or download) Dataset

from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


mnist_transform = transforms.Compose([
        transforms.ToTensor(),
])

kwargs = {'num_workers': 1, 'pin_memory': True} 

train_dataset = MNIST(dataset_path, transform=mnist_transform, train=True, download=True)
test_dataset  = MNIST(dataset_path, transform=mnist_transform, train=False, download=True)

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
test_loader  = DataLoader(dataset=test_dataset,  batch_size=batch_size, shuffle=False, **kwargs)

Step 2. Define our model: Variational AutoEncoder (VAE)

"""
    A simple implementation of Gaussian MLP Encoder and Decoder
"""

class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()

        self.FC_input = nn.Linear(input_dim, hidden_dim)
        self.FC_input2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_mean  = nn.Linear(hidden_dim, latent_dim)
        self.FC_var   = nn.Linear (hidden_dim, latent_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
        self.training = True
        
    def forward(self, x):
        h_       = self.LeakyReLU(self.FC_input(x))
        h_       = self.LeakyReLU(self.FC_input2(h_))
        mean     = self.FC_mean(h_)
        log_var  = self.FC_var(h_)                     # encoder produces mean and log of variance 
                                                       #             (i.e., parateters of simple tractable normal distribution "q"
        
        return mean, log_var
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.FC_hidden = nn.Linear(latent_dim, hidden_dim)
        self.FC_hidden2 = nn.Linear(hidden_dim, hidden_dim)
        self.FC_output = nn.Linear(hidden_dim, output_dim)
        
        self.LeakyReLU = nn.LeakyReLU(0.2)
        
    def forward(self, x):
        h     = self.LeakyReLU(self.FC_hidden(x))
        h     = self.LeakyReLU(self.FC_hidden2(h))
        
        x_hat = torch.sigmoid(self.FC_output(h))
        return x_hat
        
class Model(nn.Module):
    def __init__(self, Encoder, Decoder):
        super(Model, self).__init__()
        self.Encoder = Encoder
        self.Decoder = Decoder
        
    def reparameterization(self, mean, var):
        epsilon = torch.randn_like(var).to(DEVICE)        # sampling epsilon        
        z = mean + var*epsilon                          # reparameterization trick
        return z
        
                
    def forward(self, x):
        mean, log_var = self.Encoder(x)
        z = self.reparameterization(mean, torch.exp(0.5 * log_var)) # takes exponential function (log var -> var)
        x_hat            = self.Decoder(z)
        
        return x_hat, mean, log_var
encoder = Encoder(input_dim=x_dim, hidden_dim=hidden_dim, latent_dim=latent_dim)
decoder = Decoder(latent_dim=latent_dim, hidden_dim = hidden_dim, output_dim = x_dim)

model = Model(Encoder=encoder, Decoder=decoder).to(DEVICE)

Step 3. Define Loss function (reprod. loss) and optimizer

from torch.optim import Adam

BCE_loss = nn.BCELoss()

def loss_function(x, x_hat, mean, log_var):
    reproduction_loss = nn.functional.binary_cross_entropy(x_hat, x, reduction='sum')
    KLD      = - 0.5 * torch.sum(1+ log_var - mean.pow(2) - log_var.exp())

    return reproduction_loss + KLD


optimizer = Adam(model.parameters(), lr=lr)

Step 4. Train Variational AutoEncoder (VAE)

print("Start training VAE...")
model.train()

for epoch in range(epochs):
    overall_loss = 0
    for batch_idx, (x, _) in enumerate(train_loader):
        x = x.view(batch_size, x_dim)
        x = x.to(DEVICE)

        optimizer.zero_grad()

        x_hat, mean, log_var = model(x)
        loss = loss_function(x, x_hat, mean, log_var)
        
        overall_loss += loss.item()
        
        loss.backward()
        optimizer.step()
        
    print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*batch_size))
    
print("Finish!!")
Start training VAE...
    Epoch 1 complete!   Average Loss:  175.0465806128704
    Epoch 2 complete!   Average Loss:  128.90606839850273
    Epoch 3 complete!   Average Loss:  116.97019198664441
    Epoch 4 complete!   Average Loss:  112.97430726014711
    Epoch 5 complete!   Average Loss:  110.6636990198508
    Epoch 6 complete!   Average Loss:  109.06746030167467
    Epoch 7 complete!   Average Loss:  107.74997078464106
    Epoch 8 complete!   Average Loss:  106.72171278302379
    Epoch 9 complete!   Average Loss:  105.93948053070221
    Epoch 10 complete!  Average Loss:  105.06694583746348
    Epoch 11 complete!  Average Loss:  104.39186422357575
    Epoch 12 complete!  Average Loss:  103.90174390585872
    Epoch 13 complete!  Average Loss:  103.5466619496035
    Epoch 14 complete!  Average Loss:  103.15721392750939
    Epoch 15 complete!  Average Loss:  102.90074143755217
    Epoch 16 complete!  Average Loss:  102.59622096397642
    Epoch 17 complete!  Average Loss:  102.38915857483828
    Epoch 18 complete!  Average Loss:  102.17129713259078
    Epoch 19 complete!  Average Loss:  101.88609652154632
    Epoch 20 complete!  Average Loss:  101.79162388159953
    Epoch 21 complete!  Average Loss:  101.54875074994783
    Epoch 22 complete!  Average Loss:  101.43628648659224
    Epoch 23 complete!  Average Loss:  101.28708972962751
    Epoch 24 complete!  Average Loss:  101.16868274924353
    Epoch 25 complete!  Average Loss:  100.9885125273894
    Epoch 26 complete!  Average Loss:  100.94756503351941
    Epoch 27 complete!  Average Loss:  100.77305422774937
    Epoch 28 complete!  Average Loss:  100.66967468567404
    Epoch 29 complete!  Average Loss:  100.5374796372861
    Epoch 30 complete!  Average Loss:  100.50892499869575
Finish!!
/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
  warnings.warn(warn_msg)

Step 5. Generate images from test dataset

import matplotlib.pyplot as plt
model.eval()

with torch.no_grad():
    for batch_idx, (x, _) in enumerate(tqdm(test_loader)):
        x = x.view(batch_size, x_dim)
        x = x.to(DEVICE)
        
        x_hat, _, _ = model(x)


        break
  0%|                                                                                                                                   | 0/100 [00:00<?, ?it/s]/opt/anaconda3/lib/python3.12/site-packages/torch/utils/data/dataloader.py:684: UserWarning: 'pin_memory' argument is set as true but not supported on MPS now, then device pinned memory won't be used.
  warnings.warn(warn_msg)
  0%|                                                                                                                                   | 0/100 [00:01<?, ?it/s]

Let’s visualize test samples alongside their reconstruction.

def show_image_grid_detailed(x, nrows=10, ncols=10):
    x = x.view(-1, 28, 28)
    
    fig, axes = plt.subplots(nrows, ncols, figsize=(12, 12))
    
    for i in range(nrows):
        for j in range(ncols):
            idx = i * ncols + j
            if idx < len(x):
                axes[i, j].imshow(x[idx].cpu().numpy(), cmap='gray')
            axes[i, j].axis('off')
    
    plt.tight_layout()
    plt.show()
show_image_grid_detailed(x)

show_image_grid_detailed(x_hat)

Step 6. Generate image from noise vector

If q(z|x) is close to N(0, I) “enough”(but not tightly close due to posterior collapse problem), N(0, I) may replace the encoder of VAE.

with torch.no_grad():
    noise = torch.randn(batch_size, latent_dim).to(DEVICE)
    generated_images = decoder(noise)
show_image_grid_detailed(generated_images)

Back to top

Reuse

Citation

For attribution, please cite this work as:
Li, Zeju. n.d. “Variational AutoEncoder (VAE).” https://zerojumpline.github.io//teaching/2025-08-08-Pattern Recognition/code_4_vae.html.