From 0e5242cd11fda341e55a70d70d3f08fe0b5c1ee1 Mon Sep 17 00:00:00 2001 From: mini rawat Date: Sat, 14 Dec 2024 21:37:20 -0800 Subject: [PATCH] Add XENetConv convolution layer as requested in issue ticket #8257 --- CHANGELOG.md | 1 + test/nn/conv/test_xenet.py | 169 ++++++++++++++++++++++++++ torch_geometric/nn/conv/__init__.py | 2 + torch_geometric/nn/conv/xenet_conv.py | 155 +++++++++++++++++++++++ 4 files changed, 327 insertions(+) create mode 100644 test/nn/conv/test_xenet.py create mode 100644 torch_geometric/nn/conv/xenet_conv.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..cb19065c5199 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added XENetConv - a convolution layer based on the XENet paper ([#8257](https://github.com/pyg-team/pytorch_geometric/issues/8257)) - Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) diff --git a/test/nn/conv/test_xenet.py b/test/nn/conv/test_xenet.py new file mode 100644 index 000000000000..395028f57881 --- /dev/null +++ b/test/nn/conv/test_xenet.py @@ -0,0 +1,169 @@ +import unittest + +import torch + +from torch_geometric.data import Data +from torch_geometric.nn import XENetConv + + +class TestXENetConv(unittest.TestCase): + def setUp(self): + # Set random seed for reproducibility + torch.manual_seed(42) + + # Define test dimensions + self.num_nodes = 4 + self.in_node_channels = 3 + self.in_edge_channels = 2 + self.node_channels = 5 + self.edge_channels = 4 + self.stack_channels = [8, 16] + + # Create a simple graph for testing + self.x = torch.randn(self.num_nodes, self.in_node_channels) + self.edge_index = torch.tensor( + [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]], dtype=torch.long) + self.edge_attr = torch.randn(self.edge_index.size(1), + self.in_edge_channels) + + # Create different variants of the layer for testing + self.conv_attention = XENetConv(in_node_channels=self.in_node_channels, + in_edge_channels=self.in_edge_channels, + stack_channels=self.stack_channels, + node_channels=self.node_channels, + edge_channels=self.edge_channels, + attention=True) + + self.conv_no_attention = XENetConv( + in_node_channels=self.in_node_channels, + in_edge_channels=self.in_edge_channels, + stack_channels=self.stack_channels, + node_channels=self.node_channels, edge_channels=self.edge_channels, + attention=False) + + def test_basic_forward(self): + """Test basic forward pass with attention.""" + out_x, out_edge_attr = self.conv_attention(self.x, self.edge_index, + self.edge_attr) + + # Check output shapes + self.assertEqual(out_x.shape, (self.num_nodes, self.node_channels)) + self.assertEqual(out_edge_attr.shape, + (self.edge_index.size(1), self.edge_channels)) + + # Check that outputs contain no NaN values + self.assertFalse(torch.isnan(out_x).any()) + self.assertFalse(torch.isnan(out_edge_attr).any()) + + def test_no_attention_forward(self): + """Test forward pass without attention.""" + out_x, out_edge_attr = self.conv_no_attention(self.x, self.edge_index, + self.edge_attr) + + # Check output shapes + self.assertEqual(out_x.shape, (self.num_nodes, self.node_channels)) + self.assertEqual(out_edge_attr.shape, + (self.edge_index.size(1), self.edge_channels)) + + # Check that outputs contain no NaN values + self.assertFalse(torch.isnan(out_x).any()) + self.assertFalse(torch.isnan(out_edge_attr).any()) + + def test_custom_activation(self): + """Test with custom activation functions.""" + conv = XENetConv(in_node_channels=self.in_node_channels, + in_edge_channels=self.in_edge_channels, + stack_channels=self.stack_channels, + node_channels=self.node_channels, + edge_channels=self.edge_channels, attention=True, + node_activation=torch.tanh, + edge_activation=torch.relu) + + out_x, out_edge_attr = conv(self.x, self.edge_index, self.edge_attr) + + # Check output ranges for activations + self.assertTrue(torch.all(out_x >= -1) + and torch.all(out_x <= 1)) # tanh range + self.assertTrue(torch.all(out_edge_attr >= 0)) # ReLU range + + def test_single_stack_channel(self): + """Test with a single stack channel instead of a list.""" + conv = XENetConv( + in_node_channels=self.in_node_channels, + in_edge_channels=self.in_edge_channels, + stack_channels=32, # single integer + node_channels=self.node_channels, + edge_channels=self.edge_channels) + + out_x, out_edge_attr = conv(self.x, self.edge_index, self.edge_attr) + + # Check output shapes + self.assertEqual(out_x.shape, (self.num_nodes, self.node_channels)) + self.assertEqual(out_edge_attr.shape, + (self.edge_index.size(1), self.edge_channels)) + + def test_batch_processing(self): + """Test processing of batched graphs.""" + # Create two graphs with different sizes + x1 = torch.randn(3, self.in_node_channels) + edge_index1 = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], + dtype=torch.long) + edge_attr1 = torch.randn(edge_index1.size(1), self.in_edge_channels) + + x2 = torch.randn(4, self.in_node_channels) + edge_index2 = torch.tensor([[0, 1, 2, 2, 3], [1, 2, 1, 3, 2]], + dtype=torch.long) + edge_attr2 = torch.randn(edge_index2.size(1), self.in_edge_channels) + + # Create PyG Data objects + data1 = Data(x=x1, edge_index=edge_index1, edge_attr=edge_attr1) + data2 = Data(x=x2, edge_index=edge_index2, edge_attr=edge_attr2) + + # Process each graph separately + out_x1, out_edge_attr1 = self.conv_attention(data1.x, data1.edge_index, + data1.edge_attr) + out_x2, out_edge_attr2 = self.conv_attention(data2.x, data2.edge_index, + data2.edge_attr) + + # Check output shapes + self.assertEqual(out_x1.shape, (3, self.node_channels)) + self.assertEqual(out_edge_attr1.shape, (4, self.edge_channels)) + self.assertEqual(out_x2.shape, (4, self.node_channels)) + self.assertEqual(out_edge_attr2.shape, (5, self.edge_channels)) + + def test_isolated_nodes(self): + """Test handling of isolated nodes.""" + # Create a graph with an isolated node + x = torch.randn(4, self.in_node_channels) + edge_index = torch.tensor([[0, 1], [1, 2]], + dtype=torch.long) # Node 3 is isolated + edge_attr = torch.randn(edge_index.size(1), self.in_edge_channels) + + out_x, out_edge_attr = self.conv_attention(x, edge_index, edge_attr) + + # Check that isolated node features are updated + self.assertFalse(torch.isnan(out_x[3]).any()) + self.assertEqual(out_x.shape, (4, self.node_channels)) + self.assertEqual(out_edge_attr.shape, (2, self.edge_channels)) + + def test_gradients(self): + """Test gradient computation.""" + self.x.requires_grad_() + self.edge_attr.requires_grad_() + + out_x, out_edge_attr = self.conv_attention(self.x, self.edge_index, + self.edge_attr) + + # Compute gradients + loss = out_x.sum() + out_edge_attr.sum() + loss.backward() + + # Check that gradients are computed + self.assertIsNotNone(self.x.grad) + self.assertIsNotNone(self.edge_attr.grad) + self.assertFalse(torch.isnan(self.x.grad).any()) + self.assertFalse(torch.isnan(self.edge_attr.grad).any()) + + +if __name__ == '__main__': + unittest.main() diff --git a/torch_geometric/nn/conv/__init__.py b/torch_geometric/nn/conv/__init__.py index d0169f65f0a0..ac2ac6225e53 100644 --- a/torch_geometric/nn/conv/__init__.py +++ b/torch_geometric/nn/conv/__init__.py @@ -61,6 +61,7 @@ from .antisymmetric_conv import AntiSymmetricConv from .dir_gnn_conv import DirGNNConv from .mixhop_conv import MixHopConv +from .xenet_conv import XENetConv import torch_geometric.nn.conv.utils # noqa @@ -131,6 +132,7 @@ 'AntiSymmetricConv', 'DirGNNConv', 'MixHopConv', + 'XENetConv', ] classes = __all__ diff --git a/torch_geometric/nn/conv/xenet_conv.py b/torch_geometric/nn/conv/xenet_conv.py new file mode 100644 index 000000000000..67e7f1c6ee3f --- /dev/null +++ b/torch_geometric/nn/conv/xenet_conv.py @@ -0,0 +1,155 @@ +from typing import List, Optional, Union + +import torch +from torch import Tensor, nn + +from torch_geometric.nn.conv import MessagePassing + + +class XENetConv(MessagePassing): + r"""Implementation of XENet convolution layer from the paper. + + "XENet: Using a new graph convolution to accelerate the timeline for + protein design on quantum computers. + + Based on original implementation here: + https://github.com/danielegrattarola/spektral/blob/master/spektral/ \ + layers/convolutional/xenet_conv.py" + + Args: + in_node_channels (int): Size of input node features + in_edge_channels (int): Size of input edge features + stack_channels (Union[int, List[int]]): Number of channels for the + hidden stack layers + node_channels (int): Number of output node features + edge_channels (int): Number of output edge features + attention (bool, optional): Whether to use attention when aggregating + messages. (default: True) + node_activation(Optional[callable], optional): Activation function for + nodes. (default: None) + edge_activation (Optional[callable], optional): Activation function for + edges. (default: None) + """ + def __init__(self, in_node_channels: int, in_edge_channels: int, + stack_channels: Union[int, List[int]], node_channels: int, + edge_channels: int, attention: bool = True, + node_activation: Optional[callable] = None, + edge_activation: Optional[callable] = None, **kwargs): + super().__init__(aggr='add', node_dim=0, **kwargs) + + self.in_node_channels = in_node_channels + self.in_edge_channels = in_edge_channels + self.stack_channels = stack_channels if isinstance( + stack_channels, list) else [stack_channels] + self.node_channels = node_channels + self.edge_channels = edge_channels + self.attention = attention + + # Node and edge activation functions + self.node_activation = node_activation if node_activation is not None \ + else lambda x: x + self.edge_activation = edge_activation if edge_activation is not None \ + else lambda x: x + + # Stack MLPs + stack_input_size = 2 * in_node_channels + 2 * in_edge_channels + self.stack_layers = nn.ModuleList() + current_channels = stack_input_size + + for i, channels in enumerate(self.stack_channels): + self.stack_layers.append(nn.Linear(current_channels, channels)) + if i != len(self.stack_channels) - 1: + self.stack_layers.append(nn.ReLU()) + else: + self.stack_layers.append(nn.PReLU()) + current_channels = channels + + # Final node and edge MLPs + node_input_size = in_node_channels + 2 * self.stack_channels[-1] + self.node_mlp = nn.Linear(node_input_size, node_channels) + self.edge_mlp = nn.Linear(self.stack_channels[-1], edge_channels) + + # Attention layers + if self.attention: + self.att_in = nn.Sequential(nn.Linear(self.stack_channels[-1], 1), + nn.Sigmoid()) + self.att_out = nn.Sequential(nn.Linear(self.stack_channels[-1], 1), + nn.Sigmoid()) + + def forward(self, x: Tensor, edge_index: Tensor, + edge_attr: Tensor) -> tuple[Tensor, Tensor]: + """Args + x (Tensor): Node feature matrix of shape [num_nodes, + in_node_channels] edge_index (Tensor): Graph connectivity matrix of + shape [2, num_edges] edge_attr (Tensor): Edge feature matrix of + shape [num_edges, in_edge_channels] + + Returns: + tuple[Tensor, Tensor]: Updated node features [num_nodes, + node_channels] and edge features [num_edges, edge_channels] + """ + # Propagate messages + out_dict = self.propagate(edge_index, x=x, edge_attr=edge_attr, + size=(x.size(0), x.size(0))) + + # Update node features + x_new = self.node_mlp( + torch.cat([x, out_dict['incoming'], out_dict['outgoing']], dim=-1)) + x_new = self.node_activation(x_new) + + # Update edge features + edge_features = out_dict['edge_features'] + edge_attr_new = self.edge_mlp(edge_features) + edge_attr_new = self.edge_activation(edge_attr_new) + + return x_new, edge_attr_new + + def message(self, x_i: Tensor, x_j: Tensor, edge_attr: Tensor) -> dict: + """Constructs messages for each edge.""" + # Get reversed edge features by flipping edge_index + edge_attr_rev = edge_attr[torch.arange(edge_attr.size(0) - 1, -1, -1)] + + # Concatenate all features + stack = torch.cat([x_i, x_j, edge_attr, edge_attr_rev], dim=-1) + + # Apply stack MLPs + for layer in self.stack_layers: + stack = layer(stack) + + # Apply attention if needed + if self.attention: + att_in = self.att_in(stack) + att_out = self.att_out(stack) + stack_in = stack * att_in + stack_out = stack * att_out + else: + stack_in = stack_out = stack + + return { + 'incoming': stack_in, + 'outgoing': stack_out, + 'edge_features': stack + } + + def aggregate(self, inputs: dict, index: Tensor, + dim_size: Optional[int] = None) -> dict: + """Aggregates messages from neighbors.""" + incoming = self.aggr_module(inputs['incoming'], index, + dim_size=dim_size) + outgoing = self.aggr_module(inputs['outgoing'], index, + dim_size=dim_size) + + return { + 'incoming': incoming, + 'outgoing': outgoing, + 'edge_features': inputs['edge_features'] + } + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'in_node_channels={self.in_node_channels}, ' + f'in_edge_channels={self.in_edge_channels}, ' + f'stack_channels={self.stack_channels}, ' + f'node_channels={self.node_channels}, ' + f'edge_channels={self.edge_channels}, ' + f'attention={self.attention})')