Skip to content

Commit

Permalink
Adding betti number task and GCN architectures on it for pytorch ligh…
Browse files Browse the repository at this point in the history
…tning.
  • Loading branch information
rballeba committed May 5, 2024
1 parent 609c648 commit eea48a0
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 3 deletions.
Empty file.
92 changes: 92 additions & 0 deletions experiments/betti_numbers/graphs/GCN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import lightning as L
import torch
import wandb
from torch.utils.data import Subset
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import FaceToEdge, OneHotDegree
from torchvision import transforms

from experiments.experiment_utils import get_wandb_logger
from experiments.lightning_modules.GraphCommonModuleBettiNumbers import (
GraphCommonModuleBettiNumbers,
)
from mantra.simplicial import SimplicialDataset
from mantra.transforms import TriangulationToFaceTransform, DegreeTransform
from models.graphs.GCN import GCNetwork


class GCNModule(GraphCommonModuleBettiNumbers):
def __init__(
self,
hidden_channels,
num_node_features,
out_channels,
num_hidden_layers,
learning_rate=0.01,
):
base_model = GCNetwork(
hidden_channels=hidden_channels,
num_node_features=num_node_features,
out_channels=out_channels,
num_hidden_layers=num_hidden_layers,
)
super().__init__(base_model=base_model)
self.learning_rate = learning_rate

def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.base_model.parameters(), lr=self.learning_rate
)
return optimizer


def load_dataset_with_transformations():
tr = transforms.Compose(
[
TriangulationToFaceTransform(),
FaceToEdge(remove_faces=False),
DegreeTransform(),
OneHotDegree(max_degree=8, cat=False),
]
)
dataset = SimplicialDataset(root="./data", transform=tr)
return dataset


def single_experiment_betti_numbers_gnn():
# ===============================
# Training parameters
# ===============================
hidden_channels = 64
num_hidden_layers = 2
batch_size = 32
max_epochs = 100
learning_rate = 0.1
num_workers = 0
# ===============================
dataset = load_dataset_with_transformations()
model = GCNModule(
hidden_channels=hidden_channels,
num_node_features=dataset.num_node_features,
out_channels=3, # Three different Betti numbers
num_hidden_layers=num_hidden_layers,
learning_rate=learning_rate,
)
train_ds = Subset(dataset, dataset.train_betti_numbers_indices)
test_ds = Subset(dataset, dataset.test_betti_numbers_indices)
train_dl = DataLoader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
test_dl = DataLoader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
logger = get_wandb_logger(task_name="betti_numbers", model_name="GCN")
trainer = L.Trainer(
max_epochs=max_epochs, log_every_n_steps=1, logger=logger
)
trainer.fit(
model,
train_dl,
test_dl,
)
wandb.finish()
Empty file.
76 changes: 76 additions & 0 deletions experiments/lightning_modules/BaseModuleBettiNumbers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Literal

import lightning as L
from torch.nn import ModuleList

from experiments.metrics import GeneralAccuracy


class BaseBettiNumbersModule(L.LightningModule):
def __init__(self):
super().__init__()
# 3 accuracies per modality: one for each betti number
self.training_accuracies = ModuleList(
[GeneralAccuracy() for _ in range(3)]
)
self.validation_accuracies = ModuleList(
[GeneralAccuracy() for _ in range(3)]
)
self.test_accuracies = ModuleList(
[GeneralAccuracy() for _ in range(3)]
)

def log_scores(
self, x_hat, y, batch_len, step: Literal["train", "test", "validation"]
):
# x_hat is a float tensor of shape (batch_len, 3), one column per betti number and row per sample
# y is a long tensor of shape (batch_len, 3), one row per sample and column per betti number
if step == "train":
for dim in range(3):
self.training_accuracies[dim](
x_hat[:, dim].round().long(), y[:, dim].long()
)
self.log(
f"training_accuracy_betti_{dim}",
self.training_accuracies[dim],
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch_len,
)

elif step == "test":
for dim in range(3):
self.test_accuracies[dim](
x_hat[:, dim].round().long(), y[:, dim].long()
)
self.log(
f"test_accuracy_betti_{dim}",
self.test_accuracies[dim],
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch_len,
)
elif step == "validation":
for dim in range(3):
self.validation_accuracies[dim](
x_hat[:, dim].round().long(), y[:, dim].long()
)
self.log(
f"validation_accuracy_betti_{dim}",
self.validation_accuracies[dim],
prog_bar=True,
on_step=False,
on_epoch=True,
batch_size=batch_len,
)

def test_step(self, batch, batch_idx):
return self.general_step(batch, batch_idx, "test")

def validation_step(self, batch, batch_idx):
return self.general_step(batch, batch_idx, "validation")

def training_step(self, batch, batch_idx):
return self.general_step(batch, batch_idx, "train")
34 changes: 34 additions & 0 deletions experiments/lightning_modules/GraphCommonModuleBettiNumbers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from torch import nn

from experiments.lightning_modules.BaseModuleBettiNumbers import (
BaseBettiNumbersModule,
)


class GraphCommonModuleBettiNumbers(BaseBettiNumbersModule):
def __init__(self, base_model):
super().__init__()
self.base_model = base_model

def forward(self, x, edge_index, batch):
x = self.base_model(x, edge_index, batch)
return x

def general_step(self, batch, batch_idx, step: str):
x_hat = self(batch.x, batch.edge_index, batch.batch)
y = torch.tensor(
batch.betti_numbers, device=x_hat.device, dtype=x_hat.dtype
)
batch_len = len(y)
loss = nn.functional.mse_loss(x_hat, y)
self.log(
f"{step}_loss",
loss,
prog_bar=True,
batch_size=batch_len,
on_step=False,
on_epoch=True,
)
self.log_scores(x_hat, y, batch_len, step)
return loss
22 changes: 22 additions & 0 deletions experiments/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import torch
from torch import Tensor
from torchmetrics import Metric


class GeneralAccuracy(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state(
"correct", default=torch.tensor(0), dist_reduce_fx="sum"
)
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

def update(self, preds: Tensor, target: Tensor) -> None:
if preds.shape != target.shape:
raise ValueError("preds and target must have the same shape")

self.correct += torch.sum(preds == target)
self.total += target.numel()

def compute(self) -> Tensor:
return self.correct.float() / self.total
Empty file added experiments/name/__init__.py
Empty file.
10 changes: 7 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from experiments.betti_numbers.graphs.GCN import (
single_experiment_betti_numbers_gnn,
)
from experiments.orientability.graphs.GATSimplex2Vec import (
single_experiment_orientability_gat_simplex2vec,
)
Expand All @@ -9,6 +12,7 @@
)

if __name__ == "__main__":
single_experiment_orientability_gnn()
single_experiment_orientability_scnn()
single_experiment_orientability_gat_simplex2vec()
# single_experiment_orientability_gnn()
# single_experiment_orientability_scnn()
# single_experiment_orientability_gat_simplex2vec()
single_experiment_betti_numbers_gnn()

0 comments on commit eea48a0

Please sign in to comment.