44 Variational Inference
44.1 Rationale
Performing the EM algorithm required us to be able to compute \(f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})\) and also optimize \(Q({\boldsymbol{\theta}}, {\boldsymbol{\theta}}^{(t)})\). Sometimes this is not possible. Variational inference takes advantage of the decomposition
\[ \log f({\boldsymbol{x}}; {\boldsymbol{\theta}}) = \mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}}) + {\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})) \]
and instead considers other forms of \(q({\boldsymbol{z}})\) to identify a more tractable optimization.
44.2 Optimization Goal
Since
\[ \log f({\boldsymbol{x}}; {\boldsymbol{\theta}}) = \mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}}) + {\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})) \]
it follows that the closer \(q({\boldsymbol{z}})\) is to \(f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})\), the term \(\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})\) grows larger while \({\text{KL}}(q({\boldsymbol{z}}) \|f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}}))\) becomes smaller. The goal is typically to identify a restricted form of \(q({\boldsymbol{z}})\) that maximizes \(\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})\), which serves as an approximation to the posterior distribution \(f({\boldsymbol{z}}| {\boldsymbol{x}}; {\boldsymbol{\theta}})\).
44.3 Mean Field Approximation
A mean field approximation implies we restrict \(q({\boldsymbol{z}})\) to be
\[ q({\boldsymbol{z}}) = \prod_{k=1}^K q_k({\boldsymbol{z}}_k) \]
for some partition \({\boldsymbol{z}}= ({\boldsymbol{z}}_1, {\boldsymbol{z}}_2, \ldots, {\boldsymbol{z}}_K)\). This partition is very context specific and is usually driven by the original model and what is tractable.
44.4 Optimal \(q_k({\boldsymbol{z}}_k)\)
Under the above restriction, it can be shown that the \(\{q_k({\boldsymbol{z}}_k)\}\) that maximize \(\mathcal{L}(q({\boldsymbol{z}}), {\boldsymbol{\theta}})\) have the form:
\[ q_k({\boldsymbol{z}}_k) \propto \exp \left\{ \int \log f({\boldsymbol{x}}, {\boldsymbol{z}}; {\boldsymbol{\theta}}) \prod_{j \not= k} q_j({\boldsymbol{z}}_j)d{\boldsymbol{z}}_j \right\}. \]
These pdf’s or pmf’s can be calculated iteratively by cycling over \(k=1, 2, \ldots, K\) after intializing them appropriately. Note that convergence is guaranteed.
44.5 Remarks
If \({\boldsymbol{\theta}}\) is also random, then it can be included in \({\boldsymbol{z}}\).
The estimated \(\hat{f}({\boldsymbol{z}}| {\boldsymbol{x}})\) is typically concentrated around the high density region of the true \(f({\boldsymbol{z}}| {\boldsymbol{x}})\), so it is useful for calculations such as the MAP, but it is not guaranteed to be a good overall estimate of \(f({\boldsymbol{z}}| {\boldsymbol{x}})\).
Variational inference is typically faster than MCMC (covered next).
Given this is an optimization procedure, care can be taken to speed up convergence and avoid unintended local maxima.