44 Variational Inference

44.1 Rationale

Performing the EM algorithm required us to be able to compute \(f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})\) and also optimize \(Q({\boldsymbol{\theta}}, {\boldsymbol{\theta}}^{(t)})\). Sometimes this is not possible. Variational inference takes advantage of the decomposition

\[ \log f({\boldsymbol{x}}; {\boldsymbol{\theta}}) = \mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}}) + {\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})) \]

and instead considers other forms of \(q({\boldsymbol{z}})\) to identify a more tractable optimization.

44.2 Optimization Goal

Since

\[ \log f({\boldsymbol{x}}; {\boldsymbol{\theta}}) = \mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}}) + {\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})) \]

it follows that the closer \(q({\boldsymbol{z}})\) is to \(f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})\), the term \(\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})\) grows larger while \({\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}}))\) becomes smaller. The goal is typically to identify a restricted form of \(q({\boldsymbol{z}})\) that maximizes \(\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})\), which serves as an approximation to the posterior distribution \(f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})\).

44.3 Mean Field Approximation

A mean field approximation implies we restrict \(q({\boldsymbol{z}})\) to be

\[ q({\boldsymbol{z}}) = \prod_{k=1}^K q_k({\boldsymbol{z}}_k) \]

for some partition \({\boldsymbol{z}}= ({\boldsymbol{z}}_1, {\boldsymbol{z}}_2, \ldots, {\boldsymbol{z}}_K)\). This partition is very context specific and is usually driven by the original model and what is tractable.

44.4 Optimal \(q_k({\boldsymbol{z}}_k)\)

Under the above restriction, it can be shown that the \(\{q_k({\boldsymbol{z}}_k)\}\) that maximize \(\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})\) have the form:

\[ q_k({\boldsymbol{z}}_k) \propto \exp \left\{ \int \log f({\boldsymbol{x}}, {\boldsymbol{z}}; {\boldsymbol{\theta}}) \prod_{j \not= k} q_j({\boldsymbol{z}}_j)d{\boldsymbol{z}}_j \right\}. \]

These pdf’s or pmf’s can be calculated iteratively by cycling over \(k=1, 2, \ldots, K\) after intializing them appropriately. Note that convergence is guaranteed.

44.5 Remarks

  • If \({\boldsymbol{\theta}}\) is also random, then it can be included in \({\boldsymbol{z}}\).

  • The estimated \(\hat{f}({\boldsymbol{z}}| {\boldsymbol{x}})\) is typically concentrated around the high density region of the true \(f({\boldsymbol{z}}| {\boldsymbol{x}})\), so it is useful for calculations such as the MAP, but it is not guaranteed to be a good overall estimate of \(f({\boldsymbol{z}}| {\boldsymbol{x}})\).

  • Variational inference is typically faster than MCMC (covered next).

  • Given this is an optimization procedure, care can be taken to speed up convergence and avoid unintended local maxima.