Deep InfoMax: Learning good representations through mutual information maximization
As researchers continue to apply machine learning to more complex real-world problems, they’ll need to rely less on algorithms that require annotation. This is not only because labels are expensive, but also because supervised learners trained only to predict annotations tend not to generalize beyond structure in the data necessary for the given task. For instance, a neural network trained to classify images tends to do so based on texture that correlates with the class label rather than with shape or size, which may limit the suitability of the classifier in test settings. This issue, among others, is a core motivation for unsupervised learning of good representations.
Learning good representations without relying on annotations has been a long-standing challenge in machine learning. Our approach, which we call Deep InfoMax (DIM), does so by learning a predictive model of localized features of a deep neural network. The work is presented at the 2019 International Conference on Learning Representations (ICLR).
DIM (code on GitHub) is based on two learning principles: mutual information maximization in the vein of the infomax optimization principle and self-supervision, an important unsupervised learning method that relies on intrinsic properties of the data to provide its own annotation. DIM is flexible, simple to implement, and incorporates a task we call self-prediction.
Mutual information estimation: Does this pair belong together?
DIM draws inspiration from the infomax principle, a guideline for learning good representations by maximizing the mutual information between the input and output of a neural network. In this setting, the mutual information is defined as the KL-divergence between the joint distribution—all inputs paired with the corresponding outputs—and product-of-marginals distribution—all possible input/output pairs. While intuitive, infomax has had limited success with deep networks, partially because estimating mutual information is difficult in settings where the input is high-dimensional and/or the representation is continuous.
The recently introduced Mutual Information Neural Estimator (MINE) trains a neural network to maximize a lower bound to the mutual information. MINE works by training a discriminator network between samples from the joint distribution, also known as positive samples, and samples from the product-of-marginals distribution, also known as negative samples. For a neural network encoder, the discriminator in MINE is tasked with answering the following question: Does this pair—the input and the output representation—belong together?
DIM borrows this idea from MINE to learn representations using the gradients from a discriminator to help train the encoder network. This is similar to learning for the generator in generative adversarial networks (GANs), except the encoder is making this task easier for the discriminator, not harder. In addition, we don’t rely on the KL-based discriminator from MINE, as this works less effectively in practice than discriminators that use the Jensen-Shannon divergence (JSD) or infoNCE, an estimator used by Contrastive Predictive Coding (CPC).
Learning shared information
Unfortunately, training an encoder to only maximize the mutual information between the input and output will yield representations that contain trivial or “noisy” information from the input. For example, in the cat picture below, there are many locations, or patches, from which a neural network could extract information that would increase the mutual information during optimization.
But there is only a subset of locations we’re actually interested in. Worse still, maximizing mutual information between the whole image input and the output representation, which we refer to as global DIM in the paper, will be biased toward learning features that are unrelated, as their sum has more unique information than redundant locations. For example, the ear, eye, and fur all indicate information about the same thing—“a cat”—so encoding all of these locations won’t increase the mutual information as much as encoding the foliage in the background.
So pure mutual information maximization isn’t exactly what we want. We want to maximize information that is shared across the input—in this case, across relevant locations. To accomplish this, we maximize the mutual information between a global summary feature vector, which coincides with the full image, and feature vectors corresponding to local patches. This is a self-supervision task analogous to training the encoder to predict local patch features from a global summary feature; we call this local DIM, but for simplicity, we’ll refer to this approach as “DIM” below.
Predicting the local given the whole
DIM is not unlike other self-supervision approaches that learn to predict local patches across an image at the feature level. Other works learn a representation of images by asking the encoder to pick the correct patch for a given location from among a set of candidates and conditioned on another patch—for example, conditioned on the center patch—with the task performed completely in the representation space. In natural language processing, methods like Quick Thoughts use similar types of self-supervision to learn good sentence representations. However, these approaches are purely autoregressive: They all involve tasks where the conditioning and predicted features correspond to different locations in the input. In DIM, the conditioning global feature can be a function of all the local features being predicted, a task we call self-prediction, but DIM is also flexible enough to incorporate autoregression.
The base encoder architecture for DIM with images is very simple, requiring only a small modification on top of a convolutional neural network. First, a standard ConvNet yields a set of localized feature vectors. These local features are summarized using a standard neural network such as a combination of convolutional and fully connected layers. The output of this is the global summary feature vector, which we use to maximize the mutual information with all local feature vectors.
To estimate and maximize the mutual information, one can use a variety of neural architectures as small additions to the encoder architecture. One that worked well for us was a deep bilinear model in which both the local and global feature vectors were fed into separate fully connected networks with the same output size, followed by a dot product between the two for the score. This score was then fed into either JSD, infoNCE, or the KL-based estimator in MINE.
We evaluated DIM, along with various other unsupervised models, by training a small nonlinear classifier on top of the local representations on the CIFAR-10, CIFAR-100, Tiny ImageNet, and STL-10 datasets. DIM outperformed all methods we tested and proved comparable to supervised learning. We also explored other ways to measure the “goodness” of the representation, such as independence of the global variables and the ability to reconstruct, with results and analyses available in the paper.
While DIM favors self-prediction over the type of autoregression commonly found in self-supervision models like CPC, DIM is flexible and can be easily modified to incorporate some autoregression, which can ultimately improve the representation.
This can be done by computing the global vector with part of the input occluded. The local features are then computed using the complete original input. This is a type of orderless autoregression that allows us to make the task slightly harder for DIM and potentially improve results. Alternatively, we can perform mixed self-prediction and orderless autoregression by computing multiple global vectors using a simple convolutional layer on top of the local features.
We compared DIM with single and multiple global vectors—that is, without and with orderless autoregression—to CPC (ordered autoregression) on classification tasks using ResNet architectures and strided crops as used in CPC. DIM and CPC performed comparably despite the much simpler and faster task of DIM. This indicates the strict ordered autoregression in CPC may not be necessary for these types of representation learning tasks on images.
|DIM (single global)||80.95||76.97|
|DIM (multiple globals)||77.51||78.21|
Overall, the ideas behind DIM are simple and can be easily extended to other domains. They’ve already been extended—with impressive results—to learning unsupervised representations of graphs and brain imaging data. We’re optimistic this is only the beginning for how DIM can help researchers advance machine learning by providing a more effective way to learn good representations.