Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm

In the old days we were happy with mean field approximation. Currently we don't. As the model goes more complicated, Bayesian inference needs more accurate yet fast approximations to the exact posterior, and apparently mean-field is not a very good choice. To enhance the power of variational approximations people start to use invertible transformations (e.g. see the normalizing flow paper) that warp simple distributions (e.g. factorised Gaussian) to complicated ones that are still differentiable. However it introduces an extra cost: you need to be able to compute the determinant of the Jacobian of that transform in a fast way. Solutions of this include constructing the transform with a sequence of "simple" functions -- simple in the sense that the Jacobian is low-rank or triangular (e.g. see this paper). Recently I found a NIPS preprint that provides another solution of this, which is absolutely amazing: through their clever design you don't even need to compute the Jacobian at all! I found this paper a very interesting read so I decided to put a blog post here:

Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
Qiang Liu and Dilin Wang
http://arxiv.org/abs/1608.04471 (to appear at NIPS 2016)

You might also want to check out their previous work at this year's ICML.

Before going to the details, let's recap the idea of invertible transformations of random variables. Assume z is a random variable distributed as q(z), e.g. Gaussian. Now if a mapping T is differentiable and invertible, then the transformation x = T(z) is also a random variable, distributed as (using notations in the paper) q_T(x) = q(T^{-1}(x)) | det(\nabla_x T^{-1}(x)) |. If we use this distribution in VI, then the variational free energy becomes \mathrm{KL}[q_T(x)||p(x|\mathcal{D})] - \log p(\mathcal{D}), which means we need to evaluate the log determinant of the Jacobian \log | det(\nabla_x T^{-1}(x)) |. As said T often contains a sequence of "simple" mappings, in math this is T(z) = F_K(F_{K-1}(...F_1(z))), so the determinant term becomes \sum_{k=1}^K \log |det(\nabla F_k^{-1}) |.

Previous approaches parameterised the F functions with carefully designed neural networks, and optimised all these network parameters jointly through the free energy. This paper, although not explicitly mentioned, used sequential optimisation instead: it first finds F_1 by minimising the energy, then fixes it and proceeds to F_2 and so on. No fine-tuning at the end. Now optimisation becomes easier, but potentially you need more transformations and thus longer training time in total. Also storing these functions can be very challenging for memory.

To solve these problems the authors proposed using functional gradients. They assumed a very simple form for the transform: F_k(z) = z + f_k(z) where f_k belongs to some RKHS defined by a kernel K(x, \cdot), and used functional gradients to find the f_k function. Then instead of finding a local optimum of f_k, we can do "early stopping", or even just run one gradient step, then move to the next transform. Furthermore, if you start your search at f_k(z) = 0, then there's no need to evaluate the determinant as now \nabla F_{k}^{-1} = I ! This solves the running time problem even when we need much more of these transforms compared to previous approaches. Another nice explanation is that now T becomes T(z) = z + f(z) with f(z) = \sum_{k=1}^K f_1(z), and if the norm of f_k is small enough, then the above is also equivalent to one gradient step for f at point z + \sum_{i < k} f_i(z). I like both ways of interpreting this procedure and I would say it comes from the clever design of the sequential transform.

The authors provided the analytical form of this functional gradient which links back to the kernel Stein discrepancy discussed in their ICML paper:

\nabla_{f_k} \mathrm{KL}[q_{F_k} || p] |_{f_k = 0} = - \mathbb{E}_{z \sim q} [ \nabla_z K(z, \cdot) + \nabla_z \log p(z) K(z, \cdot) ],

where p(z) short-hands the joint distribution p(z, \mathcal{D}). Since now Monte Carlo estimation has become quite a standard approach for modern VI, we can use samples from the q distribution to compute the above gradient. Assume we take n samples z_1, ..., z_n \sim q. By taking only one gradient step with learning rate \epsilon the current transform becomes F_k(z) = z + \frac{\epsilon}{n} \sum_n [ \nabla_{z_n} K(z_n, z) + \nabla_{z_n} \log p(z_n) K(z_n, z) ]. We can also use mini-batches to approximate \log p(z, \mathcal{D}) to make the algorithm scalable on large datasets. In fact they reported even faster speed than PBP on Bayesian neural network with comparable results to the state-of-the-art, which looks very promising.

After the gradient step of f_k, the algorithm moves to the next transform F_{k+1}(z) = z + f_{k+1}(z) and again starts at f_{k+1}(z) = 0. Notice now the q distribution that we simulate z from in the above gradient equation becomes q_{F_k}, which do contain a non-identity Jacobian if we want to evaluate it directly. However recall the core idea of invertible transformation that we can simulate x \sim q_T by first sample z \sim q then apply x = T(z), and we do have samples from q from the last step. This means we can first apply the transform to update the samples (with Monte Carlo estimate) z_i \leftarrow z_i + \frac{\epsilon}{n} \sum_n [ \nabla_{z_n} K(z_n, z_i) + \nabla_{z_n} \log p(z_n) K(z_n, z_i) ], then use them to compute this step's gradient. It makes the algorithm sampling-like and the memory consumption only comes from storing these samples, which can be very cheap compared to storing lots of parametric functionals. If one is still unhappy with storing samples (say due to limited memory), one can fit a parametric density to the samples after the last transform F_K. Another strategy is to only use 1 sample, and in this case this algorithm reduces to maximum a posteriori (MAP), which still finds a mode of posterior for you.

The authors used RBF kernel and also pointed out that this algorithm also recovers MAP when the bandwidth tends to zero. This make sense as you can understand it through analogies of kernel density estimation, and in general running VI with q as a delta function is also equivalent to MAP. In other words, the number of samples n and the kernel implicitly define the family of variational distribution q: the kernel mainly controls the smoothness, and the number of samples roughly indicates the complexity of your fit. Presumably the main criticism I have is also from the nature that n affects the complexity of q, which is not the case for other black-box VI techniques. This means in high dimensions where using large number of samples has prohibitive memory cost, you need to carefully tune the kernel parameters in order to get a good fit, or even you might not be able to achieve significantly better results than MAP. Well using mean-field Gaussian in high-dimensions also seems insufficient, but at least the community has relatively clear understanding in this case. For this method it's unclear to me right now what the distribution looks like and whether it is often better than mean-field. Perhaps I should raise this questions to the authors and I’ll be very excited to hear more about it from their poster at NIPS!