3. Convolutional Graph Neural Networks, Deep Graph Infomax
Convolutional Graph Neural Networks (ConvGNNs)¶
What is a Convolutional Graph Neural Network?¶
Convolutional Graph Neural Networks (ConvGNNs) are a type of neural network designed to operate on graph-structured data. They generalize the concept of convolution from traditional grid-structured data (like images) to graphs. This allows ConvGNNs to capture and aggregate information from a node's neighbors, leveraging the graph topology to learn powerful representations.
Why Use ConvGNNs?¶
- Enable the application of deep learning techniques to graph data.
- Capture complex relationships and structures in the data.
- Improve performance on various tasks such as node classification, link prediction, and graph classification.
How Do They Work in General?¶
ConvGNNs work by:
- Aggregating information from a node's neighbors.
- Applying a convolutional operation to combine this information.
- Updating the node's representation based on the aggregated and convolved information.
This process is repeated for multiple layers, allowing the network to capture increasingly complex patterns and dependencies.
Algorithms¶
Neural Fingerprints¶
Idea:
Neural Fingerprints (Neural FPs) are among the first algorithms to leverage convolution for learning in graphs, specifically designed for molecular data. The goal is to create vector representations (embeddings) of molecules that capture their structural information, enabling the identification of functionally similar molecules.
Algorithm:
- Initialization:
- Begin with hand-crafted hashing and indexing techniques to initialize molecular representations.
- Inner Iteration:
- Collect Neighbor Information: For each atom in the molecule, gather information from its neighboring atoms.
- Aggregate Representations: Combine the collected information to form a new representation using a hash function.
- Update Output Vector: Use the hash to update an output vector that holds binary flags indicating certain molecular properties.
- Outer Cycle:
- Propagate Information: Iteratively propagate the properties to more distant neighbors, ensuring that information spreads across the molecule.
- Learnable Operations:
- Replace the hashing and aggregation steps with learnable operations using matrices H_N and W_L, along with non-linear activation functions (e.g., sigmoid or tanh).
Conclusion:
Neural FPs allow for learnable convolutional representations that can infer embeddings for unseen graphs, though they are computationally expensive and limited in information propagation across the graph.
Diffusion-Convolution Neural Networks (DCNNs)¶
Idea:
DCNNs generalize convolutional neural networks to graphs, supporting both node and graph classification. They utilize diffusion kernels to measure connectivity between nodes, considering all paths but weighting shorter paths more heavily.
Diffusion Kernel:
A diffusion kernel is a mathematical construct used to measure the connectivity or similarity between nodes in a graph. It is based on the idea of diffusion processes, where information spreads from one node to others through the edges of the graph. In the context of DCNNs:
- Diffusion Process: Represents the spread of information from one node to others.
- Kernel Function: Measures how strongly connected two nodes are, considering all possible paths between them.
- Weighting Paths: Shorter paths are weighted more heavily than longer paths, reflecting the more direct connection between nodes.
Algorithm:
- Graph Setup:
- Define the Graphs: Assume a set of T graphs G = \{G_t | t \in 1, \ldots, T \} with vertices V_t and edges E_t.
- Feature Matrix: Each graph G_t has a feature matrix X_t representing additional information about the vertices.
- Diffusion Matrix:
- Compute Transition Probabilities: Use the adjacency matrix to create the diffusion matrix P_t, which contains probabilities of transitioning between nodes.
- Power Series of Diffusion Matrix: Compute P_t^*, the power series of P_t, to capture multi-step transitions.
- Node Classification:
- Vector Representation: Convert each node to a vector representation Z_t by combining its features with the diffusion matrix.
- Prediction: Use a non-linear function f and learnable weights W_c and W_d to make predictions.
- Graph Classification:
- Aggregate Node Representations: Combine node representations to form a graph-level representation.
- Prediction: Make predictions using the aggregated representation, similar to node classification.
- Learning:
- Training: Use backpropagation and gradient descent to optimize the model parameters.
- Minibatch Approach: Train on randomly selected nodes or graphs in minibatches for efficiency.
Limitations:
DCNNs face scalability issues due to memory requirements for large graphs and may fail to capture long-range dependencies.
GraphSAGE (Graph SAmple and aggreGatE)¶
Idea:
GraphSAGE is an inductive framework that generates node embeddings by sampling and aggregating features from a node's local neighborhood, allowing it to generalize to unseen nodes.
Algorithm:
- Initialization:
- Node Features: Start with initial node features x_v, which serve as the initial representations.
- Aggregation:
- Aggregate Neighbor Features: Use different strategies (MEAN, LSTM, POOL) to aggregate features from neighboring nodes.
- Combine Representations: Combine the node's features with its neighbors' aggregated features to update the node's representation.
- Neighborhood Sampling:
- Sample Neighbors: Sample a fixed number of neighbors to limit computational costs and avoid considering the entire neighborhood.
- Propagate Information: Information from a node is propagated one step further in each iteration.
- Training:
- Unsupervised Learning: Train using unsupervised learning with negative sampling, ensuring that similar nodes have similar representations while dissimilar nodes are pushed apart.
- Optimization: Use stochastic gradient descent to optimize weight matrices.
Wrap Up:
GraphSAGE can generate embeddings for unseen nodes, reducing space requirements by not storing all node embeddings directly.
HetGNN (Heterogeneous Graph Neural Network)¶
Idea:
HetGNN addresses the challenges of heterogeneous graphs, where nodes and edges can have different types and attributes. It processes and aggregates information from different types of nodes and edges to generate embeddings.
Algorithm:
- Sampling Strategy:
- Random Walk with Restart (RWR): Perform a random walk from each node with a probability of restarting at the initial node.
- Collect Neighbors: Collect a fixed number of nodes of all types during the walk.
- Group Neighbors: Group neighbors by type and select the top k_t most frequently visited nodes for each type.
- Attribute Encoding:
- Convert Attributes: Convert each attribute into a vector representation using appropriate techniques (e.g., NLP, CNN).
- Aggregate Attributes: Use a bidirectional-LSTM to aggregate attribute values into final representations.
- Aggregation:
- Same-Type Nodes: Aggregate representations of same-type nodes using another bidirectional-LSTM.
- Combine Types: Combine representations from all types of nodes, employing an attention mechanism to identify important nodes.
- Objective and Training:
- Predict Probability: Predict the probability p(v_c | v; \Theta) using negative sampling.
- Optimize Loss Function: Optimize a loss function incorporating both positive and negative samples to train the model.
Deep Graph Infomax (DGI)¶
Idea:
DGI is a general unsupervised learning framework that maximizes mutual information between node representations and a summary vector of the graph. It aims to produce similar representations for nodes within the same graph and dissimilar ones for nodes from different graphs.
Algorithm:
- Feature and Adjacency Matrix:
- Input Data: Start with node features X and an adjacency matrix A that represents the graph structure.
- Encoder:
- Node Representations: Learn high-level node representations H = E(X, A) using an encoder.
- Readout Function:
- Graph-Level Representation: Summarize node representations into a graph-level representation \vec{s} = R(H) using a readout function.
- Negative Samples:
- Corrupted Samples: Generate corrupted samples (\tilde{X}, \tilde{A}) by altering the original features or topology.
- Corrupted Representations: Compute representations for the corrupted graph \tilde{H} = E(\tilde{X}, \tilde{A}).
- Discriminator:
- Mutual Information: Train a discriminator to distinguish between node representations from the original and corrupted graphs.
- Loss Function:
- Maximize Mutual Information: Optimize the loss function to maximize mutual information between node representations and the graph summary.
- Gradient Descent: Update the encoder, readout function, and discriminator parameters using gradient descent.
Loss Function: