44 Variational Inference

44.1 Rationale

Performing the EM algorithm required us to be able to compute $$f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})$$ and also optimize $$Q({\boldsymbol{\theta}}, {\boldsymbol{\theta}}^{(t)})$$. Sometimes this is not possible. Variational inference takes advantage of the decomposition

$\log f({\boldsymbol{x}}; {\boldsymbol{\theta}}) = \mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}}) + {\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}}))$

and instead considers other forms of $$q({\boldsymbol{z}})$$ to identify a more tractable optimization.

44.2 Optimization Goal

Since

$\log f({\boldsymbol{x}}; {\boldsymbol{\theta}}) = \mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}}) + {\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}}))$

it follows that the closer $$q({\boldsymbol{z}})$$ is to $$f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})$$, the term $$\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})$$ grows larger while $${\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}}))$$ becomes smaller. The goal is typically to identify a restricted form of $$q({\boldsymbol{z}})$$ that maximizes $$\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})$$, which serves as an approximation to the posterior distribution $$f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})$$.

44.3 Mean Field Approximation

A mean field approximation implies we restrict $$q({\boldsymbol{z}})$$ to be

$q({\boldsymbol{z}}) = \prod_{k=1}^K q_k({\boldsymbol{z}}_k)$

for some partition $${\boldsymbol{z}}= ({\boldsymbol{z}}_1, {\boldsymbol{z}}_2, \ldots, {\boldsymbol{z}}_K)$$. This partition is very context specific and is usually driven by the original model and what is tractable.

44.4 Optimal $$q_k({\boldsymbol{z}}_k)$$

Under the above restriction, it can be shown that the $$\{q_k({\boldsymbol{z}}_k)\}$$ that maximize $$\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})$$ have the form:

$q_k({\boldsymbol{z}}_k) \propto \exp \left\{ \int \log f({\boldsymbol{x}}, {\boldsymbol{z}}; {\boldsymbol{\theta}}) \prod_{j \not= k} q_j({\boldsymbol{z}}_j)d{\boldsymbol{z}}_j \right\}.$

These pdf’s or pmf’s can be calculated iteratively by cycling over $$k=1, 2, \ldots, K$$ after intializing them appropriately. Note that convergence is guaranteed.

44.5 Remarks

• If $${\boldsymbol{\theta}}$$ is also random, then it can be included in $${\boldsymbol{z}}$$.

• The estimated $$\hat{f}({\boldsymbol{z}}| {\boldsymbol{x}})$$ is typically concentrated around the high density region of the true $$f({\boldsymbol{z}}| {\boldsymbol{x}})$$, so it is useful for calculations such as the MAP, but it is not guaranteed to be a good overall estimate of $$f({\boldsymbol{z}}| {\boldsymbol{x}})$$.

• Variational inference is typically faster than MCMC (covered next).

• Given this is an optimization procedure, care can be taken to speed up convergence and avoid unintended local maxima.