Neural Graphical Models

Published

By , Senior Applied Scientist , Principal Applied Scientist

This research paper was presented at the 17th European Conference on Symbolic and Quantitative Approaches to Reasoning with Uncertainty (opens in new tab), a premier forum for advances in the theory and practice of reasoning under uncertainty.

ECSQARU Blog Hero:
Neural Graphical Models

In the field of reasoning under uncertainty, probabilistic graphical models (PGMs) stand out as a powerful tool for analyzing data. They can represent relationships between features and learn underlying distributions that model functional dependencies between them. Learning, inference, and sampling are operations that make graphical models useful for domain exploration.  

In a broad sense, learning involves fitting the distribution function parameters from data, and inference is the procedure of answering queries in the form of conditional distributions with one or more observed variables. Sampling entails the ability to extract samples from the underlying distribution as defined by the graphical model. A common challenge with graphical model representations lies in the high computational complexity of one or more of these operations.   

Various graphical models impose restrictions on the set of distributions or types of variables in the domain. Some graphical models work with continuous variables only (or categorical variables only) or place restrictions on the graph structure, for example, the constraint that continuous variables cannot be parents of categorical variables in a directed acyclic graph (DAG). Other restrictions affect the set of distributions the models can represent, for example, only multivariate Gaussian distributions.

Microsoft Research Podcast

AI Frontiers: AI for health and the future of research with Peter Lee

Peter Lee, head of Microsoft Research, and Ashley Llorens, AI scientist and engineer, discuss the future of AI research and the potential for GPT-4 as a medical copilot.

In our paper, “Neural Graphical Models (opens in new tab),” presented at ECSQARU 2023 (opens in new tab), we propose Neural Graphical Models (NGMs), a new type of PGM that learns to represent the probability function over the domain using a deep neural network. The parameterization of such a network can be learned from data efficiently, with a loss function that jointly optimizes adherence to the dependency structure, given as input in the form of a directed or undirected graph, and fit to the data. Probability functions represented by NGMs are unrestricted by any of the common restrictions inherent in other PGMs. NGMs can handle various input types: categorical, continuous, images and embedding representations. They also support efficient inference and sampling.

Figure 1 - The image on the left shows an undirected network graph with five variables: x1, x2, x3, x4 and x5. The variable x3 is connected to all other variables, and x1 is directly connected to x3 and x4 only. The annotation next to the nodes indicates that the value of each variable is a function of the values of its neighbors. For example, the value of x1 is a function of x3 and x4, the value of x2 is a function of x3, and so on. On the right, we see a table representing the adjacency matrix for the same graph, with both rows and columns labeled with variables names from x1 to x5. The cells show either ones or zeros. The ones indicate a presence of an edge, for example in the cell on the intersection of the row labeled x1 and the column labeled x3.
Figure 1: Graphical view of NGMs: The input graph G (undirected) for given input data X. Each feature \( x_i=f_i(\text{Nbrs}(x_i))\) is a function of the neighboring features. For a DAG, the functions between features will be defined by the Markov Blanket relationship \( x_i=f_i(\text{MB}(x_i))\). On the right, the adjacency matrix represents the associated dependency structure S.
Figure 2 - The image shows a neural network. The input layer has five variables: x1, x2, …, x5, and the corresponding output layer has the same five variables. Between the input and output layers there is one hidden layer with six nodes. Some of the units in the input layer are connected to the units in the hidden layer, and some of the units in the hidden layer are connected to the units in the output layer. A careful examination shows that there is a path from a unit xi in the input layer to a unit xj in the output layer whenever there is an edge from the xi node to the xj node in the graph in Figure 1. Note that there are no self-paths, that is, paths from xi in the input layer to xi in the output layer. Some of the remaining neural network connections representing zeroed-out weights are shown in dashed black lines.
Figure 2: Neural view of NGMs: This is a neural network as a multitask learning architecture capturing nonlinear dependencies for the features of the undirected graph in Figure 1. The presence of a path from the input to the output features indicates a dependency between them. The dependency matrix between the input and output of the NN reduces to matrix product operation \(S_{nn}=\Pi_i|W_i|=|W_1|\times|W_2|\). Note that not all the zeroed-out weights of the MLP (in black-dashed lines) are shown for the sake of clarity.

Experimental validations for NGMs

In our paper (opens in new tab), we evaluate NGMs’ performance, inference accuracy, sensitivity to the input graph, and ability to recover the input dependency structure when trained on both real and synthetic data: Infant mortality data (opens in new tab) from the Centers for Disease Control and Prevention (CDC), synthetic Gaussian Graphical model data, and lung cancer data from Kaggle. 

The infant mortality dataset (opens in new tab) describes pregnancy and birth variables for all live births in the US and, in instances of infant death before the first birthday, the cause of death. We used the latest available data, which includes information about 3,988,733 live births in the US during 2015. It was particularly challenging to evaluate the inference accuracy of NGMs using this dataset due to the (thankfully) rare occurrence of infant deaths during the first year of life, making queries concerning such low probability events hard to accurately estimate.  

We used the CDC data to evaluate the NGMs’ inference accuracy. We compared their prediction for four variables of various types: gestational age (ordinal, expressed in weeks), birth weight (continuous, specified in grams), survival until the first birthday (binary) and the cause of death. We used the categories of “alive,” the 10 most common causes of death, or “other” for the less common causes. Here, “alive” was indicated for 99.48% of infants. We also compared the performance of logistic regression, Bayesian networks, Explainable Boosting Machines (EBM), and NGMs. In case of NGMs, we trained two models: one using the Bayesian network graph and one using the uGLAD graph.

Our results demonstrate that NGM are significantly more accurate than logistic regression, more accurate than Bayesian networks, and on par with EBM models for categorical and ordinal variables. They particularly shine when predicting very low probability categories for multi-valued variable cause of death, where, in contrast most models (such as both PGMs and classification models) typically struggle. Note that while we need to train a separate LR and EBM model for each outcome variable evaluated, all variables can be predicted within one trained NGM model. Interestingly, the two NGM models show similar accuracy results despite the differences in the two dependency structures used in training. 

We believe that NGMs are an interesting amalgam of the deep learning architectures’ expressivity, and PGMs’ representation capabilities and can be applied in many domains, given that they place no restrictions on input types and distributions. We encourage you to explore NGMs and take advantage of the ability to work with a wider range of distributions and inputs. You can access the code for Neural Graphical Models on GitHub (opens in new tab).

Related publications

Continue reading

See all blog posts