From 31757b3c9b91553830289b088805e5c88f075db1 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Wed, 13 Nov 2024 23:02:14 +0530 Subject: [PATCH 01/19] initial commit --- test/transforms/test_exphormer.py | 106 +++++++++++++++++++++++ torch_geometric/nn/attention/expander.py | 45 ++++++++++ torch_geometric/nn/attention/local.py | 50 +++++++++++ torch_geometric/transforms/exphormer.py | 51 +++++++++++ 4 files changed, 252 insertions(+) create mode 100644 test/transforms/test_exphormer.py create mode 100644 torch_geometric/nn/attention/expander.py create mode 100644 torch_geometric/nn/attention/local.py create mode 100644 torch_geometric/transforms/exphormer.py diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py new file mode 100644 index 000000000000..2d5d2f7e7df6 --- /dev/null +++ b/test/transforms/test_exphormer.py @@ -0,0 +1,106 @@ +import pytest +import torch +from torch_geometric.data import Data + +from torch_geometric.nn.attention.local import LocalAttention +from torch_geometric.nn.attention.expander import ExpanderAttention +from torch_geometric.transforms.exphormer import EXPHORMER + + +def create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16): + x = torch.rand(num_nodes, hidden_dim) + edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) + edge_attr = torch.rand(num_nodes * 2, edge_dim) + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + +# Tests for LocalAttention +@pytest.fixture +def local_attention_model(): + hidden_dim, num_heads = 16, 4 + return LocalAttention(hidden_dim=hidden_dim, num_heads=num_heads) + +def test_local_attention_initialization(local_attention_model): + assert local_attention_model.hidden_dim == 16 + assert local_attention_model.num_heads == 4 + assert local_attention_model.q_proj.weight.shape == (16, 16) + +def test_local_attention_forward(local_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) + output = local_attention_model(data.x, data.edge_index, data.edge_attr) + assert output.shape == (data.x.shape[0], 16) + +def test_local_attention_no_edge_attr(local_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) + output_no_edge_attr = local_attention_model(data.x, data.edge_index, None) + assert output_no_edge_attr.shape == (data.x.shape[0], 16) + +def test_local_attention_mismatched_edge_attr(local_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) + with pytest.raises(Exception): + mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16) + local_attention_model(data.x, data.edge_index, mismatched_edge_attr) + +def test_local_attention_message(local_attention_model): + q_i = torch.rand(20, 4, 4) + k_j = torch.rand_like(q_i) + v_j = torch.rand_like(q_i) + edge_attr = torch.rand(20, 4) + attention_scores = local_attention_model.message(q_i, k_j, v_j, edge_attr) + assert attention_scores.shape == v_j.shape + +# Tests for ExpanderAttention +@pytest.fixture +def expander_attention_model(): + hidden_dim, num_heads, expander_degree = 16, 4, 4 + return ExpanderAttention(hidden_dim=hidden_dim, expander_degree=expander_degree, num_heads=num_heads) + +def test_expander_attention_initialization(expander_attention_model): + assert expander_attention_model.expander_degree == 4 + assert expander_attention_model.q_proj.weight.shape == (16, 16) + +def test_expander_attention_generate_expander_edges(expander_attention_model): + edge_index = expander_attention_model.generate_expander_edges(num_nodes=10) + assert edge_index.shape[0] == 2 + +def test_expander_attention_forward(expander_attention_model): + data = create_mock_data(num_nodes=10, hidden_dim=16) + output, edge_index = expander_attention_model(data.x, num_nodes=10) + assert output.shape == (data.x.shape[0], 16) + +def test_expander_attention_insufficient_nodes(expander_attention_model): + with pytest.raises(Exception): + expander_attention_model.generate_expander_edges(2) + +def test_expander_attention_message(expander_attention_model): + q_i = torch.rand(20, 4, 4) + k_j = torch.rand_like(q_i) + v_j = torch.rand_like(q_i) + attention_scores = expander_attention_model.message(q_i, k_j, v_j) + assert attention_scores.shape == v_j.shape + +# Tests for EXPHORMER +@pytest.fixture +def exphormer_model(): + hidden_dim, num_layers, num_heads, expander_degree = 16, 3, 4, 4 + return EXPHORMER(hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, expander_degree=expander_degree) + +def test_exphormer_initialization(exphormer_model): + assert len(exphormer_model.layers) == 3 + assert isinstance(exphormer_model.layers[0]['local'], LocalAttention) + if exphormer_model.use_expander: + assert isinstance(exphormer_model.layers[0]['expander'], ExpanderAttention) + +def test_exphormer_forward(exphormer_model): + data = create_mock_data(num_nodes=10, hidden_dim=16) + output = exphormer_model(data) + assert output.shape == (data.x.shape[0], 16) + +def test_exphormer_empty_graph(exphormer_model): + data_empty = create_mock_data(num_nodes=1, hidden_dim=16) + with pytest.raises(Exception): + exphormer_model(data_empty) + +def test_exphormer_incorrect_input_dimensions(exphormer_model): + data_incorrect_dim = create_mock_data(num_nodes=10, hidden_dim=17) + with pytest.raises(Exception): + exphormer_model(data_incorrect_dim) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py new file mode 100644 index 000000000000..08ba6cee777b --- /dev/null +++ b/torch_geometric/nn/attention/expander.py @@ -0,0 +1,45 @@ +import torch +import torch.nn as nn +import numpy as np + +from typing import Tuple +from torch_geometric.nn import MessagePassing + + +class ExpanderAttention(MessagePassing): + """Expander graph attention using random d-regular near-Ramanujan graphs.""" + def __init__(self, hidden_dim: int, expander_degree: int = 4, num_heads: int = 4, dropout: float = 0.1): + super().__init__(aggr='add', node_dim=0) + self.hidden_dim = hidden_dim + self.expander_degree = expander_degree + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.q_proj = nn.Linear(hidden_dim, hidden_dim) + self.k_proj = nn.Linear(hidden_dim, hidden_dim) + self.v_proj = nn.Linear(hidden_dim, hidden_dim) + self.o_proj = nn.Linear(hidden_dim, hidden_dim) + self.edge_embedding = nn.Parameter(torch.randn(1, num_heads)) + self.dropout = nn.Dropout(dropout) + + def generate_expander_edges(self, num_nodes: int) -> torch.Tensor: + edges = [] + for _ in range(self.expander_degree // 2): + perm = torch.randperm(num_nodes) + edges.extend([(i, perm[i].item()) for i in range(num_nodes)]) + edges.extend([(perm[i].item(), i) for i in range(num_nodes)]) + edge_index = torch.tensor(edges, dtype=torch.long).t() + return edge_index + + def forward(self, x: torch.Tensor, num_nodes: int) -> Tuple[torch.Tensor, torch.Tensor]: + edge_index = self.generate_expander_edges(num_nodes).to(x.device) + q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) + k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) + v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) + out = self.propagate(edge_index, q=q, k=k, v=v) + return self.o_proj(out.view(-1, self.hidden_dim)), edge_index + + def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor) -> torch.Tensor: + attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) + attention = torch.softmax(attention + self.edge_embedding, dim=-1) + attention = self.dropout(attention) + return attention.unsqueeze(-1) * v_j \ No newline at end of file diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py new file mode 100644 index 000000000000..e8c8ce3b2839 --- /dev/null +++ b/torch_geometric/nn/attention/local.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +import numpy as np + +from typing import Optional +from torch_geometric.nn import MessagePassing + + +class LocalAttention(MessagePassing): + """Local neighborhood attention.""" + + def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.1): + super().__init__(aggr='add', node_dim=0) + self.hidden_dim = hidden_dim + self.num_heads = num_heads + self.head_dim = hidden_dim // num_heads + self.q_proj = nn.Linear(hidden_dim, hidden_dim) + self.k_proj = nn.Linear(hidden_dim, hidden_dim) + self.v_proj = nn.Linear(hidden_dim, hidden_dim) + self.o_proj = nn.Linear(hidden_dim, hidden_dim) + self.edge_proj = nn.Linear(hidden_dim, num_heads) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, edge_index: torch.Tensor, + edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor: + q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) + k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) + v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) + edge_attention = self.edge_proj(edge_attr) if edge_attr is not None else None + out = self.propagate(edge_index, q=q, k=k, v=v, edge_attr=edge_attention) + return self.o_proj(out.view(-1, self.hidden_dim)) + + def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, edge_attr: Optional[torch.Tensor]) -> torch.Tensor: + print(q_i.shape, k_j.shape, v_j.shape) + attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) + if edge_attr is not None: + print('edge:', edge_attr.shape, 'attention:', attention.shape) + if edge_attr.size(0) < attention.size(0): + num_repeats = attention.size(0) // edge_attr.size(0) + 1 + edge_attr = edge_attr.repeat(num_repeats, 1)[:attention.size(0)] + elif edge_attr.size(0) > attention.size(0): + edge_attr = edge_attr[:attention.size(0)] + attention = attention + edge_attr + + attention = torch.softmax(attention, dim=-1) + attention = self.dropout(attention) + + # Apply attention scores to the values + out = attention.unsqueeze(-1) * v_j + return out \ No newline at end of file diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py new file mode 100644 index 000000000000..a574d13f218e --- /dev/null +++ b/torch_geometric/transforms/exphormer.py @@ -0,0 +1,51 @@ +import torch.nn as nn +from torch_geometric.transforms import VirtualNode + +from torch_geometric.nn.attention.local import LocalAttention +from torch_geometric.nn.attention.expander import ExpanderAttention + + +class EXPHORMER(nn.Module): + """EXPHORMER architecture. + Based on the paper: https://arxiv.org/abs/2303.06147 + """ + def __init__(self, hidden_dim: int, num_layers: int = 3, num_heads: int = 4, expander_degree: int = 4, + dropout: float = 0.1, use_expander: bool = True, use_global: bool = True, num_virtual_nodes: int = 1): + super().__init__() + self.hidden_dim = hidden_dim + self.use_expander = use_expander + self.use_global = use_global + self.virtual_node_transform = VirtualNode() if use_global else None + self.num_virtual_nodes = num_virtual_nodes + self.layers = nn.ModuleList([ + nn.ModuleDict({ + 'local': LocalAttention(hidden_dim, num_heads=num_heads, dropout=dropout), + 'expander': ExpanderAttention(hidden_dim, expander_degree=expander_degree, num_heads=num_heads, dropout=dropout) if use_expander else None, + 'layer_norm': nn.LayerNorm(hidden_dim), + 'ffn': nn.Sequential( + nn.Linear(hidden_dim, 4 * hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(4 * hidden_dim, hidden_dim) + ) + }) for _ in range(num_layers) + ]) + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + x, edge_index = data.x, data.edge_index + edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None + batch_size = x.size(0) + if self.virtual_node_transform is not None: + data = self.virtual_node_transform(data) + x, edge_index = data.x, data.edge_index + + for layer in self.layers: + residual = x + local_out = layer['local'](x, edge_index, edge_attr) + expander_out = 0 + if self.use_expander and layer['expander'] is not None: + expander_out, _ = layer['expander'](x[:batch_size + self.num_virtual_nodes], batch_size + self.num_virtual_nodes) + x = layer['layer_norm'](residual + local_out + expander_out) + x = x + layer['ffn'](x) + return x[:batch_size] \ No newline at end of file From d29c887de3b13d496b9dce129c32d949e7d0ec55 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 17:36:03 +0000 Subject: [PATCH 02/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/transforms/test_exphormer.py | 31 +++++++++++++++--- torch_geometric/nn/attention/expander.py | 16 ++++++---- torch_geometric/nn/attention/local.py | 23 ++++++++------ torch_geometric/transforms/exphormer.py | 40 ++++++++++++++---------- 4 files changed, 74 insertions(+), 36 deletions(-) diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py index 2d5d2f7e7df6..b69f3c539a4d 100644 --- a/test/transforms/test_exphormer.py +++ b/test/transforms/test_exphormer.py @@ -1,9 +1,9 @@ import pytest import torch -from torch_geometric.data import Data -from torch_geometric.nn.attention.local import LocalAttention +from torch_geometric.data import Data from torch_geometric.nn.attention.expander import ExpanderAttention +from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.transforms.exphormer import EXPHORMER @@ -13,33 +13,39 @@ def create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16): edge_attr = torch.rand(num_nodes * 2, edge_dim) return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Tests for LocalAttention @pytest.fixture def local_attention_model(): hidden_dim, num_heads = 16, 4 return LocalAttention(hidden_dim=hidden_dim, num_heads=num_heads) + def test_local_attention_initialization(local_attention_model): assert local_attention_model.hidden_dim == 16 assert local_attention_model.num_heads == 4 assert local_attention_model.q_proj.weight.shape == (16, 16) + def test_local_attention_forward(local_attention_model): data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) output = local_attention_model(data.x, data.edge_index, data.edge_attr) assert output.shape == (data.x.shape[0], 16) + def test_local_attention_no_edge_attr(local_attention_model): data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) output_no_edge_attr = local_attention_model(data.x, data.edge_index, None) assert output_no_edge_attr.shape == (data.x.shape[0], 16) + def test_local_attention_mismatched_edge_attr(local_attention_model): data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) with pytest.raises(Exception): mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16) local_attention_model(data.x, data.edge_index, mismatched_edge_attr) + def test_local_attention_message(local_attention_model): q_i = torch.rand(20, 4, 4) k_j = torch.rand_like(q_i) @@ -48,29 +54,37 @@ def test_local_attention_message(local_attention_model): attention_scores = local_attention_model.message(q_i, k_j, v_j, edge_attr) assert attention_scores.shape == v_j.shape + # Tests for ExpanderAttention @pytest.fixture def expander_attention_model(): hidden_dim, num_heads, expander_degree = 16, 4, 4 - return ExpanderAttention(hidden_dim=hidden_dim, expander_degree=expander_degree, num_heads=num_heads) + return ExpanderAttention(hidden_dim=hidden_dim, + expander_degree=expander_degree, + num_heads=num_heads) + def test_expander_attention_initialization(expander_attention_model): assert expander_attention_model.expander_degree == 4 assert expander_attention_model.q_proj.weight.shape == (16, 16) + def test_expander_attention_generate_expander_edges(expander_attention_model): edge_index = expander_attention_model.generate_expander_edges(num_nodes=10) assert edge_index.shape[0] == 2 + def test_expander_attention_forward(expander_attention_model): data = create_mock_data(num_nodes=10, hidden_dim=16) output, edge_index = expander_attention_model(data.x, num_nodes=10) assert output.shape == (data.x.shape[0], 16) + def test_expander_attention_insufficient_nodes(expander_attention_model): with pytest.raises(Exception): expander_attention_model.generate_expander_edges(2) + def test_expander_attention_message(expander_attention_model): q_i = torch.rand(20, 4, 4) k_j = torch.rand_like(q_i) @@ -78,28 +92,35 @@ def test_expander_attention_message(expander_attention_model): attention_scores = expander_attention_model.message(q_i, k_j, v_j) assert attention_scores.shape == v_j.shape + # Tests for EXPHORMER @pytest.fixture def exphormer_model(): hidden_dim, num_layers, num_heads, expander_degree = 16, 3, 4, 4 - return EXPHORMER(hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, expander_degree=expander_degree) + return EXPHORMER(hidden_dim=hidden_dim, num_layers=num_layers, + num_heads=num_heads, expander_degree=expander_degree) + def test_exphormer_initialization(exphormer_model): assert len(exphormer_model.layers) == 3 assert isinstance(exphormer_model.layers[0]['local'], LocalAttention) if exphormer_model.use_expander: - assert isinstance(exphormer_model.layers[0]['expander'], ExpanderAttention) + assert isinstance(exphormer_model.layers[0]['expander'], + ExpanderAttention) + def test_exphormer_forward(exphormer_model): data = create_mock_data(num_nodes=10, hidden_dim=16) output = exphormer_model(data) assert output.shape == (data.x.shape[0], 16) + def test_exphormer_empty_graph(exphormer_model): data_empty = create_mock_data(num_nodes=1, hidden_dim=16) with pytest.raises(Exception): exphormer_model(data_empty) + def test_exphormer_incorrect_input_dimensions(exphormer_model): data_incorrect_dim = create_mock_data(num_nodes=10, hidden_dim=17) with pytest.raises(Exception): diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py index 08ba6cee777b..f9bb1a837fa3 100644 --- a/torch_geometric/nn/attention/expander.py +++ b/torch_geometric/nn/attention/expander.py @@ -1,14 +1,16 @@ +from typing import Tuple + +import numpy as np import torch import torch.nn as nn -import numpy as np -from typing import Tuple from torch_geometric.nn import MessagePassing class ExpanderAttention(MessagePassing): """Expander graph attention using random d-regular near-Ramanujan graphs.""" - def __init__(self, hidden_dim: int, expander_degree: int = 4, num_heads: int = 4, dropout: float = 0.1): + def __init__(self, hidden_dim: int, expander_degree: int = 4, + num_heads: int = 4, dropout: float = 0.1): super().__init__(aggr='add', node_dim=0) self.hidden_dim = hidden_dim self.expander_degree = expander_degree @@ -30,7 +32,8 @@ def generate_expander_edges(self, num_nodes: int) -> torch.Tensor: edge_index = torch.tensor(edges, dtype=torch.long).t() return edge_index - def forward(self, x: torch.Tensor, num_nodes: int) -> Tuple[torch.Tensor, torch.Tensor]: + def forward(self, x: torch.Tensor, + num_nodes: int) -> Tuple[torch.Tensor, torch.Tensor]: edge_index = self.generate_expander_edges(num_nodes).to(x.device) q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) @@ -38,8 +41,9 @@ def forward(self, x: torch.Tensor, num_nodes: int) -> Tuple[torch.Tensor, torch. out = self.propagate(edge_index, q=q, k=k, v=v) return self.o_proj(out.view(-1, self.hidden_dim)), edge_index - def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor) -> torch.Tensor: + def message(self, q_i: torch.Tensor, k_j: torch.Tensor, + v_j: torch.Tensor) -> torch.Tensor: attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) attention = torch.softmax(attention + self.edge_embedding, dim=-1) attention = self.dropout(attention) - return attention.unsqueeze(-1) * v_j \ No newline at end of file + return attention.unsqueeze(-1) * v_j diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index e8c8ce3b2839..6aed206391c1 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -1,15 +1,16 @@ +from typing import Optional + +import numpy as np import torch import torch.nn as nn -import numpy as np -from typing import Optional from torch_geometric.nn import MessagePassing class LocalAttention(MessagePassing): """Local neighborhood attention.""" - - def __init__(self, hidden_dim: int, num_heads: int = 4, dropout: float = 0.1): + def __init__(self, hidden_dim: int, num_heads: int = 4, + dropout: float = 0.1): super().__init__(aggr='add', node_dim=0) self.hidden_dim = hidden_dim self.num_heads = num_heads @@ -26,18 +27,22 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) - edge_attention = self.edge_proj(edge_attr) if edge_attr is not None else None - out = self.propagate(edge_index, q=q, k=k, v=v, edge_attr=edge_attention) + edge_attention = self.edge_proj( + edge_attr) if edge_attr is not None else None + out = self.propagate(edge_index, q=q, k=k, v=v, + edge_attr=edge_attention) return self.o_proj(out.view(-1, self.hidden_dim)) - def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, edge_attr: Optional[torch.Tensor]) -> torch.Tensor: + def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, + edge_attr: Optional[torch.Tensor]) -> torch.Tensor: print(q_i.shape, k_j.shape, v_j.shape) attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) if edge_attr is not None: print('edge:', edge_attr.shape, 'attention:', attention.shape) if edge_attr.size(0) < attention.size(0): num_repeats = attention.size(0) // edge_attr.size(0) + 1 - edge_attr = edge_attr.repeat(num_repeats, 1)[:attention.size(0)] + edge_attr = edge_attr.repeat(num_repeats, + 1)[:attention.size(0)] elif edge_attr.size(0) > attention.size(0): edge_attr = edge_attr[:attention.size(0)] attention = attention + edge_attr @@ -47,4 +52,4 @@ def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, edge_ # Apply attention scores to the values out = attention.unsqueeze(-1) * v_j - return out \ No newline at end of file + return out diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index a574d13f218e..68185a134429 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -1,16 +1,18 @@ import torch.nn as nn -from torch_geometric.transforms import VirtualNode -from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.nn.attention.expander import ExpanderAttention +from torch_geometric.nn.attention.local import LocalAttention +from torch_geometric.transforms import VirtualNode class EXPHORMER(nn.Module): """EXPHORMER architecture. - Based on the paper: https://arxiv.org/abs/2303.06147 + Based on the paper: https://arxiv.org/abs/2303.06147 """ - def __init__(self, hidden_dim: int, num_layers: int = 3, num_heads: int = 4, expander_degree: int = 4, - dropout: float = 0.1, use_expander: bool = True, use_global: bool = True, num_virtual_nodes: int = 1): + def __init__(self, hidden_dim: int, num_layers: int = 3, + num_heads: int = 4, expander_degree: int = 4, + dropout: float = 0.1, use_expander: bool = True, + use_global: bool = True, num_virtual_nodes: int = 1): super().__init__() self.hidden_dim = hidden_dim self.use_expander = use_expander @@ -19,15 +21,19 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, num_heads: int = 4, exp self.num_virtual_nodes = num_virtual_nodes self.layers = nn.ModuleList([ nn.ModuleDict({ - 'local': LocalAttention(hidden_dim, num_heads=num_heads, dropout=dropout), - 'expander': ExpanderAttention(hidden_dim, expander_degree=expander_degree, num_heads=num_heads, dropout=dropout) if use_expander else None, - 'layer_norm': nn.LayerNorm(hidden_dim), - 'ffn': nn.Sequential( - nn.Linear(hidden_dim, 4 * hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(4 * hidden_dim, hidden_dim) - ) + 'local': + LocalAttention(hidden_dim, num_heads=num_heads, + dropout=dropout), + 'expander': + ExpanderAttention(hidden_dim, expander_degree=expander_degree, + num_heads=num_heads, dropout=dropout) + if use_expander else None, + 'layer_norm': + nn.LayerNorm(hidden_dim), + 'ffn': + nn.Sequential(nn.Linear(hidden_dim, 4 * hidden_dim), nn.GELU(), + nn.Dropout(dropout), + nn.Linear(4 * hidden_dim, hidden_dim)) }) for _ in range(num_layers) ]) self.dropout = nn.Dropout(dropout) @@ -45,7 +51,9 @@ def forward(self, data): local_out = layer['local'](x, edge_index, edge_attr) expander_out = 0 if self.use_expander and layer['expander'] is not None: - expander_out, _ = layer['expander'](x[:batch_size + self.num_virtual_nodes], batch_size + self.num_virtual_nodes) + expander_out, _ = layer['expander']( + x[:batch_size + self.num_virtual_nodes], + batch_size + self.num_virtual_nodes) x = layer['layer_norm'](residual + local_out + expander_out) x = x + layer['ffn'](x) - return x[:batch_size] \ No newline at end of file + return x[:batch_size] From e2e9b683fdb305c5afeb6b6fca88301448411fa7 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 00:42:55 +0530 Subject: [PATCH 03/19] adding some exception handling to pass pytest --- torch_geometric/nn/attention/expander.py | 4 ++++ torch_geometric/nn/attention/local.py | 2 ++ torch_geometric/transforms/exphormer.py | 6 ++++++ 3 files changed, 12 insertions(+) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py index f9bb1a837fa3..daa7939f2d27 100644 --- a/torch_geometric/nn/attention/expander.py +++ b/torch_geometric/nn/attention/expander.py @@ -14,6 +14,8 @@ def __init__(self, hidden_dim: int, expander_degree: int = 4, super().__init__(aggr='add', node_dim=0) self.hidden_dim = hidden_dim self.expander_degree = expander_degree + if expander_degree % 2 != 0: + raise ValueError("expander_degree must be an even number.") self.num_heads = num_heads self.head_dim = hidden_dim // num_heads self.q_proj = nn.Linear(hidden_dim, hidden_dim) @@ -24,6 +26,8 @@ def __init__(self, hidden_dim: int, expander_degree: int = 4, self.dropout = nn.Dropout(dropout) def generate_expander_edges(self, num_nodes: int) -> torch.Tensor: + if num_nodes < self.expander_degree: + raise ValueError("Number of nodes is insufficient to generate expander edges.") edges = [] for _ in range(self.expander_degree // 2): perm = torch.randperm(num_nodes) diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index 6aed206391c1..02317d8173d0 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -24,6 +24,8 @@ def __init__(self, hidden_dim: int, num_heads: int = 4, def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor: + if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): + raise ValueError("edge_attr size does not match the number of edges in edge_index.") q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index 68185a134429..50de1bf0f79b 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -18,6 +18,8 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, self.use_expander = use_expander self.use_global = use_global self.virtual_node_transform = VirtualNode() if use_global else None + if use_global and num_virtual_nodes < 1: + raise ValueError("num_virtual_nodes must be at least 1 if use_global is enabled.") self.num_virtual_nodes = num_virtual_nodes self.layers = nn.ModuleList([ nn.ModuleDict({ @@ -39,6 +41,10 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, self.dropout = nn.Dropout(dropout) def forward(self, data): + if data.x.size(0) == 0: + raise ValueError("Input graph is empty.") + if not hasattr(data, 'edge_index') or data.edge_index is None: + raise ValueError("Input data must contain 'edge_index' for message passing.") x, edge_index = data.x, data.edge_index edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None batch_size = x.size(0) From fd86ec0fb786ca62ccf2f39086cb23990b963899 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 13 Nov 2024 19:14:27 +0000 Subject: [PATCH 04/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/attention/expander.py | 3 ++- torch_geometric/nn/attention/local.py | 4 +++- torch_geometric/transforms/exphormer.py | 7 +++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py index daa7939f2d27..661a365d6e32 100644 --- a/torch_geometric/nn/attention/expander.py +++ b/torch_geometric/nn/attention/expander.py @@ -27,7 +27,8 @@ def __init__(self, hidden_dim: int, expander_degree: int = 4, def generate_expander_edges(self, num_nodes: int) -> torch.Tensor: if num_nodes < self.expander_degree: - raise ValueError("Number of nodes is insufficient to generate expander edges.") + raise ValueError( + "Number of nodes is insufficient to generate expander edges.") edges = [] for _ in range(self.expander_degree // 2): perm = torch.randperm(num_nodes) diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index 02317d8173d0..484ba9c2f363 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -25,7 +25,9 @@ def __init__(self, hidden_dim: int, num_heads: int = 4, def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor: if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): - raise ValueError("edge_attr size does not match the number of edges in edge_index.") + raise ValueError( + "edge_attr size does not match the number of edges in edge_index." + ) q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index 50de1bf0f79b..92fe83d136f1 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -19,7 +19,9 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, self.use_global = use_global self.virtual_node_transform = VirtualNode() if use_global else None if use_global and num_virtual_nodes < 1: - raise ValueError("num_virtual_nodes must be at least 1 if use_global is enabled.") + raise ValueError( + "num_virtual_nodes must be at least 1 if use_global is enabled." + ) self.num_virtual_nodes = num_virtual_nodes self.layers = nn.ModuleList([ nn.ModuleDict({ @@ -44,7 +46,8 @@ def forward(self, data): if data.x.size(0) == 0: raise ValueError("Input graph is empty.") if not hasattr(data, 'edge_index') or data.edge_index is None: - raise ValueError("Input data must contain 'edge_index' for message passing.") + raise ValueError( + "Input data must contain 'edge_index' for message passing.") x, edge_index = data.x, data.edge_index edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None batch_size = x.size(0) From 8d64b444ff3d6c44aed190365aaa823d5738ddbe Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 00:48:45 +0530 Subject: [PATCH 05/19] Update local.py --- torch_geometric/nn/attention/local.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index 484ba9c2f363..6aed206391c1 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -24,10 +24,6 @@ def __init__(self, hidden_dim: int, num_heads: int = 4, def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor: - if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): - raise ValueError( - "edge_attr size does not match the number of edges in edge_index." - ) q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) From f9c7e9590f7f6911ffb44df12e02c588a00170c3 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 22:50:55 +0530 Subject: [PATCH 06/19] solving pytest maybe --- test/transforms/test_exphormer.py | 4 ++-- torch_geometric/nn/attention/local.py | 9 +++++++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py index b69f3c539a4d..63eb0fb9ea69 100644 --- a/test/transforms/test_exphormer.py +++ b/test/transforms/test_exphormer.py @@ -41,8 +41,8 @@ def test_local_attention_no_edge_attr(local_attention_model): def test_local_attention_mismatched_edge_attr(local_attention_model): data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) - with pytest.raises(Exception): - mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16) + mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16) + with pytest.raises(ValueError, match="edge_attr size does not match"): local_attention_model(data.x, data.edge_index, mismatched_edge_attr) diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index 6aed206391c1..a0910c53c4c2 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -24,6 +24,11 @@ def __init__(self, hidden_dim: int, num_heads: int = 4, def forward(self, x: torch.Tensor, edge_index: torch.Tensor, edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor: + + if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): + raise ValueError( + "edge_attr size does not match the number of edges in edge_index." + ) q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) @@ -35,10 +40,10 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, edge_attr: Optional[torch.Tensor]) -> torch.Tensor: - print(q_i.shape, k_j.shape, v_j.shape) + attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim) if edge_attr is not None: - print('edge:', edge_attr.shape, 'attention:', attention.shape) + if edge_attr.size(0) < attention.size(0): num_repeats = attention.size(0) // edge_attr.size(0) + 1 edge_attr = edge_attr.repeat(num_repeats, From 603f5c1f70a60fac7257c825f3ab98c749ab0b66 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 22:58:41 +0530 Subject: [PATCH 07/19] Update test_exphormer.py --- test/transforms/test_exphormer.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py index 63eb0fb9ea69..a2522a80474d 100644 --- a/test/transforms/test_exphormer.py +++ b/test/transforms/test_exphormer.py @@ -7,13 +7,18 @@ from torch_geometric.transforms.exphormer import EXPHORMER -def create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16): +def create_mock_data(num_nodes, hidden_dim, edge_dim=None): + from torch_geometric.data import Data x = torch.rand(num_nodes, hidden_dim) - edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) - edge_attr = torch.rand(num_nodes * 2, edge_dim) + edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) # Create a few random edges + + # Ensure edge_attr has the same number of rows as edges in edge_index + edge_attr = torch.rand(edge_index.size(1), edge_dim) if edge_dim else None + return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + # Tests for LocalAttention @pytest.fixture def local_attention_model(): From f7b64b64714f534bbd8c8876bdff65a14e33f581 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:30:18 +0000 Subject: [PATCH 08/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/transforms/test_exphormer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py index a2522a80474d..d0dd310b2fa1 100644 --- a/test/transforms/test_exphormer.py +++ b/test/transforms/test_exphormer.py @@ -1,7 +1,6 @@ import pytest import torch -from torch_geometric.data import Data from torch_geometric.nn.attention.expander import ExpanderAttention from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.transforms.exphormer import EXPHORMER @@ -10,7 +9,8 @@ def create_mock_data(num_nodes, hidden_dim, edge_dim=None): from torch_geometric.data import Data x = torch.rand(num_nodes, hidden_dim) - edge_index = torch.randint(0, num_nodes, (2, num_nodes * 2)) # Create a few random edges + edge_index = torch.randint(0, num_nodes, + (2, num_nodes * 2)) # Create a few random edges # Ensure edge_attr has the same number of rows as edges in edge_index edge_attr = torch.rand(edge_index.size(1), edge_dim) if edge_dim else None @@ -18,7 +18,6 @@ def create_mock_data(num_nodes, hidden_dim, edge_dim=None): return Data(x=x, edge_index=edge_index, edge_attr=edge_attr) - # Tests for LocalAttention @pytest.fixture def local_attention_model(): From a2b58f8d15411447f2944ebfdefd1778c8ea42bc Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 23:05:39 +0530 Subject: [PATCH 09/19] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 73a782026189..3a98fde352cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `Exphormer` implementation based on this [paper](https://arxiv.org/abs/2303.06147) ([#9783](https://github.com/pyg-team/pytorch_geometric/pull/9783)) - Added support for fast `Delaunay()` triangulation via the `torch_delaunay` package ([#9748](https://github.com/pyg-team/pytorch_geometric/pull/9748)) - Added PyTorch 2.5 support ([#9779](https://github.com/pyg-team/pytorch_geometric/pull/9779), [#9779](https://github.com/pyg-team/pytorch_geometric/pull/9780)) - Support 3D tetrahedral mesh elements of shape `[4, num_faces]` in the `FaceToEdge` transformation ([#9776](https://github.com/pyg-team/pytorch_geometric/pull/9776)) From 4ae11a5ee11984e18261dcb6088b91edbe12b7ac Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 23:08:03 +0530 Subject: [PATCH 10/19] pre-commit issue --- torch_geometric/transforms/exphormer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index 92fe83d136f1..f7a22d08526c 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -7,6 +7,7 @@ class EXPHORMER(nn.Module): """EXPHORMER architecture. + Based on the paper: https://arxiv.org/abs/2303.06147 """ def __init__(self, hidden_dim: int, num_layers: int = 3, From a50a0e5ae37fac7f7146ff1ba654468befd32fda Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:39:36 +0000 Subject: [PATCH 11/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/transforms/exphormer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index f7a22d08526c..e67580b48dd0 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -7,7 +7,7 @@ class EXPHORMER(nn.Module): """EXPHORMER architecture. - + Based on the paper: https://arxiv.org/abs/2303.06147 """ def __init__(self, hidden_dim: int, num_layers: int = 3, From 86508cc69e33317bbd982fb758ed60513e5cb936 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 23:23:51 +0530 Subject: [PATCH 12/19] pre-commit --- torch_geometric/nn/attention/expander.py | 2 +- torch_geometric/nn/attention/local.py | 2 +- torch_geometric/transforms/exphormer.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py index 661a365d6e32..8f00315a2acf 100644 --- a/torch_geometric/nn/attention/expander.py +++ b/torch_geometric/nn/attention/expander.py @@ -8,7 +8,7 @@ class ExpanderAttention(MessagePassing): - """Expander graph attention using random d-regular near-Ramanujan graphs.""" + """Expander attention with random d-regular, near-Ramanujan graphs.""" def __init__(self, hidden_dim: int, expander_degree: int = 4, num_heads: int = 4, dropout: float = 0.1): super().__init__(aggr='add', node_dim=0) diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index a0910c53c4c2..0467813d1208 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -27,7 +27,7 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): raise ValueError( - "edge_attr size does not match the number of edges in edge_index." + "edge_attr size doesn't match the no of edges in edge_index." ) q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index e67580b48dd0..78f684229c26 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -21,7 +21,7 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, self.virtual_node_transform = VirtualNode() if use_global else None if use_global and num_virtual_nodes < 1: raise ValueError( - "num_virtual_nodes must be at least 1 if use_global is enabled." + "num_virtual_nodes must be >= 1 if use_global is enabled." ) self.num_virtual_nodes = num_virtual_nodes self.layers = nn.ModuleList([ From 4981d8cfdd73c279c5d5a3c4ae851356c8185cc4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 17:55:26 +0000 Subject: [PATCH 13/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/attention/local.py | 3 +-- torch_geometric/transforms/exphormer.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index 0467813d1208..df5a33fd0c79 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -27,8 +27,7 @@ def forward(self, x: torch.Tensor, edge_index: torch.Tensor, if edge_attr is not None and edge_attr.size(0) != edge_index.size(1): raise ValueError( - "edge_attr size doesn't match the no of edges in edge_index." - ) + "edge_attr size doesn't match the no of edges in edge_index.") q = self.q_proj(x).view(-1, self.num_heads, self.head_dim) k = self.k_proj(x).view(-1, self.num_heads, self.head_dim) v = self.v_proj(x).view(-1, self.num_heads, self.head_dim) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index 78f684229c26..f07911561a37 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -21,8 +21,7 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, self.virtual_node_transform = VirtualNode() if use_global else None if use_global and num_virtual_nodes < 1: raise ValueError( - "num_virtual_nodes must be >= 1 if use_global is enabled." - ) + "num_virtual_nodes must be >= 1 if use_global is enabled.") self.num_virtual_nodes = num_virtual_nodes self.layers = nn.ModuleList([ nn.ModuleDict({ From 9f874c43afc3311d1c3b968b847e31bb1721d178 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 23:29:53 +0530 Subject: [PATCH 14/19] Update test_exphormer.py --- test/transforms/test_exphormer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/transforms/test_exphormer.py b/test/transforms/test_exphormer.py index d0dd310b2fa1..7e14be38a208 100644 --- a/test/transforms/test_exphormer.py +++ b/test/transforms/test_exphormer.py @@ -46,7 +46,7 @@ def test_local_attention_no_edge_attr(local_attention_model): def test_local_attention_mismatched_edge_attr(local_attention_model): data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16) mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16) - with pytest.raises(ValueError, match="edge_attr size does not match"): + with pytest.raises(ValueError, match="edge_attr size doesn't match"): local_attention_model(data.x, data.edge_index, mismatched_edge_attr) From fb1516d8c02eb84ed7eb0dfaa07140b0b1a19f1f Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Thu, 14 Nov 2024 23:49:31 +0530 Subject: [PATCH 15/19] trying solving linting error --- torch_geometric/nn/attention/expander.py | 8 +++++++- torch_geometric/nn/attention/local.py | 8 +++++++- torch_geometric/transforms/exphormer.py | 8 ++++---- 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py index 8f00315a2acf..f584543e1a7e 100644 --- a/torch_geometric/nn/attention/expander.py +++ b/torch_geometric/nn/attention/expander.py @@ -10,7 +10,7 @@ class ExpanderAttention(MessagePassing): """Expander attention with random d-regular, near-Ramanujan graphs.""" def __init__(self, hidden_dim: int, expander_degree: int = 4, - num_heads: int = 4, dropout: float = 0.1): + num_heads: int = 4, dropout: float = 0.1) -> None: super().__init__(aggr='add', node_dim=0) self.hidden_dim = hidden_dim self.expander_degree = expander_degree @@ -52,3 +52,9 @@ def message(self, q_i: torch.Tensor, k_j: torch.Tensor, attention = torch.softmax(attention + self.edge_embedding, dim=-1) attention = self.dropout(attention) return attention.unsqueeze(-1) * v_j + + def edge_update(self) -> torch.Tensor: + raise NotImplementedError("edge_update not implemented in ExpanderAttention.") + + def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("message_and_aggregate not implemented in ExpanderAttention.") diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index df5a33fd0c79..5bf7592802ff 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -10,7 +10,7 @@ class LocalAttention(MessagePassing): """Local neighborhood attention.""" def __init__(self, hidden_dim: int, num_heads: int = 4, - dropout: float = 0.1): + dropout: float = 0.1) -> None: super().__init__(aggr='add', node_dim=0) self.hidden_dim = hidden_dim self.num_heads = num_heads @@ -57,3 +57,9 @@ def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, # Apply attention scores to the values out = attention.unsqueeze(-1) * v_j return out + + def edge_update(self) -> torch.Tensor: + raise NotImplementedError("edge_update not implemented in LocalAttention.") + + def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor: + raise NotImplementedError("message_and_aggregate not implemented in LocalAttention.") diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index f07911561a37..b39a458bdca9 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -1,5 +1,5 @@ import torch.nn as nn - +import torch from torch_geometric.nn.attention.expander import ExpanderAttention from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.transforms import VirtualNode @@ -13,7 +13,7 @@ class EXPHORMER(nn.Module): def __init__(self, hidden_dim: int, num_layers: int = 3, num_heads: int = 4, expander_degree: int = 4, dropout: float = 0.1, use_expander: bool = True, - use_global: bool = True, num_virtual_nodes: int = 1): + use_global: bool = True, num_virtual_nodes: int = 1) -> None: super().__init__() self.hidden_dim = hidden_dim self.use_expander = use_expander @@ -31,7 +31,7 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, 'expander': ExpanderAttention(hidden_dim, expander_degree=expander_degree, num_heads=num_heads, dropout=dropout) - if use_expander else None, + if use_expander else nn.Identity(), 'layer_norm': nn.LayerNorm(hidden_dim), 'ffn': @@ -42,7 +42,7 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, ]) self.dropout = nn.Dropout(dropout) - def forward(self, data): + def forward(self, data) -> torch.Tensor: if data.x.size(0) == 0: raise ValueError("Input graph is empty.") if not hasattr(data, 'edge_index') or data.edge_index is None: From 39316ff99d7c2bed1209ec75d68b407e94e77327 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:21:10 +0000 Subject: [PATCH 16/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/nn/attention/expander.py | 6 ++++-- torch_geometric/nn/attention/local.py | 6 ++++-- torch_geometric/transforms/exphormer.py | 3 ++- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/torch_geometric/nn/attention/expander.py b/torch_geometric/nn/attention/expander.py index f584543e1a7e..7d5cbee654e6 100644 --- a/torch_geometric/nn/attention/expander.py +++ b/torch_geometric/nn/attention/expander.py @@ -54,7 +54,9 @@ def message(self, q_i: torch.Tensor, k_j: torch.Tensor, return attention.unsqueeze(-1) * v_j def edge_update(self) -> torch.Tensor: - raise NotImplementedError("edge_update not implemented in ExpanderAttention.") + raise NotImplementedError( + "edge_update not implemented in ExpanderAttention.") def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("message_and_aggregate not implemented in ExpanderAttention.") + raise NotImplementedError( + "message_and_aggregate not implemented in ExpanderAttention.") diff --git a/torch_geometric/nn/attention/local.py b/torch_geometric/nn/attention/local.py index 5bf7592802ff..be5904537cf9 100644 --- a/torch_geometric/nn/attention/local.py +++ b/torch_geometric/nn/attention/local.py @@ -59,7 +59,9 @@ def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor, return out def edge_update(self) -> torch.Tensor: - raise NotImplementedError("edge_update not implemented in LocalAttention.") + raise NotImplementedError( + "edge_update not implemented in LocalAttention.") def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor: - raise NotImplementedError("message_and_aggregate not implemented in LocalAttention.") + raise NotImplementedError( + "message_and_aggregate not implemented in LocalAttention.") diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index b39a458bdca9..24a7ef4cfc71 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -1,5 +1,6 @@ -import torch.nn as nn import torch +import torch.nn as nn + from torch_geometric.nn.attention.expander import ExpanderAttention from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.transforms import VirtualNode From abccb1183b70d46009c9eb6d9ad1ecfbc638aaf5 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Fri, 15 Nov 2024 00:04:27 +0530 Subject: [PATCH 17/19] adding annotations --- torch_geometric/transforms/exphormer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index 24a7ef4cfc71..0c81c50336c5 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -4,7 +4,7 @@ from torch_geometric.nn.attention.expander import ExpanderAttention from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.transforms import VirtualNode - +from torch_geometric.data import Data class EXPHORMER(nn.Module): """EXPHORMER architecture. @@ -43,7 +43,7 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, ]) self.dropout = nn.Dropout(dropout) - def forward(self, data) -> torch.Tensor: + def forward(self, data:Data) -> torch.Tensor: if data.x.size(0) == 0: raise ValueError("Input graph is empty.") if not hasattr(data, 'edge_index') or data.edge_index is None: From ebd2bf36e1ac53a571e521cf1bc81375623a8d4d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 18:36:38 +0000 Subject: [PATCH 18/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- torch_geometric/transforms/exphormer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index 0c81c50336c5..bd875542fad6 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -1,10 +1,11 @@ import torch import torch.nn as nn +from torch_geometric.data import Data from torch_geometric.nn.attention.expander import ExpanderAttention from torch_geometric.nn.attention.local import LocalAttention from torch_geometric.transforms import VirtualNode -from torch_geometric.data import Data + class EXPHORMER(nn.Module): """EXPHORMER architecture. @@ -43,7 +44,7 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, ]) self.dropout = nn.Dropout(dropout) - def forward(self, data:Data) -> torch.Tensor: + def forward(self, data: Data) -> torch.Tensor: if data.x.size(0) == 0: raise ValueError("Input graph is empty.") if not hasattr(data, 'edge_index') or data.edge_index is None: From 39ee4f9d186ce56b9e3aed9067c68282c9396229 Mon Sep 17 00:00:00 2001 From: Aryan Saini Date: Sat, 16 Nov 2024 22:44:09 +0530 Subject: [PATCH 19/19] Update exphormer.py --- torch_geometric/transforms/exphormer.py | 37 +++++++++++++++++++------ 1 file changed, 29 insertions(+), 8 deletions(-) diff --git a/torch_geometric/transforms/exphormer.py b/torch_geometric/transforms/exphormer.py index bd875542fad6..5f0d2218bdb2 100644 --- a/torch_geometric/transforms/exphormer.py +++ b/torch_geometric/transforms/exphormer.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch import torch.nn as nn @@ -45,26 +47,45 @@ def __init__(self, hidden_dim: int, num_layers: int = 3, self.dropout = nn.Dropout(dropout) def forward(self, data: Data) -> torch.Tensor: - if data.x.size(0) == 0: + + if data.x is None: + raise ValueError("Input data.x cannot be None") + x: torch.Tensor = data.x + + if x.size(0) == 0: raise ValueError("Input graph is empty.") + if not hasattr(data, 'edge_index') or data.edge_index is None: raise ValueError( "Input data must contain 'edge_index' for message passing.") - x, edge_index = data.x, data.edge_index - edge_attr = data.edge_attr if hasattr(data, 'edge_attr') else None + edge_index: torch.Tensor = data.edge_index + + edge_attr: Optional[torch.Tensor] = data.edge_attr if hasattr( + data, 'edge_attr') else None + batch_size = x.size(0) + if self.virtual_node_transform is not None: data = self.virtual_node_transform(data) - x, edge_index = data.x, data.edge_index + if data.x is None: + raise ValueError("Virtual node transform resulted in None x") + x = data.x + if data.edge_index is None: + raise ValueError( + "Virtual node transform resulted in None edge_index") + edge_index = data.edge_index for layer in self.layers: residual = x local_out = layer['local'](x, edge_index, edge_attr) - expander_out = 0 + expander_out = torch.zeros_like(x) if self.use_expander and layer['expander'] is not None: - expander_out, _ = layer['expander']( - x[:batch_size + self.num_virtual_nodes], - batch_size + self.num_virtual_nodes) + + node_subset = x[:batch_size + self.num_virtual_nodes] + expander_out, _ = layer['expander'](node_subset, batch_size + + self.num_virtual_nodes) + x = layer['layer_norm'](residual + local_out + expander_out) x = x + layer['ffn'](x) + return x[:batch_size]