Variational Inference algorithm

Sampling data or approximating probabilistic densities is one of the core problems in modern statistics, especially in Bayesian statistic. Beside of MCMC, Variational Inference is one of the two typical Bayesian approaches solving this problem. It leverages the Variational Inference algorithm to develop many recent powerful methods and models: Stochastic Variational Inference, Variational Autoencoder… This article briefly introduces Variational Inference algorithm.

The main ideas of Variational Inference (VI) is approximating the posterior with simpler variational distribution by solving optimization problem. The section 1 reviews Bayesian inference and latent variables. Section 2 presents objective of the algorithm and why it solves optimization problem. Section 3 shows some important mathematical contributions while attempting the target. Section 4 introduces Stochastic Variational Inference and finally the summary is presented.

1. Bayesian Inference and Latent variables

  • Bayesian Inference

Bayesian inference supposes the probability of observations x or tractable distributions depend on a quantity, called prior. Prior reflects personal perspective about observations, comes from knowledge of experts or tuning through experiments. Let denote it as \theta.

Prior: p(\theta)

Given theta, we can figure out what probability of outcome by a evidence function:

Likelihood: p(x|\theta)

We don’t always know exact values of prior. Therefore, we have the hypothesis reflecting how much \theta is matched with the true prior. With Bayes theorem: p(A|B)=\dfrac{p(B|A)p(A)}{p(B)}, the posterior is:

Posterior: p(\theta|x) = \dfrac{p(x|\theta)p(\theta)}{p(x)}

To test which parameters is best in describing the data, methods find which one has the maximum value of the likelihood or the posterior. If you’re not familiar with Bayesian inference, please read this post.

  • Latent variables

In statistics, latent variables are variables that are not directly observed but are rather inferred (through a mathematical model) from other variables that are observed (directly measured). Mathematical models that aim to explain observed variables in terms of latent variables are called latent variable models.

See Wikipedia

Latent variables are not directly observed but we have information about their structure and relationship to the observations. Thus, it can be inferred. In Bayesian inference, they are often factors that allow us to factorize the observations to tractable distributions.

For example, in topic models, the only observation variable is collections of documents with their words. The list of topics, the distribution between topics and documents or words, both of them are latent variables.

2. Model’s objective and optimization problem

Variational Inference strives to sample observations or difficult-to-compute distribution. It uses hidden variables to encode hidden structure in observed data and estimates the posterior. The posterior is not always tractable, therefore, more generic way is approximating it by simpler distribution governed by variational parameters.

Goal distribution: the posterior p(z|x) = \dfrac{p(x, z)}{\int p(x,z) dz}

Approximating distribution: the variational distribution q(z;\lambda),
where \lambda is a set of variational parameters.

Loss function: the negative log likelihood L(\lambda) = - log p(x|\lambda)

VI measures the distance between distributions by KL divergence:

D_{KL}\bigg(q(z) || p(z)\bigg) = -\int q(z) log \dfrac{p(z)}{q(z)} dz

Lower distance means better approximation, thus VI minimizes the distance and the problem becomes finding parameters which have optimal values on the targeting metric.

Instead of directly minimizing D_{KL}, authors maximizes ELBO L(\lambda). For short, this term has close relation to D_{KL} and log likelihood p(x):

log p(x) = L(\lambda) + D_{KL}

3. Mathematical notes

This section explains in detail mathematical points and how VI solves optimization problems.

3.1 Evidence Lower Bound (ELBO)

This term is called ELBO since it is the lower bound of evidence function of data:

log p(x)= log \int p(x,z) dz = log \int \dfrac{p(x,z)q(z;\lambda)}{q(z; \lambda)} dz = log E_{q(z;\lambda)}\bigg[\dfrac{p(x,z)}{q(z; \lambda)} \bigg]

\geq E_{q(z;\lambda)}\bigg[log \dfrac{p(x,z)}{q(z; \lambda)} \bigg] = ELBO(\lambda) (applying Jensen’inequality)

Proof of log p(x) = ELBO + D_{KL} and applying Jensen’inequality in ELBO, see Section 3 of this post.

3.2 Mean-field variational inference

To express q(z;\lambda) as a function of \lambda, there are many ways, as well as the tradeoff between expressive to approximate and simple to track. The mean-field theory assumes that all variables are independent to each other, therefore correlation information between them is lost. In VI, the distribution is factorized into factors and each other is governed by its own parameters:

q(z; \lambda) = \Pi_{i=1}^{N}q(z_i;\lambda_i)

Due to this assumption, VI updates iteratively by pairs (z_i, \lambda_i) to optimize L(\lambda) by gradient descent or coordinate descent.

while (ELBO has not convergenced)
	compute ELBO (q)    
	for j in {1, .. m}:
		q_j(z_j) = q_j*(z_j) # optimal value of q_j(z_j)

Coordinate descent finds the steepest values of each variable while keeping others fixed, then updates the steepest to the variable. In contrast, gradient descent optimizes the whole dimensions at the same time, but with a small ratio of magnitude rather than full length of distant from current points to optimal values. Each time updating all factors, despite independently or simultaneously, the lower bound (ELBO) is increased, the training process found the better local optimum.

Both gradient descent or coordinate descent work by finding steepest values of local convex loss function. How to find them?

3.3 Natural gradient

The important thing to optimize the loss function L(\lambda) by VI is finding optimal value on each updating step. If the loss function is in Euclidean space, it is easy to compute the distance by simply taking the first-order gradient. However, the parameters are still in Euclidean space, but the distance metric is not. Comparing between 2 distributions with just subtraction their mean is insufficient. A lot of information, such as shape, variation, skew… of distributions are not considered. Therefore, we need another way to compare distance metric in distribution space, then indicate how to update parameters on Euclidean space. This is done by natural gradient.

  • Natural gradient and Fisher Information

Natural gradient is proven to be extracted by the equation:

\hat\nabla L(\lambda) = F^{-1}\nabla L(\lambda)
where F is the Fisher Information of the loss function L(\lambda).-

The Fisher Information matrix is computed by getting second-order gradient, Hessian matrix, of the D_{KL}.

Let {P_\theta} denote a parametric family of a distribution on space X. Density function: p_\theta. Equation of Fisher information:

F_\theta := E_\theta \bigg[ \nabla_\theta log (p_\theta(x)) \nabla_\theta log(p_\theta(x))^T \bigg] = -E_\theta \bigg[ \nabla_\theta^2 log(p_\theta(x)) \bigg]

  • Proof F= H_{KL}

Suppose that p(x|\lambda^*) is the true posterior, and \lambda^* is the optimal value of \lambda.

D_{KL}\bigg( p(x|\lambda^*) || p(x|\lambda) \bigg) =  E_{p(x|\lambda)}\big[ log (p(x|\lambda^* )\big) \big] - E_{p(x|\lambda)}\big[ log (p(x|\lambda )\big) \big]

\nabla_\lambda^1 D_{KL}\bigg( p(x|\lambda^*) || p(x|\lambda) \bigg) = \nabla_\lambda^1 E_{p(x|\lambda)}\big[ log (p(x|\lambda^* )\big) \big] - \nabla_\lambda^1 E_{p(x|\lambda)}\big[ log (p(x|\lambda)\big) \big]

= - E_{p(x|\lambda)}\big[\nabla_\lambda^1 log (p(x|\lambda )\big) \big]

\nabla_\lambda^2 D_{KL}\bigg( p(x|\lambda^*) || p(x|\lambda) \bigg) = - E_{p(x|\lambda)}\big[\nabla_\lambda^2 log (p(x|\lambda)\big) \big] = - \int p(x|\lambda) \nabla_\lambda^2 log (p(x|\lambda)\big) dx

Evaluate \lambda at \lambda^*:

H_{KL}\bigg( p(x|\lambda^*) || p(x|\lambda) \bigg) = - \int p(x|\lambda) \nabla_\lambda^2 log (p(x|\lambda)\big) \big|_{\lambda=\lambda^*} dx = - \int p(x|\lambda^*) \nabla_\lambda^2 log (p(x|\lambda^*)\big) dx

= -E_{p(x|\lambda^*)}\bigg[ \nabla_\lambda^2 log (p(x|\lambda^*)\big) \bigg]

Thus, H_{KL}\bigg( p(x|\lambda^*) || p(x|\lambda) \bigg) = F_{p(x|\lambda^*)}

  • Finding steepest points

Let KL(\lambda') = D_{KL}\bigg( p(x|\lambda') || p(x|\lambda)\bigg). \lambda^* = \lambda + \nabla KL(\lambda). Finding \lambda^* means minimizing a function in distribution space:

min \big\{KL(\lambda + \nabla KL(\lambda))\big\}

With second-order Taylor Series:

f(x + dx) \approx f(x) + \nabla f(x)^T \Delta x  + \dfrac{1}{2}\Delta x^T H(x)\Delta x

Let d = \nabla KL(\lambda). We have:

KL (\lambda^*) = KL(\lambda + d) \approx KL(\lambda) + \nabla KL(\lambda)^Td + \dfrac{1}{2}d^T F d

KL(\lambda) = 0 because it is the distance from \lambda to itself, thus:

KL (\lambda^*) = KL(\lambda+d) \approx \dfrac{1}{2}d^T F d

To find \lambda^*, corresponding to minimizing the loss function KL, we minimize the L = -log p(x|\lambda):

min \big\{L(\lambda + \nabla L(\lambda))\big\}

Let \hat{d} =\nabla L(\lambda). It was proved that L(\lambda +\hat{d}) = min \iff \hat{d} = KL (\lambda^*):

\hat{d} = KL (\lambda^*) \Rightarrow d = F^{-1}\hat{d} or \hat\nabla_\theta{L} = F^{-1}\nabla_\theta{L}.

The term \hat\nabla_\theta{L} = F^{-1}\nabla_\theta{L} is called natural gradient. In fact, F is a Riemannian metric performing transformation from distribution space to Euclidean space.

More detail in explaining natural gradient descent, check out here.

4. Stochastic Variational Inference (SVI)

SVI was inspired by Stochastic Gradient Descent, which separates the training set into many sub-set and feeds data to neural network with smaller size. This is helpful to save resource usage so that SVI is highly scalable while traditional method could fail to be trained because of loading full size of data. The challange of SVI is how to approximate the natural gradients of full data by natural gradients of small sub-set. They introduce global and local parameters for all hidden variables, including all latent variables and all their parameters. The global holds natural gradients of full dataset, while the local saves natural gradients of current sub-set.

Summary

Variational Inference methods use hidden variables to encode hidden structure in observed data. The relationship between the hidden variables and observations is factorized by joint probability distribution. Therefore, to sample observations, we estimate the posterior distribution (the conditional distribution of the hidden structure given the observations). This posterior could be not tractable to compute and we must appeal to approximate methods, thus approximating problem becomes optimization problem.

While dealing with the optimization task, KL divergence is used as a distance metric. But we do not minimize it. Instead, ELBO, which has close relation to KL divergence, provides an easier way to reach global optimum by iteratively increasing the lower bound. This step is done by coordinate descent or gradient descent to find the steepest points of factors.

Another important theorem is mean-field variational inference. It allows approximating complicated structure of distributions and also is easier to track. Then each factor can be updated independently by steepest descent algorithms.

These algorithms are difficult to find optimal points without assisting of the natural gradient since parameters are in Euclidean space while distance metric is in distribution space. Fisher information matrix, a transformation between two kinds of space, is a key contribution of natural gradient.

Finally, Stochastic Variational Inference inherits ideas of Stochastic Gradient Descent and proposes terms of global and local parameters. They respectively hold full-batch gradients and mini-batch gradients of models. SVI achieves in improving the scalability of methods, making Variational Inference methods more wide-known today.

1 thought on “Variational Inference algorithm”

  1. Chào An,

    Lâu rồi không thấy An viết blog lại. Nếu An đọc được những dòng này, mong em sẽ có thật nhiều sức khoẻ và một tâm hồn an lạc như chính tên em.

    Thân ái,
    Duy.

    Liked by 1 person

Leave a comment