Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Hyperbolic GCN layer of paper [Hyperbolic Graph Convolutional Neural Networks](https://arxiv.org/abs/1910.12933) #9423

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `HGCNConv` layer ([#9423](https://github.com/pyg-team/pytorch_geometric/pull/9423))
- Added the heterogeneous `HeteroJumpingKnowledge` module ([#9380](https://github.com/pyg-team/pytorch_geometric/pull/9380))
- Started work on GNN+LLM package ([#9350](https://github.com/pyg-team/pytorch_geometric/pull/9350))
- Added support for negative sampling in `LinkLoader` acccording to source and destination node weights ([#9316](https://github.com/pyg-team/pytorch_geometric/pull/9316))
Expand Down
108 changes: 108 additions & 0 deletions examples/hgcn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import argparse
import os.path as osp
import time

import torch
import torch.nn.functional as F

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.logging import init_wandb, log
from torch_geometric.nn import HGCNConv

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='Cora')
parser.add_argument('--hidden_channels', type=int, default=128)
parser.add_argument('--heads', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.1e-3)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--wandb', action='store_true', help='Track experiment')
args = parser.parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
init_wandb(name=f'HGCN-{args.dataset}', heads=args.heads, epochs=args.epochs,
hidden_channels=args.hidden_channels, lr=args.lr, device=device)

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Planetoid')
dataset = Planetoid(path, args.dataset, transform=T.NormalizeFeatures())
data = dataset[0].to(device)


class LinearDecoder(torch.nn.Module):
"""MLP Decoder for Hyperbolic/Euclidean node classification models."""
def __init__(self, manifold, in_dim, out_dim, c):
super(LinearDecoder, self).__init__()
self.manifold = manifold
self.input_dim = in_dim
self.output_dim = out_dim
self.cls = torch.nn.Linear(self.input_dim, self.output_dim)
self.c = torch.nn.Parameter(torch.FloatTensor([c]))

def forward(self, x):
x = self.manifold.proj_tan0(self.manifold.logmap0(x, c=self.c),
c=self.c)
x = self.cls(x)
return x

def extra_repr(self):
return 'in_features={}, out_features={}, c={}'.format(
self.input_dim, self.output_dim, self.c)


class HGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = HGCNConv(in_channels, hidden_channels, dropout=0.5)
self.decoder = LinearDecoder(self.conv1.manifold, hidden_channels,
out_channels, 1.0)

def forward(self, x, edge_index):
# print(x.shape, edge_index.shape)
x = self.conv1(x, edge_index)
x = F.dropout(x, p=0.6, training=self.training)
x = self.decoder(x)
x = F.log_softmax(x, dim=-1)
# print(x.shape)
# exit(0)
return x


model = HGCN(dataset.num_features, args.hidden_channels,
dataset.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
print(model)


def train():
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return float(loss)


@torch.no_grad()
def test():
model.eval()
pred = model(data.x, data.edge_index).argmax(dim=-1)

accs = []
for mask in [data.train_mask, data.val_mask, data.test_mask]:
accs.append(int((pred[mask] == data.y[mask]).sum()) / int(mask.sum()))
return accs


times = []
best_val_acc = final_test_acc = 0
for epoch in range(1, args.epochs + 1):
start = time.time()
loss = train()
train_acc, val_acc, tmp_test_acc = test()
if val_acc > best_val_acc:
best_val_acc = val_acc
test_acc = tmp_test_acc
log(Epoch=epoch, Loss=loss, Train=train_acc, Val=val_acc, Test=test_acc)
times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
94 changes: 94 additions & 0 deletions test/nn/conv/test_hgcn_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch

from torch_geometric.nn.conv.hgcn_conv import (
HGCNConv,
HypAct,
Hyperboloid,
HypLinear,
PoincareBall,
)
from torch_geometric.utils import to_torch_csc_tensor


def test_hgcn_conv_hyperboloid_forward():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
value = torch.rand(edge_index.size(1))
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))

conv = HGCNConv(16, 32, manifold='hyperboloid')
assert str(conv) == 'HGCNConv(16, 32)'

out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)

out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
return


def test_hgcn_conv_poincareBall_forward():
x = torch.randn(4, 16)
edge_index = torch.tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]])
value = torch.rand(edge_index.size(1))
adj1 = to_torch_csc_tensor(edge_index, size=(4, 4))
adj2 = to_torch_csc_tensor(edge_index, value, size=(4, 4))

conv = HGCNConv(16, 32, manifold='poincare')
assert str(conv) == 'HGCNConv(16, 32)'

out1 = conv(x, edge_index)
assert out1.size() == (4, 32)
assert torch.allclose(conv(x, adj1.t()), out1, atol=1e-6)

out2 = conv(x, edge_index, value)
assert out2.size() == (4, 32)
assert torch.allclose(conv(x, adj2.t()), out2, atol=1e-6)
return


def test_hgcn_linear_hyperboloid_forward():
x = torch.randn(4, 16)
manifold = Hyperboloid()

linaer = HypLinear(manifold, 16, 32, 1.0, 0.0, True)

out1 = linaer(x)
assert out1.size() == (4, 32)
return


def test_hgcn_linear_poincareBall_forward():
x = torch.randn(4, 16)
manifold = PoincareBall()

linaer = HypLinear(manifold, 16, 32, 1.0, 0.0, True)

out1 = linaer(x)
assert out1.size() == (4, 32)
return


def test_hypact_hyperboloid_forward():
x = torch.randn(4, 16)
manifold = Hyperboloid()

hypact = HypAct(manifold, 1.0, 1.0, None)

out1 = hypact(x)
assert out1.size() == (4, 16)
return


def test_hypact_poincareBall_forward():
x = torch.randn(4, 16)
manifold = PoincareBall()

hypact = HypAct(manifold, 1.0, 1.0, None)

out1 = hypact(x)
assert out1.size() == (4, 16)
return
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .antisymmetric_conv import AntiSymmetricConv
from .dir_gnn_conv import DirGNNConv
from .mixhop_conv import MixHopConv
from .hgcn_conv import HGCNConv

import torch_geometric.nn.conv.utils # noqa

Expand Down Expand Up @@ -131,6 +132,7 @@
'AntiSymmetricConv',
'DirGNNConv',
'MixHopConv',
'HGCNConv',
]

classes = __all__
Expand Down
Loading
Loading