From Graph Neural Networks in Action by Keita Broadwater
In fields that have no lack of hype for new technologies and methods, GNNs are without hyperbole a leap forward for both the field of deep learning and for the domain of graph analytics.
Graph data has characteristics that have presented problems for traditional machine learning and deep learning methods to penetrate. When represented in grid-like formats and data structures, graph data has qualities that prevent consistent application of machine learning methods. Yet, graphs are a data type that is rich with information. They are also ubiquitous: We can find network structures in nature (molecules), society (social networks), technology (the internet), and mundane settings (roadmaps). In order to use this rich and ubiquitous data type for machine learning, we need a specialized form of neural network dedicated to work on graph data: the graph neural network or GNN.
Defining graphs and graph neural networks.
Graphs are data structures whose elements are relationships (expressed as edges) between entities (expressed as nodes, also called vertices). In graphs, the links and relationships in our data are first class citizens, and are key to how we model the world, analyze data, and learn from it. This is in contrast to the grid-like data that is common in traditional learning and analytics problems, like the relational database table, pandas dataframe, or excel/google sheet. The relationships in such grid-like data are much more muted than in networks or graphs.
A graph neural network (GNN) is an algorithm that allows you to represent and learn from graphs, including their constituent nodes, edges and features. Graph neural networks provide analogous advantages to conventional neural networks, except the learning happens on graph data. .
Why and when would you want to use graph data and a GNN instead of other neural networks? Let’s find out by casting a classic machine-learning problem in a new light using graphs.
When to use a GNN
The obvious place to use a GNN is when the data has been modeled as a graph. Graph data structures are a versatile data structure, so with a little imagination, many problems could be framed in this way. A critical consideration is whether the relationships being modeled are of relevance to the problem at hand.
In contrast, for many applications, individual data entries are meant to stand alone, and relationships are not useful. For example, consider a non-profit organization that has maintained a database of donors and their financial contributions. These donors are spread across the country, and for the most part don’t personally know each other. Also, the non-profit is primarily concerned with growing donor contributions. Representing their donor data with a graph data structure and storing it using a graph database would be a waste of time and money.
Another place to consider using a GNN is when a predictive task is involved. Applications where graph neural networks shine are problems that require predictive models. We either use the GNN to train such a model, or we use it to produce a representation of the graph data that can be used in a downstream model. These prediction tasks may occur at the level of the node, the edge, or the graph.
Node-Level Tasks Nodes in graphs often have attributes and labels. Node classification is the node analog to traditional machine learning classification: given a graph with node classes, we can accurately predict the class of an unlabeled node.
Edge-Level Tasks Edge tasks are very similar to nodes. Link classification involves predicting a link’s label in a supervised or semi-supervised way. Edge Prediction is inferring a link between nodes, where one may not exist in the graph under study.
Graph-Level Tasks In real scenarios, graphs may consist of a large set of connected nodes, many disconnected graphs, or a combination of a large connected graph with many smaller unconnected graphs. In these situations the learning task may be different. Graph classification and regression involve predicting the label or a quantity associated with the graph. An example of graph classification will be outlined in the next section.
There are also unsupervised tasks that involve GNNs. These involve Graph Auto Encoders that embed graphs and do the opposite process of generating a graph from an embedding. Then there is a class of models that use adversarial methods to generate graphs.
Casting the Titanic dataset as graphs
An example of a dataset that has historically been approached in a grid manner, but is full of unexplored links and relationships, is the Titanic dataset, whose observations span the passengers of the doomed ship, and label their survivorship. Most machine learning challenges and projects based on this dataset express it in terms of the target variable (survivorship) and columns of features, which fit well in a table or dataframe. As you can see in the following figure, each row of these tables represents one person.
Figure 1. The Titanic Dataset is usually displayed and analyzed using a table format.
As expressive as tables like this are, their account of links and relationships are superficial at best. In particular, they fail to convey the social links between the people, the corridors that linked locations on the ship, and the communication network.
First, the people on the ship shared multiple types of relationships, including marital and blood relations (married, engaged, parent/child, siblings, and cousins), business and employment relationships, and, of course, friendships. Many table versions of the Titanic dataset give boolean indicators or counts of immediate family relationships (e.g. in the table above, the alone feature is True if a person had immediate family on the ship).
Network representations of social relationships can add depth via specificity. For example, having relatives on the ship may be a factor in favor of survival, but having socially important relatives would probably give a greater chance at survival. Social networks can convey social importance via their structure. An example of a graph representation of families on the Titanic is shown in figure 2.
The next example of a network on the Titanic is the ship’s corridors, consisting of connected hallways, stairways, junctions and landings, and the adjoining cabins and workrooms can also be represented as a graph. These are important to the survival question because the ease to quickly get to a lifeboat in a crisis depends on one’s starting location. From certain cabins and decks, it would be easier to reach a lifeboat in time than others.
The third example of a network on the Titanic is the ship’s communication network, which the ship’s crew used to communicate between themselves, to the passengers, and to the outside world (figure 2). This is relevant to the question of survival, because critical information about the crisis would reach individuals only in proximity to communication nodes. The communication networks of the Titanic and surrounding ships (both by wireless telegraph and by analog signals, such as flares), also impacted survivability. There were two ships within 14 miles of the Titanic when it sank. They saw the distress flare, but for different reasons, decided not to help. If they were linked by wireless, the compelling information about the rapid sinking of the ship may have swayed one or both ships to attempt a rescue.
Figure 2. The Titanic Dataset.
Left: with family relationships visualized as graphs (Image by Matt Hagy: https://github.com/matthagy/titanic-families https://www.kaggle.com/matthagy/discussion?sortBy=mostVotes&group=commentsAndTopics).
Right: An old illustration of the wireless communication between the Titanic and other ships. It’s easy to interpret the ships as nodes and the wireless connections as edges.
Assuming one had access to this graph data and it was of sufficient quantity and quality, what could a GNN glean from it? First, since the historical records are not altogether intact, we could use it to fill in missing information. GNNs use node classification to predict node attributes, which in our case could be applied to predict missing passenger information. An example could be if citizenship information was missing from some of the passenger data We could use node classification to uncover these citizenship labels. GNNs also use edge prediction (or link prediction) to uncover hidden or missing links between nodes. In our case, we could use it to find non-obvious relationships between the passengers, Examples would be extended family relationships (cousins on the ship), and business relationships. At the ship level, if we had similar data for several other large ships, we could classify the ships themselves and possibly find characteristics that could portend disaster. Finally, we could encode this graph data, known as embedding, and use it as additional features in conventional Titanic-focused machine learning and deep learning solutions.
How do GNNs Work?
Though there are a variety of GNN architectures at this point, they all tackle the problem of dealing with the unstructured nature of graph data. Essentially the mechanism to tackle this problem is to exchange information across the graph structure during the learning process.
In a conventional neural network (and in machine learning methods in general), we initialize a set of parameters. Then we iteratively update these parameters by:
- Inputting our data
- Having that data flow through layers that transform the data according to the parameters of that layer.
- Output a prediction, which is used to update the parameters.
Figure 3. Comparison of (simple) non-graph neural network (above) and graph neural network. GNNs have a layer that distributes data amongst its vertices.
With a Graph Neural Network, we add a step. We still have to initialize the neural network parameters, but we also initialize a representation of the nodes of the graph. So in our iterative process we:
- Input the graph data
- Update the node representations using GNN layers
- Have the resulting data flow through the conventional neural networks layers.
- Output a prediction, and an updated set of node representations, which are used to update the neural network parameters.
This set of GNN layers’ are designed specifically for interrogating the graph structure. For each node in the graph, each GNN layer represents a communication that can span nodes x hops away. For this reason, GNN layers are almost never as ‘deep’ as a deep learning network could be. From a performance point of view, there is a law of diminishing returns in applying many layers. We’ll get into the specifics of this as we review each architecture.
Most of this book will be focused on the middle sections of the diagram: the different ways graph updating layers are done, and the downstream neural network architecture of the GNN.
If you want to learn more, check out the book on Manning’s liveBook platform here.