Free Energy and EM Algorithm

Latent Variable Model

Latent variable model is a generic term for a broad class of statistical models. The examples include mixture of Gaussians, factor analysis, independent component analysis (ICA), principle component analysis (PCA), and so on.

In a latent variable model, we are able to observe random variables X, which are generated by latent variables Z and model parameters θ. Given a latent variable model and observed data X, we usually want to find the latent variables Z and model parameters θ that maximize the log likelihood.

(θ)=deflnp(X|θ)

This turns out to be usually hard to maximize, mainly because we need to compute the likelihood by marginalizing the joint probability.

p(X|θ)=dZp(Z,X|θ)

This integral can be computationally expensive or analytically intractable to compute. A common alternative is to maximize a lower bound of the log likelihood instead of itself. This lower bound of log likelihood is called "free energy", a jargon borrowed from physics.

Free Energy

Free energy is a lower bound of log likelihood

We have defined the log likelihood, which can be expressed by the marginalized joint probability.

(θ)=deflnp(X|θ)=lndZp(Z,X|θ)

By Jensen and the concavity of the logarithm function , any distribution q(Z) over the latent variables generates a lower bound for l(θ), which we define as the free energy F(q,θ).

(1)(θ)=lndZq(Z)p(Z,X|θ)q(Z)dZq(Z)lnp(Z,X|θ)q(Z)=defF(q,θ)

We did this "useless" thing of multiplying and dividing q(Z) because we need to massage our equation into a ready-to-use-Jensen format. Recall that Jensen tells us: taking the mean of a random variable X and then plugging it into a concave function (i.e., here the logarithm function) spits a larger value than plugging X into the concave function first and then taking its mean.

lnXlnX

The inequality in (1) is in the same format

lnp(Z,X|θ)q(Z)q(Z)lnp(Z,X|θ)q(Z)q(Z)

Two useful ways to re-write free energy

The free energy can be re-written into the expected joint under distribution q(Z) plus the entropy of q(Z).

F(q,θ)=dZq(Z)lnp(Z,X|θ)q(Z)=dZq(Z)lnp(Z,X|θ)dZq(Z)lnq(Z)(2)=lnp(Z,X|θ)q(Z)+H(q)

Or, the log-likelihood subtracted by the KL-divergence between q(Z) and the true posterior of latents.

F(q,θ)=dZq(Z)lnp(X|θ)p(Z|X,θ)q(Z)=dZq(Z)lnp(X|θ)+dZq(Z)lnp(Z|X,θ)q(Z)=lnp(X|θ)dZq(Z)dZq(Z)lnq(Z)p(Z|X,θ)(3)=lnp(X|θ)KL[q(Z)||p(Z|X,θ)]

The derivations should look ad hoc freestyling the first time. We re-organize the terms like this because hindsight shows that factoring out entropy or KL-divergence would give us useful expressions. Otherwise, one is not expected to autonomously want to mess around with the terms in this particular way.

EM

The expectation-maximization algorithm maximizing the free energy by updating model parameters θ and (the distribution on) latents Z in alternation.

E step

In an E step, we update the distribution of latents. The free energy expression in (3) is useful for this, because the first log-likelihood term does not involve Z. Hence, maximizing the free energy with respect to q(Z) is equivalent to minimizing the KL-divergence.

q(Z)=argmaxqF(q,θ)=argminqKL[q(Z)||p(Z|X,θ)]

We know that the KL-divergence attain its minimum 0 if and only if

q(Z)=p(Z|X,θ)

The interpretation is that we are setting the distribution on Z to the true posterior under the data set and the current parameters.

Usually, we will then use the Bayes theorem to write out the posterior.

p(Z|X,θ)=p(X|Z,θ)p(Z|θ)dZp(X|Z,θ)p(Z|θ)

In the Gaussian mixture model, q(Z) appears in the name of responsibility, meaning how much it is a certain mixture's responsibility to have generated a data point.

M step

In an M step, we update the model parameters. The first free energy expression is useful for this, because the entropy term does not involve θ. Hence, maximizing the free energy with respect to θ is equivalent to maximizing the expected joint under q(Z).

θ=argmaxθF(q,θ)=argmaxθlnp(Z,X|θ)q(Z)

Usually, we will take gradient of the joint with respect to the parameters and set to zero to solve for the parameters update.

θlnp(Z,X|θ)q(Z)=set0

Convergence of EM

We have this summative inequality chain. The superscript means the number of iteration.

(θ(t1))=F(q(t),θ(t1))F(q(t),θ(t))(θ(t))

The first equality is due to E step. By making the KL-divergence zero, the lower bound on the log likelihood, i.e., free energy, is tightened to equal the log likelihood after an E step.

The second inequality is due to M step. In an M step, the entropy does not change but the expected joint probability increases because M step updates the model parameters to maximize the expected joint.

The third inequality is due to Jensen, which we have shown in equation (1).

Hence, it hold true for every iteration that

(θ(t1))(θ(t))

This means the log likelihood never decreases in EM algorithms. It is a pretty nice property, because many other optimization-based methods cannot guarantee ones objective function never decrease. It is also a (frustratingly) useful debug tool; because we will know for sure our code has bugs if the log likelihood increases at any iteration.

© Yedi Zhang | Last updated: April 2023