diff --git a/CHANGELOG.md b/CHANGELOG.md index 4a18280634a4..b86a735e68d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `Exphormer` implementation based on this [paper](https://arxiv.org/abs/2303.06147) ([#9783](https://github.com/pyg-team/pytorch_geometric/pull/9783)) - Consolidate Cugraph examples into ogbn_train_cugraph.py and ogbn_train_cugraph_multigpu.py for ogbn-arxiv, ogbn-products and ogbn-papers100M ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953)) - Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975)) - Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947)) diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py new file mode 100644 index 000000000000..7e14be38a208 --- /dev/null +++ b/test/transforms/test_exphormer.py @@ -0,0 +1,131 @@ +import pytest +import torch + +from torch_geometric.nn.attention.expander import ExpanderAttention +from torch_geometric.nn.attention.local import LocalAttention +from torch_geometric.transforms.exphormer import EXPHORMER + + +def create_mock_data(num_nodes, hidden_dim, edge_dim=None): + from torch_geometric.data import Data + x = torch.rand(num_nodes, hidden_dim) + edge_index = torch.randint(0, num_nodes, + (2, num_nodes * 2)) # Create a few random edges + + # Ensure edge_attr has the same number of rows as edges in edge_index + edge_attr = torch.rand(edge_index.size(1), edge_dim) if edge_dim else None + + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + +# Tests for LocalAttention +@pytest.fixture +def local_attention_model(): + hidden_dim, num_heads = 16, 4 + return LocalAttention(hidden_dim=hidden_dim, num_heads=num_heads) + + +def test_local_attention_initialization(local_attention_model): + assert local_attention_model.hidden_dim == 16 + assert local_attention_model.num_heads == 4 + assert local_attention_model.q_proj.weight.shape == (16, 16) + + +def test_local_attention_forward(local_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) + output = local_attention_model(data.x, data.edge_index, data.edge_attr) + assert output.shape == (data.x.shape[0], 16) + + +def test_local_attention_no_edge_attr(local_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) + output_no_edge_attr = local_attention_model(data.x, data.edge_index, None) + assert output_no_edge_attr.shape == (data.x.shape[0], 16) + + +def test_local_attention_mismatched_edge_attr(local_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) + mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16) + with pytest.raises(ValueError, match="edge_attr size doesn't match"): + local_attention_model(data.x, data.edge_index, mismatched_edge_attr) + + +def test_local_attention_message(local_attention_model): + q_i = torch.rand(20, 4, 4) + k_j = torch.rand_like(q_i) + v_j = torch.rand_like(q_i) + edge_attr = torch.rand(20, 4) + attention_scores = local_attention_model.message(q_i, k_j, v_j, edge_attr) + assert attention_scores.shape == v_j.shape + + +# Tests for ExpanderAttention +@pytest.fixture +def expander_attention_model(): + hidden_dim, num_heads, expander_degree = 16, 4, 4 + return ExpanderAttention(hidden_dim=hidden_dim, + expander_degree=expander_degree, + num_heads=num_heads) + + +def test_expander_attention_initialization(expander_attention_model): + assert expander_attention_model.expander_degree == 4 + assert expander_attention_model.q_proj.weight.shape == (16, 16) + + +def test_expander_attention_generate_expander_edges(expander_attention_model): + edge_index = expander_attention_model.generate_expander_edges(num_nodes=10) + assert edge_index.shape[0] == 2 + + +def test_expander_attention_forward(expander_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16) + output, edge_index = expander_attention_model(data.x, num_nodes=10) + assert output.shape == (data.x.shape[0], 16) + + +def test_expander_attention_insufficient_nodes(expander_attention_model): + with pytest.raises(Exception): + expander_attention_model.generate_expander_edges(2) + + +def test_expander_attention_message(expander_attention_model): + q_i = torch.rand(20, 4, 4) + k_j = torch.rand_like(q_i) + v_j = torch.rand_like(q_i) + attention_scores = expander_attention_model.message(q_i, k_j, v_j) + assert attention_scores.shape == v_j.shape + + +# Tests for EXPHORMER +@pytest.fixture +def exphormer_model(): + hidden_dim, num_layers, num_heads, expander_degree = 16, 3, 4, 4 + return EXPHORMER(hidden_dim=hidden_dim, num_layers=num_layers, + num_heads=num_heads, expander_degree=expander_degree) + + +def test_exphormer_initialization(exphormer_model): + assert len(exphormer_model.layers) == 3 + assert isinstance(exphormer_model.layers[0]['local'], LocalAttention) + if exphormer_model.use_expander: + assert isinstance(exphormer_model.layers[0]['expander'], + ExpanderAttention) + + +def test_exphormer_forward(exphormer_model): + data = create_mock_data(num_nodes=10, hidden_dim=16) + output = exphormer_model(data) + assert output.shape == (data.x.shape[0], 16) + + +def test_exphormer_empty_graph(exphormer_model): + data_empty = create_mock_data(num_nodes=1, hidden_dim=16) + with pytest.raises(Exception): + exphormer_model(data_empty) + + +def test_exphormer_incorrect_input_dimensions(exphormer_model): + data_incorrect_dim = create_mock_data(num_nodes=10, hidden_dim=17) + with pytest.raises(Exception): + exphormer_model(data_incorrect_dim) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py new file mode 100644 index 000000000000..7d5cbee654e6 --- /dev/null +++ b/torch_geometric/nn/attention/expander.py @@ -0,0 +1,62 @@ +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn + +from torch_geometric.nn import MessagePassing + + +class ExpanderAttention(MessagePassing): + """Expander attention with random d-regular, near-Ramanujan graphs.""" + def __init__(self, hidden_dim: int, expander_degree: int = 4, + num_heads: int = 4, dropout: float = 0.1) -> None: + super().__init__(aggr='add', node_dim=0) + self.hidden_dim = hidden_dim + self.expander_degree = expander_degree + if expander_degree % 2 != 0: + raise ValueError("expander_degree must be an even number.") + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.q_proj = nn.Linear(hidden_dim, hidden_dim) + self.k_proj = nn.Linear(hidden_dim, hidden_dim) + self.v_proj = nn.Linear(hidden_dim, hidden_dim) + self.o_proj = nn.Linear(hidden_dim, hidden_dim) + self.edge_embedding = nn.Parameter(torch.randn(1, num_heads)) + self.dropout = nn.Dropout(dropout) + + def generate_expander_edges(self, num_nodes: int) -> torch.Tensor: + if num_nodes < self.expander_degree: + raise ValueError( + "Number of nodes is insufficient to generate expander edges.") + edges = [] + for _ in range(self.expander_degree // 2): + perm = torch.randperm(num_nodes) + edges.extend([(i, perm[i].item()) for i in range(num_nodes)]) + edges.extend([(perm[i].item(), i) for i in range(num_nodes)]) + edge_index = torch.tensor(edges, dtype=torch.long).t() + return edge_index + + def forward(self, x: torch.Tensor, + num_nodes: int) -> Tuple[torch.Tensor, torch.Tensor]: + edge_index = self.generate_expander_edges(num_nodes).to(x.device) + q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) + k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) + v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) + out = self.propagate(edge_index, q=q, k=k, v=v) + return self.o_proj(out.view(-1, self.hidden_dim)), edge_index + + def message(self, q_i: torch.Tensor, k_j: torch.Tensor, + v_j: torch.Tensor) -> torch.Tensor: + attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) + attention = torch.softmax(attention + self.edge_embedding, dim=-1) + attention = self.dropout(attention) + return attention.unsqueeze(-1) * v_j + + def edge_update(self) -> torch.Tensor: + raise NotImplementedError( + "edge_update not implemented in ExpanderAttention.") + + def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor: + raise NotImplementedError( + "message_and_aggregate not implemented in ExpanderAttention.") diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py new file mode 100644 index 000000000000..be5904537cf9 --- /dev/null +++ b/torch_geometric/nn/attention/local.py @@ -0,0 +1,67 @@ +from typing import Optional + +import numpy as np +import torch +import torch.nn as nn + +from torch_geometric.nn import MessagePassing + + +class LocalAttention(MessagePassing): + """Local neighborhood attention.""" + def __init__(self, hidden_dim: int, num_heads: int = 4, + dropout: float = 0.1) -> None: + super().__init__(aggr='add', node_dim=0) + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.q_proj = nn.Linear(hidden_dim, hidden_dim) + self.k_proj = nn.Linear(hidden_dim, hidden_dim) + self.v_proj = nn.Linear(hidden_dim, hidden_dim) + self.o_proj = nn.Linear(hidden_dim, hidden_dim) + self.edge_proj = nn.Linear(hidden_dim, num_heads) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor, + edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor: + + if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): + raise ValueError( + "edge_attr size doesn't match the no of edges in edge_index.") + q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) + k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) + v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) + edge_attention = self.edge_proj( + edge_attr) if edge_attr is not None else None + out = self.propagate(edge_index, q=q, k=k, v=v, + edge_attr=edge_attention) + return self.o_proj(out.view(-1, self.hidden_dim)) + + def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, + edge_attr: Optional[torch.Tensor]) -> torch.Tensor: + + attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) + if edge_attr is not None: + + if edge_attr.size(0) < attention.size(0): + num_repeats = attention.size(0) // edge_attr.size(0) + 1 + edge_attr = edge_attr.repeat(num_repeats, + 1)[:attention.size(0)] + elif edge_attr.size(0) > attention.size(0): + edge_attr = edge_attr[:attention.size(0)] + attention = attention + edge_attr + + attention = torch.softmax(attention, dim=-1) + attention = self.dropout(attention) + + # Apply attention scores to the values + out = attention.unsqueeze(-1) * v_j + return out + + def edge_update(self) -> torch.Tensor: + raise NotImplementedError( + "edge_update not implemented in LocalAttention.") + + def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor: + raise NotImplementedError( + "message_and_aggregate not implemented in LocalAttention.") diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py new file mode 100644 index 000000000000..5f0d2218bdb2 --- /dev/null +++ b/torch_geometric/transforms/exphormer.py @@ -0,0 +1,91 @@ +from typing import Optional + +import torch +import torch.nn as nn + +from torch_geometric.data import Data +from torch_geometric.nn.attention.expander import ExpanderAttention +from torch_geometric.nn.attention.local import LocalAttention +from torch_geometric.transforms import VirtualNode + + +class EXPHORMER(nn.Module): + """EXPHORMER architecture. + + Based on the paper: https://arxiv.org/abs/2303.06147 + """ + def __init__(self, hidden_dim: int, num_layers: int = 3, + num_heads: int = 4, expander_degree: int = 4, + dropout: float = 0.1, use_expander: bool = True, + use_global: bool = True, num_virtual_nodes: int = 1) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.use_expander = use_expander + self.use_global = use_global + self.virtual_node_transform = VirtualNode() if use_global else None + if use_global and num_virtual_nodes < 1: + raise ValueError( + "num_virtual_nodes must be >= 1 if use_global is enabled.") + self.num_virtual_nodes = num_virtual_nodes + self.layers = nn.ModuleList([ + nn.ModuleDict({ + 'local': + LocalAttention(hidden_dim, num_heads=num_heads, + dropout=dropout), + 'expander': + ExpanderAttention(hidden_dim, expander_degree=expander_degree, + num_heads=num_heads, dropout=dropout) + if use_expander else nn.Identity(), + 'layer_norm': + nn.LayerNorm(hidden_dim), + 'ffn': + nn.Sequential(nn.Linear(hidden_dim, 4 * hidden_dim), nn.GELU(), + nn.Dropout(dropout), + nn.Linear(4 * hidden_dim, hidden_dim)) + }) for _ in range(num_layers) + ]) + self.dropout = nn.Dropout(dropout) + + def forward(self, data: Data) -> torch.Tensor: + + if data.x is None: + raise ValueError("Input data.x cannot be None") + x: torch.Tensor = data.x + + if x.size(0) == 0: + raise ValueError("Input graph is empty.") + + if not hasattr(data, 'edge_index') or data.edge_index is None: + raise ValueError( + "Input data must contain 'edge_index' for message passing.") + edge_index: torch.Tensor = data.edge_index + + edge_attr: Optional[torch.Tensor] = data.edge_attr if hasattr( + data, 'edge_attr') else None + + batch_size = x.size(0) + + if self.virtual_node_transform is not None: + data = self.virtual_node_transform(data) + if data.x is None: + raise ValueError("Virtual node transform resulted in None x") + x = data.x + if data.edge_index is None: + raise ValueError( + "Virtual node transform resulted in None edge_index") + edge_index = data.edge_index + + for layer in self.layers: + residual = x + local_out = layer['local'](x, edge_index, edge_attr) + expander_out = torch.zeros_like(x) + if self.use_expander and layer['expander'] is not None: + + node_subset = x[:batch_size + self.num_virtual_nodes] + expander_out, _ = layer['expander'](node_subset, batch_size + + self.num_virtual_nodes) + + x = layer['layer_norm'](residual + local_out + expander_out) + x = x + layer['ffn'](x) + + return x[:batch_size]