Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Exphormer Implementation #9783

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
31757b3
initial commit
phoeenniixx Nov 13, 2024
d29c887
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
e2e9b68
adding some exception handling to pass pytest
phoeenniixx Nov 13, 2024
fd86ec0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 13, 2024
8d64b44
Update local.py
phoeenniixx Nov 13, 2024
f9c7e95
solving pytest maybe
phoeenniixx Nov 14, 2024
603f5c1
Update test_exphormer.py
phoeenniixx Nov 14, 2024
f7b64b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
a2b58f8
Update CHANGELOG.md
phoeenniixx Nov 14, 2024
4ae11a5
pre-commit issue
phoeenniixx Nov 14, 2024
a50a0e5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
86508cc
pre-commit
phoeenniixx Nov 14, 2024
4981d8c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
9f874c4
Update test_exphormer.py
phoeenniixx Nov 14, 2024
fb1516d
trying solving linting error
phoeenniixx Nov 14, 2024
39316ff
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
abccb11
adding annotations
phoeenniixx Nov 14, 2024
ebd2bf3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 14, 2024
39ee4f9
Update exphormer.py
phoeenniixx Nov 16, 2024
0040c33
Merge branch 'master' into exphormer
puririshi98 Jan 24, 2025
8e15e62
Merge branch 'master' into exphormer
puririshi98 Jan 31, 2025
e367f9e
Merge branch 'master' into exphormer
puririshi98 Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `Exphormer` implementation based on this [paper](https://arxiv.org/abs/2303.06147) ([#9783](https://github.com/pyg-team/pytorch_geometric/pull/9783))
- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
- Added support for weighted `LinkPredNDCG` metric ([#9945](https://github.com/pyg-team/pytorch_geometric/pull/9945))
Expand Down
131 changes: 131 additions & 0 deletions test/transforms/test_exphormer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import pytest
import torch

from torch_geometric.nn.attention.expander import ExpanderAttention
from torch_geometric.nn.attention.local import LocalAttention
from torch_geometric.transforms.exphormer import EXPHORMER


def create_mock_data(num_nodes, hidden_dim, edge_dim=None):
from torch_geometric.data import Data
x = torch.rand(num_nodes, hidden_dim)
edge_index = torch.randint(0, num_nodes,
(2, num_nodes * 2)) # Create a few random edges

# Ensure edge_attr has the same number of rows as edges in edge_index
edge_attr = torch.rand(edge_index.size(1), edge_dim) if edge_dim else None

return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)


# Tests for LocalAttention
@pytest.fixture
def local_attention_model():
hidden_dim, num_heads = 16, 4
return LocalAttention(hidden_dim=hidden_dim, num_heads=num_heads)


def test_local_attention_initialization(local_attention_model):
assert local_attention_model.hidden_dim == 16
assert local_attention_model.num_heads == 4
assert local_attention_model.q_proj.weight.shape == (16, 16)


def test_local_attention_forward(local_attention_model):
data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16)
output = local_attention_model(data.x, data.edge_index, data.edge_attr)
assert output.shape == (data.x.shape[0], 16)


def test_local_attention_no_edge_attr(local_attention_model):
data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16)
output_no_edge_attr = local_attention_model(data.x, data.edge_index, None)
assert output_no_edge_attr.shape == (data.x.shape[0], 16)


def test_local_attention_mismatched_edge_attr(local_attention_model):
data = create_mock_data(num_nodes=10, hidden_dim=16, edge_dim=16)
mismatched_edge_attr = torch.rand(data.edge_attr.size(0) - 1, 16)
with pytest.raises(ValueError, match="edge_attr size doesn't match"):
local_attention_model(data.x, data.edge_index, mismatched_edge_attr)


def test_local_attention_message(local_attention_model):
q_i = torch.rand(20, 4, 4)
k_j = torch.rand_like(q_i)
v_j = torch.rand_like(q_i)
edge_attr = torch.rand(20, 4)
attention_scores = local_attention_model.message(q_i, k_j, v_j, edge_attr)
assert attention_scores.shape == v_j.shape


# Tests for ExpanderAttention
@pytest.fixture
def expander_attention_model():
hidden_dim, num_heads, expander_degree = 16, 4, 4
return ExpanderAttention(hidden_dim=hidden_dim,
expander_degree=expander_degree,
num_heads=num_heads)


def test_expander_attention_initialization(expander_attention_model):
assert expander_attention_model.expander_degree == 4
assert expander_attention_model.q_proj.weight.shape == (16, 16)


def test_expander_attention_generate_expander_edges(expander_attention_model):
edge_index = expander_attention_model.generate_expander_edges(num_nodes=10)
assert edge_index.shape[0] == 2


def test_expander_attention_forward(expander_attention_model):
data = create_mock_data(num_nodes=10, hidden_dim=16)
output, edge_index = expander_attention_model(data.x, num_nodes=10)
assert output.shape == (data.x.shape[0], 16)


def test_expander_attention_insufficient_nodes(expander_attention_model):
with pytest.raises(Exception):
expander_attention_model.generate_expander_edges(2)


def test_expander_attention_message(expander_attention_model):
q_i = torch.rand(20, 4, 4)
k_j = torch.rand_like(q_i)
v_j = torch.rand_like(q_i)
attention_scores = expander_attention_model.message(q_i, k_j, v_j)
assert attention_scores.shape == v_j.shape


# Tests for EXPHORMER
@pytest.fixture
def exphormer_model():
hidden_dim, num_layers, num_heads, expander_degree = 16, 3, 4, 4
return EXPHORMER(hidden_dim=hidden_dim, num_layers=num_layers,
num_heads=num_heads, expander_degree=expander_degree)


def test_exphormer_initialization(exphormer_model):
assert len(exphormer_model.layers) == 3
assert isinstance(exphormer_model.layers[0]['local'], LocalAttention)
if exphormer_model.use_expander:
assert isinstance(exphormer_model.layers[0]['expander'],
ExpanderAttention)


def test_exphormer_forward(exphormer_model):
data = create_mock_data(num_nodes=10, hidden_dim=16)
output = exphormer_model(data)
assert output.shape == (data.x.shape[0], 16)


def test_exphormer_empty_graph(exphormer_model):
data_empty = create_mock_data(num_nodes=1, hidden_dim=16)
with pytest.raises(Exception):
exphormer_model(data_empty)


def test_exphormer_incorrect_input_dimensions(exphormer_model):
data_incorrect_dim = create_mock_data(num_nodes=10, hidden_dim=17)
with pytest.raises(Exception):
exphormer_model(data_incorrect_dim)
62 changes: 62 additions & 0 deletions torch_geometric/nn/attention/expander.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn

from torch_geometric.nn import MessagePassing


class ExpanderAttention(MessagePassing):
"""Expander attention with random d-regular, near-Ramanujan graphs."""
def __init__(self, hidden_dim: int, expander_degree: int = 4,
num_heads: int = 4, dropout: float = 0.1) -> None:
super().__init__(aggr='add', node_dim=0)
self.hidden_dim = hidden_dim
self.expander_degree = expander_degree
if expander_degree % 2 != 0:
raise ValueError("expander_degree must be an even number.")
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
self.edge_embedding = nn.Parameter(torch.randn(1, num_heads))
self.dropout = nn.Dropout(dropout)

def generate_expander_edges(self, num_nodes: int) -> torch.Tensor:
if num_nodes < self.expander_degree:
raise ValueError(
"Number of nodes is insufficient to generate expander edges.")
edges = []
for _ in range(self.expander_degree // 2):
perm = torch.randperm(num_nodes)
edges.extend([(i, perm[i].item()) for i in range(num_nodes)])
edges.extend([(perm[i].item(), i) for i in range(num_nodes)])
edge_index = torch.tensor(edges, dtype=torch.long).t()
return edge_index

def forward(self, x: torch.Tensor,
num_nodes: int) -> Tuple[torch.Tensor, torch.Tensor]:
edge_index = self.generate_expander_edges(num_nodes).to(x.device)
q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
out = self.propagate(edge_index, q=q, k=k, v=v)
return self.o_proj(out.view(-1, self.hidden_dim)), edge_index

def message(self, q_i: torch.Tensor, k_j: torch.Tensor,
v_j: torch.Tensor) -> torch.Tensor:
attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim)
attention = torch.softmax(attention + self.edge_embedding, dim=-1)
attention = self.dropout(attention)
return attention.unsqueeze(-1) * v_j

def edge_update(self) -> torch.Tensor:
raise NotImplementedError(
"edge_update not implemented in ExpanderAttention.")

def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"message_and_aggregate not implemented in ExpanderAttention.")
67 changes: 67 additions & 0 deletions torch_geometric/nn/attention/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from typing import Optional

import numpy as np
import torch
import torch.nn as nn

from torch_geometric.nn import MessagePassing


class LocalAttention(MessagePassing):
"""Local neighborhood attention."""
def __init__(self, hidden_dim: int, num_heads: int = 4,
dropout: float = 0.1) -> None:
super().__init__(aggr='add', node_dim=0)
self.hidden_dim = hidden_dim
self.num_heads = num_heads
self.head_dim = hidden_dim // num_heads
self.q_proj = nn.Linear(hidden_dim, hidden_dim)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.o_proj = nn.Linear(hidden_dim, hidden_dim)
self.edge_proj = nn.Linear(hidden_dim, num_heads)
self.dropout = nn.Dropout(dropout)

def forward(self, x: torch.Tensor, edge_index: torch.Tensor,
edge_attr: Optional[torch.Tensor] = None) -> torch.Tensor:

if edge_attr is not None and edge_attr.size(0) != edge_index.size(1):
raise ValueError(
"edge_attr size doesn't match the no of edges in edge_index.")
q = self.q_proj(x).view(-1, self.num_heads, self.head_dim)
k = self.k_proj(x).view(-1, self.num_heads, self.head_dim)
v = self.v_proj(x).view(-1, self.num_heads, self.head_dim)
edge_attention = self.edge_proj(
edge_attr) if edge_attr is not None else None
out = self.propagate(edge_index, q=q, k=k, v=v,
edge_attr=edge_attention)
return self.o_proj(out.view(-1, self.hidden_dim))

def message(self, q_i: torch.Tensor, k_j: torch.Tensor, v_j: torch.Tensor,
edge_attr: Optional[torch.Tensor]) -> torch.Tensor:

attention = (q_i * k_j).sum(dim=-1) / np.sqrt(self.head_dim)
if edge_attr is not None:

if edge_attr.size(0) < attention.size(0):
num_repeats = attention.size(0) // edge_attr.size(0) + 1
edge_attr = edge_attr.repeat(num_repeats,
1)[:attention.size(0)]
elif edge_attr.size(0) > attention.size(0):
edge_attr = edge_attr[:attention.size(0)]
attention = attention + edge_attr

attention = torch.softmax(attention, dim=-1)
attention = self.dropout(attention)

# Apply attention scores to the values
out = attention.unsqueeze(-1) * v_j
return out

def edge_update(self) -> torch.Tensor:
raise NotImplementedError(
"edge_update not implemented in LocalAttention.")

def message_and_aggregate(self, edge_index: torch.Tensor) -> torch.Tensor:
raise NotImplementedError(
"message_and_aggregate not implemented in LocalAttention.")
91 changes: 91 additions & 0 deletions torch_geometric/transforms/exphormer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from typing import Optional

import torch
import torch.nn as nn

from torch_geometric.data import Data
from torch_geometric.nn.attention.expander import ExpanderAttention
from torch_geometric.nn.attention.local import LocalAttention
from torch_geometric.transforms import VirtualNode


class EXPHORMER(nn.Module):
"""EXPHORMER architecture.

Based on the paper: https://arxiv.org/abs/2303.06147
"""
def __init__(self, hidden_dim: int, num_layers: int = 3,
num_heads: int = 4, expander_degree: int = 4,
dropout: float = 0.1, use_expander: bool = True,
use_global: bool = True, num_virtual_nodes: int = 1) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.use_expander = use_expander
self.use_global = use_global
self.virtual_node_transform = VirtualNode() if use_global else None
if use_global and num_virtual_nodes < 1:
raise ValueError(
"num_virtual_nodes must be >= 1 if use_global is enabled.")
self.num_virtual_nodes = num_virtual_nodes
self.layers = nn.ModuleList([
nn.ModuleDict({
'local':
LocalAttention(hidden_dim, num_heads=num_heads,
dropout=dropout),
'expander':
ExpanderAttention(hidden_dim, expander_degree=expander_degree,
num_heads=num_heads, dropout=dropout)
if use_expander else nn.Identity(),
'layer_norm':
nn.LayerNorm(hidden_dim),
'ffn':
nn.Sequential(nn.Linear(hidden_dim, 4 * hidden_dim), nn.GELU(),
nn.Dropout(dropout),
nn.Linear(4 * hidden_dim, hidden_dim))
}) for _ in range(num_layers)
])
self.dropout = nn.Dropout(dropout)

def forward(self, data: Data) -> torch.Tensor:

if data.x is None:
raise ValueError("Input data.x cannot be None")
x: torch.Tensor = data.x

if x.size(0) == 0:
raise ValueError("Input graph is empty.")

if not hasattr(data, 'edge_index') or data.edge_index is None:
raise ValueError(
"Input data must contain 'edge_index' for message passing.")
edge_index: torch.Tensor = data.edge_index

edge_attr: Optional[torch.Tensor] = data.edge_attr if hasattr(
data, 'edge_attr') else None

batch_size = x.size(0)

if self.virtual_node_transform is not None:
data = self.virtual_node_transform(data)
if data.x is None:
raise ValueError("Virtual node transform resulted in None x")
x = data.x
if data.edge_index is None:
raise ValueError(
"Virtual node transform resulted in None edge_index")
edge_index = data.edge_index

for layer in self.layers:
residual = x
local_out = layer['local'](x, edge_index, edge_attr)
expander_out = torch.zeros_like(x)
if self.use_expander and layer['expander'] is not None:

node_subset = x[:batch_size + self.num_virtual_nodes]
expander_out, _ = layer['expander'](node_subset, batch_size +
self.num_virtual_nodes)

x = layer['layer_norm'](residual + local_out + expander_out)
x = x + layer['ffn'](x)

return x[:batch_size]
Loading