Graphical Models Meet Deep Learning: Two Recent Attempts

We've seem the power of deep learning. But not everyone is happy with just "throwing the dataset to a big neural network". A good model design can still perform well or even better than naive neural network method, but with much fewer number of model parameters. Probabilistic Graphical models (PGMs) method is this type of thing: it has beautiful theory, lots of applications, and many machine learning researchers have been developing models and algorithms based on it. I remember the first deep learning paper I read was about restricted Boltzmann machines (RBM), and today I'm still fascinated by the elegance of the maths.

However inference of PGMs has become the barrier for large scale applications and complicated graphical structure. Surely people have developed fast approximate inference method (VI/EP to name a few) that works very well in practice. But the main problem comes from the memory cost. To make your model powerful you probably want to add a hierarchy of latent variables to describe both local and global behaviour. However for large datasets the number of approximate distribution you need to maintain increases drastically. This is especially true for intermediate-level latent variables (consider topic mixture for each document in LDA for example) which you can't just throw them away in SVI. Perhaps one of the most well known papers to address this problem is Kingma & Welling's variational auto-encoder (VAE), which amortized the inference by learning a mapping (represented by a neural network) from data to local latent variables. Since then this inference network/recognition model idea has been very popular. In this post I'm gonna briefly discuss two recent papers on using neural networks to assist approximate inference which I think are interesting read.

 

The first paper comes from Ryan Adams group at Harvard:

Composing graphical models with neural networks for structured representations and fast inference
Johnson et al.
https://arxiv.org/pdf/1603.06277v2.pdf

I would say the general idea is very simple: the authors wanted to extend the original SVI method to non-conjugate models. They basically just approximated the non-exponential family likelihood terms with conjugate exponential family distributions, and to reduce memory cost the natural parameters were parameterised by a recognition model. Then you can reuse the SVI derivations to show the fixed point/natural gradient equations, and I think the algorithm can actually be very fast in practice.

I can sort of see why their approach works and in fact I like their recipe which has analytical forms. But the general framework is not that neat: they optimised different objective functions for different parts of variational distribution. It's true that the proposed loss function has the nice lower-bounding property just as SVI, however the authors proved it by saying "the learned local variational distribution using recognition model approximated likelihood terms is not optimal for the exact SVI". In other words, their argument holds for any recognition model, and it's not obvious to me how tight is their objective to the marginal likelihood. But I haven't read the paper in very detail (they have a long appendix) so maybe I'm missing something here.

UPDATE 19/08/16

I actually spent sometime today to read the paper again. Well I still have the above questions. However I noticed another very interesting point that their new parameterisation can potentially alleviate the co-adaptation problem of global/local variational approximation for neural-network based latent variable models. In the appendix of the VAE paper a full Bayesian treatment (i.e. also approximate the posterior of generative network weights) has been briefly sketched. However I've never seen any published results saying this works well. I thing the main problem comes from the factorised approximation, which uses independent variational parameters to the posterior of weights and the latent variables, and then throws their "interaction" to the optimisation process. This paper, though still assuming factorisations, explicitly introduced the dependence between the variational parameters of latent variables and the variational approximation of weights. My colleagues and I recently started to investigate full Bayesian learning for VAE, and I think this paper's observation can potentially be very helpful.

 

The second paper comes from Le Song's group at Georgia Tech:

Discriminative Embeddings of Latent Variable Models for Structured Data
Hanjun Dai, Bo Dai, Le Song
http://jmlr.org/proceedings/papers/v48/daib16.pdf

The problem they looked at is structural prediction where PGMs are a major class of models being used there. The original pipeline contains three steps: 1. learn the graphical model parameters with unsupervised learning, 2. extract feature using posterior inference on latent variables, and 3. learn a classifier on those features. But since classification performance is the main target here, the authors suggested chain the three steps together and directly minimise the classification loss. Furthermore since the inference step is the main bottle neck (imagine doing message passing on a big graph for every datapoint and every possible setting of graphical model parameters), they amortized the messages using recognition models, which also implicitly capture the graphical model parameters. To further reduce the dimensionality they used distribution embeddings by computing the messages on a feature map. All these things are learned by supervised learning and the results are pretty impressive.

I have a mixed feeling of this paper. I like the idea of discriminative training and directly working on the messages, which are really the key ideas to make the algorithm scalable. Some might complain that it's unclear what's the PGM distribution it has learned from data, but I think loopy BP messages (at convergence) can provide accurate approximations to the marginals and you can also use a tree-like construction to roughly see what's going on with the joint. The main criticism comes from the convergence part for message passing. First BP has no convergence guarantees on loopy graphs. But more importantly the inference network approach actually puts constraints (i.e. share parameters) on the beliefs across datapoint, which is unlikely to return you the optimal answer for everyone. It's also true for VAEs that you won't get the optimal posterior approximation for every latent variables. But at least VAEs still provide a lower-bound on the marginal likelihood, while the Bethe free energy doesn't have this nice property. I guess it works fineĀ in the models the authors have considered, but I could imagine possible failure for this method when extended to unsupervised learning for densely connected graphs.