Author

Zeju Li

Gaussian Mixture Models

Adapted from https://mml-book.github.io/

In this notebook, we will look at density modeling with Gaussian mixture models (GMMs). In Gaussian mixture models, we describe the density of the data as \[ p(\boldsymbol x) = \sum_{k=1}^K \pi_k \mathcal{N}(\boldsymbol x|\boldsymbol \mu_k, \boldsymbol \Sigma_k)\,,\quad \pi_k \geq 0\,,\quad \sum_{k=1}^K\pi_k = 1 \]

The goal of this notebook is to get a better understanding of GMMs and to write some code for training GMMs using the EM algorithm. We provide a code skeleton and mark the bits and pieces that you need to implement yourself.

Install required packages

!pip install numpy
!pip install matplotlib
!pip install scipy
!pip install IPython
Requirement already satisfied: numpy in /opt/anaconda3/lib/python3.12/site-packages (1.26.4)
Requirement already satisfied: matplotlib in /opt/anaconda3/lib/python3.12/site-packages (3.9.2)
Requirement already satisfied: contourpy>=1.0.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.2.0)
Requirement already satisfied: cycler>=0.10 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (0.11.0)
Requirement already satisfied: fonttools>=4.22.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (4.51.0)
Requirement already satisfied: kiwisolver>=1.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.4.4)
Requirement already satisfied: numpy>=1.23 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (1.26.4)
Requirement already satisfied: packaging>=20.0 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (24.1)
Requirement already satisfied: pillow>=8 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (3.1.2)
Requirement already satisfied: python-dateutil>=2.7 in /opt/anaconda3/lib/python3.12/site-packages (from matplotlib) (2.9.0.post0)
Requirement already satisfied: six>=1.5 in /opt/anaconda3/lib/python3.12/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)
Requirement already satisfied: scipy in /opt/anaconda3/lib/python3.12/site-packages (1.13.1)
Requirement already satisfied: numpy<2.3,>=1.22.4 in /opt/anaconda3/lib/python3.12/site-packages (from scipy) (1.26.4)
Requirement already satisfied: IPython in /opt/anaconda3/lib/python3.12/site-packages (8.27.0)
Requirement already satisfied: decorator in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (5.1.1)
Requirement already satisfied: jedi>=0.16 in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (0.19.1)
Requirement already satisfied: matplotlib-inline in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (0.1.6)
Requirement already satisfied: prompt-toolkit<3.1.0,>=3.0.41 in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (3.0.43)
Requirement already satisfied: pygments>=2.4.0 in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (2.15.1)
Requirement already satisfied: stack-data in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (0.2.0)
Requirement already satisfied: traitlets>=5.13.0 in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (5.14.3)
Requirement already satisfied: pexpect>4.3 in /opt/anaconda3/lib/python3.12/site-packages (from IPython) (4.8.0)
Requirement already satisfied: parso<0.9.0,>=0.8.3 in /opt/anaconda3/lib/python3.12/site-packages (from jedi>=0.16->IPython) (0.8.3)
Requirement already satisfied: ptyprocess>=0.5 in /opt/anaconda3/lib/python3.12/site-packages (from pexpect>4.3->IPython) (0.7.0)
Requirement already satisfied: wcwidth in /opt/anaconda3/lib/python3.12/site-packages (from prompt-toolkit<3.1.0,>=3.0.41->IPython) (0.2.5)
Requirement already satisfied: executing in /opt/anaconda3/lib/python3.12/site-packages (from stack-data->IPython) (0.8.3)
Requirement already satisfied: asttokens in /opt/anaconda3/lib/python3.12/site-packages (from stack-data->IPython) (2.0.5)
Requirement already satisfied: pure-eval in /opt/anaconda3/lib/python3.12/site-packages (from stack-data->IPython) (0.2.2)
Requirement already satisfied: six in /opt/anaconda3/lib/python3.12/site-packages (from asttokens->stack-data->IPython) (1.16.0)
# imports
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
import scipy.linalg as la
import matplotlib.cm as cm
from matplotlib import rc
import time
from IPython import display

%matplotlib inline

np.random.seed(42)

Define a GMM from which we generate data

Set up the true GMM from which we will generate data.

# Choose a GMM with 3 components

# means
m = np.zeros((3,2))
m[0] = np.array([1.2, 0.4])
m[1] = np.array([-4.4, 1.0])
m[2] = np.array([4.1, -0.3])

# covariances
S = np.zeros((3,2,2))
S[0] = np.array([[0.8, -0.4], [-0.4, 1.0]])
S[1] = np.array([[1.2, -0.8], [-0.8, 1.0]])
S[2] = np.array([[1.2, 0.6], [0.6, 3.0]])

# mixture weights
w = np.array([0.3, 0.2, 0.5])

Generate some data

N_split = 200 # number of data points per mixture component
N = N_split*3 # total number of data points
x = []
y = []
for k in range(3):
    x_tmp, y_tmp = np.random.multivariate_normal(m[k], S[k], N_split).T
    x = np.hstack([x, x_tmp])
    y = np.hstack([y, y_tmp])

data = np.vstack([x, y])

Visualization of the dataset

X, Y = np.meshgrid(np.linspace(-10,10,100), np.linspace(-10,10,100))
pos = np.dstack((X, Y))

mvn = multivariate_normal(m[0,:].ravel(), S[0,:,:])
xx = mvn.pdf(pos)

# plot the dataset
plt.figure()
plt.title("Mixture components")
plt.plot(x, y, 'ko', alpha=0.3)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

# plot the individual components of the GMM
plt.plot(m[:,0], m[:,1], 'or')

for k in range(3):
    mvn = multivariate_normal(m[k,:].ravel(), S[k,:,:])
    xx = mvn.pdf(pos)
    plt.contour(X, Y, xx,  alpha = 1.0, zorder=10)

# plot the GMM
plt.figure()
plt.title("GMM")
plt.plot(x, y, 'ko', alpha=0.3)
plt.xlabel("$x_1$")
plt.ylabel("$x_2$")

# build the GMM
gmm = 0
for k in range(3):
    mix_comp = multivariate_normal(m[k,:].ravel(), S[k,:,:])
    gmm += w[k]*mix_comp.pdf(pos)

plt.contour(X, Y, gmm,  alpha = 1.0, zorder=10);

Train the GMM via EM

Initialize the parameters for EM

K = 3 # number of clusters

means = np.zeros((K,2))
covs = np.zeros((K,2,2))
for k in range(K):
    means[k] = np.random.normal(size=(2,))
    covs[k] = np.eye(2)

weights = np.ones((K,1))/K
print("Initial mean vectors (one per row):\n" + str(means))
Initial mean vectors (one per row):
[[ 0.1252245  -0.42940554]
 [ 0.1222975   0.54329803]
 [ 0.04886007  0.04059169]]
#EDIT THIS FUNCTION
NLL = [] # log-likelihood of the GMM
gmm_nll = 0
for k in range(K):
    gmm_nll += weights[k]*multivariate_normal.pdf(mean=means[k,:], cov=covs[k,:,:], x=data.T)
NLL += [-np.sum(np.log(gmm_nll))]

plt.figure()
plt.plot(x, y, 'ko', alpha=0.3)
plt.plot(means[:,0], means[:,1], 'oy', markersize=25)

for k in range(K):
    rv = multivariate_normal(means[k,:], covs[k,:,:])
    plt.contour(X, Y, rv.pdf(pos), alpha = 1.0, zorder=10)

plt.xlabel("$x_1$");
plt.ylabel("$x_2$");

First, we define the responsibilities (which are updated in the E-step), given the model parameters \(\pi_k, \boldsymbol\mu_k, \boldsymbol\Sigma_k\) as \[ r_{nk} := \frac{\pi_k\mathcal N(\boldsymbol x_n|\boldsymbol\mu_k,\boldsymbol\Sigma_k)}{\sum_{j=1}^K\pi_j\mathcal N(\boldsymbol x_n|\boldsymbol \mu_j,\boldsymbol\Sigma_j)} \]

Given the responsibilities we just defined, we can update the model parameters in the M-step as follows: \[\begin{align*} \boldsymbol\mu_k^\text{new} &= \frac{1}{N_k}\sum_{n = 1}^Nr_{nk}\boldsymbol x_n\,,\\ \boldsymbol\Sigma_k^\text{new}&= \frac{1}{N_k}\sum_{n=1}^Nr_{nk}(\boldsymbol x_n-\boldsymbol\mu_k)(\boldsymbol x_n-\boldsymbol\mu_k)^\top\,,\\ \pi_k^\text{new} &= \frac{N_k}{N} \end{align*}\] where \[ N_k := \sum_{n=1}^N r_{nk} \]

EM Algorithm

#EDIT THIS FUNCTION
r = np.zeros((K,N)) # will store the responsibilities

for em_iter in range(100):
    means_old = means.copy()

    # E-step: update responsibilities
    for k in range(K):
        r[k] = weights[k]*multivariate_normal.pdf(mean=means[k,:], cov=covs[k,:,:], x=data.T)

    r = r/np.sum(r, axis=0)

    # M-step
    N_k = np.sum(r, axis=1)

    for k in range(K):
        # update means
        means[k] = np.sum(r[k]*data, axis=1)/N_k[k]

        # update covariances
        diff = data - means[k:k+1].T
        _tmp = np.sqrt(r[k:k+1])*diff
        covs[k] = np.inner(_tmp, _tmp)/N_k[k]

    # weights
    weights = N_k/N

    # log-likelihood
    gmm_nll = 0
    for k in range(K):
        gmm_nll += weights[k]*multivariate_normal.pdf(mean=means[k,:].ravel(), cov=covs[k,:,:], x=data.T)
    NLL += [-np.sum(np.log(gmm_nll))]

    plt.figure()
    plt.plot(x, y, 'ko', alpha=0.3)
    plt.plot(means[:,0], means[:,1], 'oy', markersize=25)
    for k in range(K):
        rv = multivariate_normal(means[k,:], covs[k])
        plt.contour(X, Y, rv.pdf(pos), alpha = 1.0, zorder=10)

    plt.xlabel("$x_1$")
    plt.ylabel("$x_2$")
    plt.text(x=3.5, y=8, s="EM iteration "+str(em_iter+1))

    if la.norm(NLL[em_iter+1]-NLL[em_iter]) < 1e-6:
        print("Converged after iteration ", em_iter+1)
        break

# plot final the mixture model
plt.figure()
gmm = 0
for k in range(3):
    mix_comp = multivariate_normal(means[k,:].ravel(), covs[k,:,:])
    gmm += weights[k]*mix_comp.pdf(pos)

plt.plot(x, y, 'ko', alpha=0.3)
plt.contour(X, Y, gmm,  alpha = 1.0, zorder=10)
plt.xlim([-8,8]);
plt.ylim([-6,6]);
plt.show()
/var/folders/nl/7_2jcxd12wb5z06jvsj1v4240000gn/T/ipykernel_39400/3431812709.py:34: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.
  plt.figure()
Converged after iteration  89

plt.figure()
plt.semilogy(np.linspace(1,len(NLL), len(NLL)), NLL)
plt.xlabel("EM iteration");
plt.ylabel("Negative log-likelihood");

idx = [0, 1, 9, em_iter+1]

for i in idx:
    plt.plot(i+1, NLL[i], 'or')
plt.show()

Back to top

Reuse

Citation

For attribution, please cite this work as:
Li, Zeju. n.d. “Gaussian Mixture Models.” https://zerojumpline.github.io//teaching/2025-08-08-Pattern Recognition/code_1_gmm.html.