For basic definitions, see Information.

Consider the two-stage latent variable model:

  1. $z \sim p_\theta(z)$.
  2. $x \mid z \sim p_\theta(x \mid z)$.

Models of this form, built out of differentiable components, are generally more flexible than models where $p_\theta(x)$ is itself directly differentiable. However, they can be harder to optimize.


As usual, given an IID sample $x_1, \dots, x_n$, we want to maximize the log-likelihood

$-\mathcal L(\theta) = \ell\ell(\mathcal U(x_1, \dots, x_n), p_\theta)$.

Which is a sum of terms of the form $\log p_\theta(x_i) = \log\mathbb E_{z \sim p_\theta(z)}[p_\theta(x_i \mid z)]$, which are not directly differentiable because of the expectation.

And even when we can differentiate through the expectation, gradient descent often fails due to exponentially vanishing gradients.


Variational Inference

So we need a better general method to optimize such a model.

Given a target distibution $\mathcal D$ that we want to model, our true objective is to minimize the log-likelihood gap:

$$\ell\ell(\mathcal D(x), \mathcal D(x)) - \ell\ell(\mathcal D(x); p_\theta(x))$$

Which describes how much worse our model is than the best possible model of the distribution (and it precisely determines the probability that our model will assign to large IID sequences from $\mathcal D$, relative to the probability assigned to those sequences by $\mathcal D$ itself).

In Information we showed that modeling an extra variable is strictly harder, in the sense that for any distribution $q(z \mid x)$, it is harder to model the joint distribution $q(z \mid x)\mathcal D(x)$ than it is to just model $\mathcal D(x)$.

That is, if we define:

$$\mathcal L(\theta, \phi) := \ell\ell(q_\phi(z \mid x)\mathcal D(x), q_\phi(z \mid x)\mathcal D(x)) - \ell\ell(q_\phi(z \mid x)\mathcal D(x), p_\theta(x, z))$$

Then $\mathcal L(\theta, \phi)$ always upper-bounds the log-likelihood gap of our model wrt $\mathcal D$.

$$\mathcal L(\theta, \phi) \geq \ell\ell(\mathcal D(x), \mathcal D(x)) - \ell\ell(\mathcal D(x); p_\theta(x))$$

In more words, if $p_\theta(x, z)$ successfully models the joint distribution $q_\phi(z \mid x)\mathcal D(x)$ with likelihood gap $\leq \varepsilon$, then this implies that $p_\theta(x)$ also successfully models $\mathcal D(x)$ alone with likelihood gap $\leq \varepsilon$ too.

So to minimize the real objective, we can instead minimize $\mathcal L(\theta, \phi)$. Of course, if we minimize this only in $\phi$, then we are only getting a better approximation of the real log-likelihood gap of $p_\theta(x)$, but not actually improving it. And if we minimize only in $\theta$, then we are possibly improving the log-likelihood gap (or even making it worse, but just never worse than $\mathcal L(\theta, \phi)$), but we may eventually be unable to improve it further without also optimizing $\phi$, for example, if $p_\theta(z) \neq \mathbb E_{x \sim \mathcal D}[q_\phi(z \mid x)]$.

Therefore, the strategy which will allow us to minimize our guaranteed log-likelihood gap the most (i.e., minimize $\mathcal L(\theta, \phi)$ the most) will be to optimize both $\theta$ and $\phi$ for this goal.


Example: Gaussian Latent (VAEs)

Suppose we define $z \sim \mathcal N(\mu_\phi(x), \Sigma_\phi(x))$ to be conditionally Gaussian.

This is equivalent to $z \sim \mu_\phi(x) + \sqrt{\Sigma_\phi(x)}\mathcal N(0, I)$, therefore any expectation with $z$ is equivalent to $\mathbb E_{z \sim q_\phi(\cdot \mid x)}[\dots] = \mathbb E_{\varepsilon \sim \mathcal N(0, I)}[\dots]_{z = \mu_\phi(x) + \sqrt{\Sigma_\phi(x)}\varepsilon}$. The probability is:

\begin{align*} q_\phi(z \mid x) &= (2\pi|\Sigma_\phi(x)|)^{-\frac12}e^{-\frac12(z-\mu_\phi(x))^T\Sigma_\phi(x)^{-1}(z-\mu_\phi(x))}\\ &= (2\pi|\Sigma_\phi(x)|)^{-\frac12}e^{-\frac12(\sqrt{\Sigma_\phi(x)}\varepsilon)^T\Sigma_\phi(x)^{-1}(\sqrt{\Sigma_\phi(x)}\varepsilon)}\\ &= (2\pi|\Sigma_\phi(x)|)^{-\frac12}e^{-\frac12\|\varepsilon\|^2} \end{align*}

Therefore the loss becomes

\begin{align*} \mathcal L(\theta, \phi) &= \mathbb E_{x \sim \mathcal D, z \sim q_\phi(\cdot \mid x)}[\log (q_\phi(z \mid x)\mathcal D(x)) - \log p_\theta(x, z)]\\ &= \mathbb E_{x \sim \mathcal D, \varepsilon \sim \mathcal N(0, I)}[\log((2\pi|\Sigma_\phi(x)|)^{-\frac12})+ \log(e^{-\frac12\|\varepsilon\|^2}) + \log \mathcal D(x) - \log p_\theta(x, z)]\\ &= \mathbb E_{x \sim \mathcal D, \varepsilon \sim \mathcal N(0, I)}[-\log\sqrt{|\Sigma_\phi(x)|} - \log p_\theta(x, z)] + C\\ &= \mathbb E_{x \sim \mathcal D, \varepsilon \sim \mathcal N(0, I)}\left[-\log p_\theta(x,\,\, \mu_\phi(x) + \sqrt{\Sigma_\phi(x)}\varepsilon)\right] + \mathbb E_x\left[-\log\sqrt{|\Sigma_\phi(x)|}\right] + C \end{align*}

where $C = \mathbb E_{\varepsilon \sim \mathcal N(0, I)}[-\frac12\|\varepsilon\|^2] + \ell\ell(\mathcal D, \mathcal D) - \frac12 \log(2\pi)$ is constant.


Restrictions

If we restrict $\Sigma_\phi(x) = \sigma_\phi^2(x)I$ for the $d$-dimensional latent then this becomes

$\mathcal L(\theta, \phi) = \mathbb E_{x, \varepsilon}\left[-\log p_\theta(x,\,\, \mu_\phi(x) + \sigma_\phi(x)\varepsilon)\right] + d \cdot \mathbb E_x[-\log\sigma_\phi(x)]$


And if we restrict $p_\theta(z) = \mathcal N(0, \tau^2)(z)$ for fixed $\tau$ then

\begin{align*} \mathbb E_{x,\varepsilon}[\log p_\theta(\mu_\phi(x) + \sigma_\phi(x)\varepsilon)] &= \mathbb E_{x, \varepsilon}\left[ \frac1{2\tau^2}\|\mu_\phi(x) + \sigma_\phi(x)\varepsilon\|^2\right] + C \\ &= (1/2\tau^2) \cdot \mathbb E_x\left[\|\mu_\phi(x)\|^2 + d \cdot \sigma_\phi^2(x)\right] + C \end{align*}

and therefore

\begin{align*} \mathcal L(\theta, \phi) &= \mathbb E_{x, \varepsilon}\left[-\log p_\theta(x \mid \mu_\phi(x) + \sigma_\phi(x)\varepsilon)\right]\\ &+ \frac{1}{2\tau^2}\mathbb E_x\left[\|\mu_\phi(x)\|^2 + d \cdot \sigma_\phi^2(x)\right] + \mathbb E_x\left[- d \cdot \log \sigma_\phi(x)\right] + C \end{align*}

Note that the formula $\sigma \mapsto \sigma^2/2\tau^2 - \log \sigma$ is minimized by $\sigma = \tau$.


This final form of the loss contains:

  1. The standard auto-encoder loss, but with a parametrically noised latent.
  2. A regularization term which keeps the latent (a) small and (b) having the right level of noise.

Variational AutoEncoder

The variational autoencoder is defined as a latent variable model where the latent is a restricted Gaussian as above, and $p_\theta(x \mid z)$, $\mu_\phi(x)$ and $\sigma_\phi^2(x)$ are all given by neural networks.



Example: Mixture of Gaussians (EM algorithm)

Consider the two-stage model

$$k \sim \mathrm{Categorical}(p_1, \dots, p_K)$$
$$x \mid k \sim \mathcal N(\mu_z, \Sigma_z)$$

parametrized by $\theta = ((p_1, \mu_1, \Sigma_1), \dots, (p_K, \mu_K, \Sigma_K))$. This called a Mixture of Gaussians. Given a dataset $\\{x_1, \dots, x_n\\}$ of samples, our goal will be to maximize the model’s log-likelihood of the uniform distribution over the dataset, which is equal to the following

\begin{align*} &\ell\ell(\mathcal U(x_1, \dots, x_n); p_\theta) \\ &\quad\quad\quad= \sum_{i=1}^n \log\left(\sum_{k=1}^K p_k \cdot (2\pi|\Sigma_k|)^{-\frac12}\exp\left(-\tfrac12(x_i-\mu_k)^T\Sigma_k^{-1}(x_i-\mu_k)\right)\right) \end{align*}

Technically, $\ell\ell(\mathcal U(x_1, \dots, x_n); p_\theta)$ is differentiable with respect to $\theta$. So perhaps we could just optimize it via gradient descent.

Problem: Direct gradient descent can fail due to exponentially vanishing gradients

Suppose that the true distribution is a mixture of $K=2$ unidimensional Gaussians, with $p_1 = p_2 = 1/2$, $\sigma_1^2 = \sigma_2^2 = 1$, $\mu_1 = 9$, and $\mu_2 = 11$. But suppose we initialize with $\mu_1 = -1, \mu_2 = +1$. Then for a datapoint $x = 8.0$, we have

$$\frac{p_\theta(x \mid k=1)}{p_\theta(x \mid k=2)} = \frac{\exp(-(8+1)^2/2)}{\exp(-(8-1)^2/2)} = \exp(-16) \approx 0.$$

Therefore, almost all of the conditional probability will come from the latent being $k = 2$. And the inner sum will therefore give vanishingly small gradient to $\mu_1,\sigma^2_1$.

So gradient descent will simply find the solution $p_1 = 0, p_2 = 1, \mu_2 = 10, \sigma_2^2 = 2$, with $\mu_1, \sigma_1^2$ not budging from their initialization until an exponential amount of time later.


So we need a method which is better than gradient descent to minimize the loss.

We will use variational inference. In this case, because the latent is a simple categorical, we can exactly minimize $\mathcal L(\theta, \phi)$ in either $\theta$ or $\phi$. We will alternate between the two; this is called the EM algorithm, where the E-step optimizes $\phi$ and the M-step optimizes $\theta$.


Theorem. The EM algorithm for the mixture of Gaussians model is as follows.

E-step:

$q_\phi(\cdot \mid x) = \mathrm{softmax}([\log p_z - \frac12\left((x-\mu_z)^T\Sigma_z^{-1}(x-\mu_z) + \log \det \Sigma_z\right)]_{z=1}^k)$

M-step:

$p_z = \frac1n\sum_{i=1}^n q_\phi(z\mid x_i)$

$\mu_z = \sum_{i=1}^n q_\phi(x_i \mid z)\,x_i$

$\Sigma_z = \sum_{i=1}^n q_\phi(x_i \mid z)\,(x_i-\mu_z)(x_i-\mu_z)^T$.

Where $q_\phi(x_i \mid z) := \frac1nq_\phi(z \mid x_i)/p_z$.


Proof.

The M-step objective function is

$\mathcal L(\theta, \phi) = \mathbb E_{x \sim \mathcal D, z \sim q_\phi(\cdot \mid z)}[-\log p_\theta(x, z)] + C$

where $C$ is independent of $\theta$.

This splits into $\mathbb E_{x \sim \mathcal D, z \sim q_\phi(\cdot \mid z)}[-\log p_\theta(x \mid z)] -\log p_\theta(z)]$.


The parameters $p_1, \dots, p_k$ only appear in the second term, which is $\ell\ell(q_\phi(z), p_\theta(z))$ where $q_\phi(z) := \sum_{i=1}^n \frac1n q_\phi(z \mid x_i)$, hence is also maximized by $p_z = p_\theta(z) = q_\phi(z)$.


The first term, which is the only one that contains $(\mu_z, \Sigma_z)_{z=1}^k$, rearranges to

\begin{align*} \mathbb E_{x \sim \mathcal D, z \sim q_\phi(\cdot \mid z)}[\log p_\theta(x \mid z)] &= \sum_{i=1}^n\sum_{z=1}^k (1/n)\,q_\phi(z \mid x_i) \log p_\theta(x_i \mid z)\\ &= \sum_{i=1}^n\sum_{z=1}^k p_zq_\phi(x_i \mid z)\log p_\theta(x_i \mid z)\\ &= \sum_{z=1}^k p_z\sum_{i=1}^nq_\phi(x_i \mid z)\log p_\theta(x_i \mid z) \end{align*}

Where $q_\phi(x_i \mid z) := \frac1nq_\phi(z \mid x_i)/p_z$. Therefore, the term which contains $\mu_z, \Sigma_z$ is $p_z$ times:

\begin{align*} &\sum_{i=1}^nq_\phi(x_i \mid z)\log p_\theta(x_i \mid z)\\ &= \sum_{i=1}^n q_\phi(x_i \mid z) \cdot \left(-\frac12((x_i-\mu_z)^T\Sigma_z^{-1}(x_i-\mu_z) + \log\det\Sigma_z)\right)\\ &=-\frac12 \log\det\Sigma_z + -\frac12\sum_{i=1}^n q_\phi(x_i \mid z)(x_i-\mu_z)^T\Sigma_z^{-1}(x_i-\mu_z). \end{align*}

The solution for $\mu_z$ is as follows. We have:

$\nabla_{\mu_z}\mathcal L(\theta, \phi) = -\sum_{i=1}^n q_\phi(x_i \mid z)\Sigma_z^{-1}(x_i-\mu_z) =\Sigma_z^{-1}(\mu_z - \sum_{i=1}^n q_\phi(x_i \mid z)\,x_i)$.

which is solved by $\mu_z’ = \sum_{i=1}^n q_\phi(x_i \mid z)\,x_i$.


The solution for $\Sigma_z$ is as follows. Let $W := \Sigma_z^{-1/2}$ have columns $w_1, \dots, w_d$, and define $\hat \Sigma_z := \sum_{i=1}^n q_\phi(x_i \mid z)(x_i-\mu_z)(x_i-\mu_z)^T$. Then we can solve by setting gradient wrt $W$ to zero.

\begin{align*}&\nabla_W\mathcal L(\theta, \phi) = \nabla_W\left(-\frac12 \log\det\Sigma_z + -\frac12\sum_{i=1}^n q_\phi(x_i \mid z)(x_i-\mu_z)^T\Sigma_z^{-1}(x_i-\mu_z)\right)\\ &\propto \nabla_W\left(\log\det (WW^T)^{-1} + \sum_{i=1}^n q_\phi(x_i \mid z)\|W^T(x_i-\mu_z)\|^2\right)\\ &= \nabla_W\left(-\log\det WW^T + \sum_{i=1}^n q_\phi(x_i \mid z)\sum_{j=1}^d (w_j^T(x_i-\mu_z))^2\right)\\ &= \nabla_W\left(-\log\det WW^T + \sum_{j=1}^d w_j^T\left(\sum_{i=1}^n q_\phi(x_i \mid z)(x_i-\mu_z)(x_i-\mu_z)^T\right)w_j\right)\\ &= \nabla_W\left(-2\log\det W + \sum_{j=1}^d w_j^T\hat \Sigma_z w_j\right)\\&= -2W^{-T} + 2\hat\Sigma_zW. \end{align*}

where we use $\log\det WW^T = \log\det W^2 = 2\log\det W$ since $W$ is symmetric.

This is solved by $W = \hat\Sigma_z^{-1/2}$, hence $\Sigma_z’ = \hat \Sigma_z = \sum_{i=1}^n q_\phi(x_i \mid z)(x_i-\mu_z)(x_i-\mu_z)^T$.


The E-step can also be viewed as inference of the latent. If $\Sigma_z = I$ is fixed, then it degenerates to the following.

\begin{align*} q_\phi(\cdot \mid x) &= \mathrm{softmax}([\log p_z - (1/2)\|x-\mu_z\|^2]_{z=1}^k)\\ &= \mathrm{softmax}([\mu_z^Tx + (\log p_z - (1/2)\|\mu_z\|^2)]_{z=1}^k) \end{align*}

So for example, if the penultimate features of a neural network are normally distributed with mean conditional on the label, then the optimal final weights are just the means.

Also, for example, if we have a uniform mixture of 1D normal distributions, then $\mathrm{softmax}(-\mu_1^2/2, \dots, -\mu_k^2/2)_z = p_\theta(z \mid x = 0)$.

  • If $\mu_1 = 0, \mu_2 = \mu$, then $p_\theta(z=2 \mid x = 0) = \mathrm{softmax}(0, -\mu^2/2)_2 = \sigma(-\mu^2/2)$.
  • If $\mu_1 = -1$, $\mu_2 = +1$, then $p_\theta(z = 2 \mid x) = \mathrm{softmax}(-x, x) = \sigma(2x).$