Skip to content

Commit

Permalink
Adapting experiments of GCN and Simplex2Vec to PyTorch Lightning expe…
Browse files Browse the repository at this point in the history
…riments. Rearranging entries and cleaning up files.
  • Loading branch information
rballeba committed May 1, 2024
1 parent beda461 commit 55e48ce
Show file tree
Hide file tree
Showing 24 changed files with 477 additions and 131 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/black.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name: Enforce coding style

on: [push, pull_request]
on: [ push, pull_request ]

jobs:
lint:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
import torchmetrics

from mantra.utils import transfer_simplicial_complex_batch_to_device


class BaseOrientabilityModule(L.LightningModule):
def __init__(self):
Expand All @@ -15,11 +13,6 @@ def __init__(self):
self.validation_accuracy = torchmetrics.classification.BinaryAccuracy()
self.test_accuracy = torchmetrics.classification.BinaryAccuracy()

def transfer_batch_to_device(self, batch, device, dataloader_idx):
return transfer_simplicial_complex_batch_to_device(
batch, device, dataloader_idx
)

def log_accuracies(
self, x_hat, y, batch_len, step: Literal["train", "test", "validation"]
):
Expand Down
33 changes: 33 additions & 0 deletions experiments/lightning_modules/GraphCommonModuleOrientability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from torch import nn

from experiments.lightning_modules.BaseModuleOrientability import (
BaseOrientabilityModule,
)


class GraphCommonModuleOrientability(BaseOrientabilityModule):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model

def forward(self, x, edge_index, batch):
x = self.base_model(x, edge_index, batch)
return x

def general_step(self, batch, batch_idx, step: str):
batch_len = len(batch.y)
x_hat = self(batch.x, batch.edge_index, batch.batch)
# Squeeze x_hat to match the shape of y
x_hat = x_hat.squeeze()
y = batch.y.float()
loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y)
self.log(
f"{step}_loss",
loss,
prog_bar=True,
batch_size=batch_len,
on_step=False,
on_epoch=True,
)
self.log_accuracies(x_hat, y, batch_len, step)
return loss
File renamed without changes.
87 changes: 87 additions & 0 deletions experiments/orientability/graphs/GATSimplex2Vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import lightning as L
import torch
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torch_geometric.data import DataLoader
from torch_geometric.transforms import FaceToEdge

from experiments.lightning_modules.GraphCommonModuleOrientability import (
GraphCommonModuleOrientability,
)
from mantra.simplicial import SimplicialDataset
from mantra.transforms import (
TriangulationToFaceTransform,
SetNumNodesTransform,
DegreeTransform,
OrientableToClassTransform,
Simplex2VecTransform,
)
from models.graphs.GAT import GATNetwork


class GATSimplexToVecModule(GraphCommonModuleOrientability):
def __init__(
self,
hidden_channels,
num_node_features,
out_channels,
num_heads,
num_hidden_layers,
learning_rate=0.0001,
):
base_model = GATNetwork(
hidden_channels=hidden_channels,
num_node_features=num_node_features,
out_channels=out_channels,
num_heads=num_heads,
num_hidden_layers=num_hidden_layers,
)
super().__init__(base_model=base_model)
self.learning_rate = learning_rate

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer


def load_dataset_with_transformations():
tr = transforms.Compose(
[
TriangulationToFaceTransform(),
SetNumNodesTransform(),
FaceToEdge(remove_faces=False),
DegreeTransform(),
OrientableToClassTransform(),
Simplex2VecTransform(),
]
)
dataset = SimplicialDataset(root="./data", transform=tr)
return dataset


def single_experiment_orientability_gat_simplex2vec():
# ===============================
# Training parameters
# ===============================
hidden_channels = 64
num_hidden_layers = 2
num_heads = 4
batch_size = 32
max_epochs = 100
learning_rate = 0.0001
# ===============================
dataset = load_dataset_with_transformations()
model = GATSimplexToVecModule(
hidden_channels=hidden_channels,
num_node_features=dataset.num_node_features,
out_channels=1,
num_heads=num_heads,
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_orientability_indices)
test_ds = Subset(dataset, dataset.test_orientability_indices)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
trainer = L.Trainer(max_epochs=max_epochs, log_every_n_steps=1)
trainer.fit(model, train_dl, test_dl)
80 changes: 71 additions & 9 deletions experiments/orientability/graphs/GCN.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,81 @@
# TODO
import lightning as L
import torch
import torchvision.transforms as transforms
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import FaceToEdge

from experiments.lightning_modules.GraphCommonModuleOrientability import (
GraphCommonModuleOrientability,
)
from mantra.simplicial import SimplicialDataset
from mantra.transforms import (
TriangulationToFaceTransform,
DegreeTransform,
OrientableToClassTransform,
)
from models.graphs.GCN import GCNetwork


class GCNModule(GraphCommonModuleOrientability):
def __init__(
self,
hidden_channels,
num_node_features,
out_channels,
num_hidden_layers,
learning_rate=0.01,
):
base_model = GCNetwork(
hidden_channels=hidden_channels,
num_node_features=num_node_features,
out_channels=out_channels,
num_hidden_layers=num_hidden_layers,
)
super().__init__(base_model=base_model)
self.learning_rate = learning_rate

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.base_model.parameters(), lr=self.learning_rate
)
return optimizer


def load_dataset_with_transformations():
tr = transforms.Compose(
[
SimplicialComplexTransform(),
SimplicialComplexOnesTransform(ones_length=10),
DimZeroHodgeLaplacianSimplicialComplexTransform(),
DimOneHodgeLaplacianUpSimplicialComplexTransform(),
DimOneHodgeLaplacianDownSimplicialComplexTransform(),
DimTwoHodgeLaplacianSimplicialComplexTransform(),
OrientableToClassSimplicialComplexTransform(),
TriangulationToFaceTransform(),
FaceToEdge(remove_faces=False),
DegreeTransform(),
OrientableToClassTransform(),
]
)
dataset = SimplicialDataset(root="./data", transform=tr)
return dataset


def single_experiment_orientability_gcn():
def single_experiment_orientability_gnn():
# ===============================
# Training parameters
# ===============================
hidden_channels = 64
num_hidden_layers = 2
batch_size = 32
max_epochs = 100
learning_rate = 0.1
# ===============================
dataset = load_dataset_with_transformations()
model = GCNModule(
hidden_channels=hidden_channels,
num_node_features=dataset.num_node_features,
out_channels=1, # Binary classification
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_orientability_indices)
test_ds = Subset(dataset, dataset.test_orientability_indices)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
trainer = L.Trainer(max_epochs=max_epochs, log_every_n_steps=1)
trainer.fit(model, train_dl, test_dl)
104 changes: 92 additions & 12 deletions experiments/orientability/simplicial_complexes/SCNN.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import math
from typing import Literal

import lightning as L
import torch
import torchvision.transforms as transforms
from torch.utils.data import random_split
from torch import nn
from torch.utils.data import Subset

from experiments.lightning_modules.BaseModuleOrientability import (
BaseOrientabilityModule,
)
from mantra.dataloaders import SimplicialDataLoader
from mantra.simplicial import SimplicialDataset
from mantra.transforms import (
Expand All @@ -15,9 +20,84 @@
SimplicialComplexOnesTransform,
)
from mantra.transforms import SimplicialComplexTransform
from models.orientability.simplicial_complexes.SCNNNetworkOrientability import (
SCNNNetwork,
)
from mantra.utils import transfer_simplicial_complex_batch_to_device
from models.simplicial_complexes.SCNN import SCNNNetwork


class SCNNNModule(BaseOrientabilityModule):
def __init__(
self,
rank,
in_channels,
hidden_channels,
out_channels,
conv_order_down,
conv_order_up,
n_layers=3,
learning_rate=0.01,
):
super().__init__()
self.rank = rank
self.learning_rate = learning_rate
self.base_model = SCNNNetwork(
rank=rank,
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
conv_order_down=conv_order_down,
conv_order_up=conv_order_up,
n_layers=n_layers,
)

def forward(self, x, laplacian_down, laplacian_up, signal_belongings):
x = self.base_model(x, laplacian_down, laplacian_up, signal_belongings)
return x

def general_step(
self, batch, batch_idx, step: Literal["train", "test", "validation"]
):
s_complexes, signal_belongings, batch_len = batch
x = s_complexes.signals[self.rank]
if self.rank == 0:
laplacian_down = None
laplacian_up = s_complexes.neighborhood_matrices[f"0_laplacian"]
elif self.rank == 1:
laplacian_down = s_complexes.neighborhood_matrices[
f"1_laplacian_down"
]
laplacian_up = s_complexes.neighborhood_matrices[f"1_laplacian_up"]
elif self.rank == 2:
laplacian_down = s_complexes.neighborhood_matrices[f"2_laplacian"]
laplacian_up = None
else:
raise ValueError("rank must be 0, 1 or 2.")
y = s_complexes.other_features["y"].float()
signal_belongings = signal_belongings[self.rank]
x_hat = self(x, laplacian_down, laplacian_up, signal_belongings)
# Squeeze x_hat to match the shape of y
x_hat = x_hat.squeeze()
loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y)
self.log(
f"{step}_loss",
loss,
prog_bar=True,
batch_size=batch_len,
on_step=False,
on_epoch=True,
)
self.log_accuracies(x_hat, y, batch_len, step)
return loss

def transfer_batch_to_device(self, batch, device, dataloader_idx):
return transfer_simplicial_complex_batch_to_device(
batch, device, dataloader_idx
)

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.base_model.parameters(), lr=self.learning_rate
)
return optimizer


def load_dataset_with_transformations():
Expand Down Expand Up @@ -47,8 +127,9 @@ def single_experiment_orientability_scnn():
hidden_channels = 20
out_channels = 1 # num classes
num_layers = 5
test_percentage = 0.2
batch_size = 128
max_epochs = 100
learning_rate = 0.01
# ===============================
# Checks and dependent parameters
# ===============================
Expand All @@ -62,19 +143,18 @@ def single_experiment_orientability_scnn():
# ===============================
# Model and dataloader creation
# ===============================
model = SCNNNetwork(
model = SCNNNModule(
rank=rank,
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
conv_order_down=conv_order_down,
conv_order_up=conv_order_up,
n_layers=num_layers,
learning_rate=learning_rate,
)
test_len = math.floor(len(dataset) * test_percentage)
train_ds, test_ds = random_split(
dataset, [len(dataset) - test_len, test_len]
)
train_ds = Subset(dataset, dataset.train_orientability_indices)
test_ds = Subset(dataset, dataset.test_orientability_indices)
train_dl = SimplicialDataLoader(
train_ds, batch_size=batch_size, shuffle=True
)
Expand All @@ -84,6 +164,6 @@ def single_experiment_orientability_scnn():
# Use CPU acceleration: SCCNN does not support GPU acceleration because it creates matrices not placed in the
# device of the network.
trainer = L.Trainer(
max_epochs=1000, accelerator="cpu", log_every_n_steps=1
max_epochs=max_epochs, accelerator="cpu", log_every_n_steps=1
)
trainer.fit(model, train_dl, test_dl)
Loading

0 comments on commit 55e48ce

Please sign in to comment.