Graph Neural Networks
Suppose we’re interested in the ogbgmolhiv dataset, which consists of roughly 40,000 molecules each with a binary label that indicates whether the molecule inhibits HIV virus replication. We’ll follow this jraph library example in predicting that label using a Graph Neural Network (GNN). In particular, we’ll focus on describing the forward pass of the model (see this colab).
We start by representing each molecule in the dataset as a graph where each vertex in the graph corresponds to an atom in the molecule and each edge corresponds to a bond between atoms. The dataset includes both vertex and edge features. For a single training example, the vertex feature matrix $\textbf{V}$ has a row for each vertex in the graph ($N^v$ vertices) and 9 columns (one for each vertex feature). The edge feature matrix $\textbf{E}$ has a row for each edge in the graph ($N^e$ edges) and 3 columns (one for each edge feature). Let $\textbf{v}_k$ denote the $k$th column of $\textbf{V}$ and $\textbf{e}_k$ denote the $k$th column of $\textbf{E}$. Each edge connects a sender vertex to a receiver vertex. A training example also includes vectors $\textbf{s}$ and $\textbf{r}$ where $s_k$ denotes the index of the row in $\textbf{V}$ associated with the sender vertex for the $k$th edge and $r_k$ denotes the index of the row in $\textbf{V}$ associated with the receiver vertex for the $k$th edge. An important property of this model is that it’s permutation invariant, i.e., it will give the same output for any ordering one chooses for the vertices and the edges to construct the input data.
The model proceeds in these steps (with notation inspired by Algorithm 1 in Battaglia et al 2018):

Embed the features: The first step is a linear layer to embed $\textbf{E}$ and a linear layer to embed $\textbf{V}$ in a space of dimension $d$. We use $\textbf{E}$ and $\textbf{V}$ to refer to the embedded matrices from here on out.

Update the edge features: The next step is to compute an “updated” edge matrix. We form the matrices $\textbf{S}^{v}$ and $\textbf{R}^{v} \in \mathbb{R}^{N^e \times d}$ so that $\textbf{S}^{v}_{k, :} = \textbf{v}_{s_k}$ and $\textbf{R}^{v}_{k, :} = \textbf{v}_{r_k}$. The edge update function $\phi^{e}$ then takes $\textbf{E}$, $\textbf{S}^{v}$ and $\textbf{R}^{v}$ as input, concatenates them on the last dimension, and feeds the resulting matrix of shape $[N^e, 3 \dot d]$ through a neural network to get the updated edge matrix $\textbf{E}^{\prime}$ of shape $[N^e, d^{\prime}]$.

Update the vertex features: In a similar vein, we also compute an updated vertex matrix. We form the matrices $\textbf{S}^{e}$ and $\textbf{R}^{e} \in \mathbb{R}^{N^v \times d}$ so that $\textbf{S}^{e}_{k, :} = \sum_{j : s_j = k} \textbf{e}_{j}$ and $\textbf{R}^{e}_{k, :} = \sum_{j : r_j = k} \textbf{e}_{j}$. Instead of taking the sum, we could take the mean or use any other permutation invariant aggregation function $\rho^{e \to v}$. The vertex update function $\phi^{v}$ then takes $\textbf{V}$, $\textbf{S}^{e}$ and $\textbf{R}^{e}$ as input, concatenates them on the last dimension and feeds the resulting matrix of shape $[N^v, 3 \dot d]$ through a neural network to get the updated vertex matrix $\textbf{V}^{\prime}$ of shape $[N^v, d^{\prime}]$.

Read out the aggregated vertex and edge features: We then aggregate the rows of $\textbf{V}^{\prime}$ with the function $\rho^{v \to u}$ to a get a vector $\textbf{v}^{u} \in \mathbb{R}^{d^{\prime}}$ that summarizes the vertices of the graph and aggregate the rows of $\textbf{E}^{\prime}$ with the function $\rho^{e \to u}$ to get a vector $\textbf{e}^{u} \in \mathbb{R}^{d^{\prime}}$ that summarizes the edges of the graph. The global update function $\phi^{u}$ then takes $\textbf{v}^{u}$ and $\textbf{e}^{u}$ as input, concatenates them and uses the resulting vector as features in a logistic regression layer.
We can operate on a batch of graphs by embedding each of the graphs into a larger graph that contains them and using the same model described above with care taken to aggregate vertex and edge features only within subgraphs (the jraph implementation achieves this by using jax.ops.segment_sum for the aggregation functions).