diff --git a/experiments/betti_numbers/__init__.py b/experiments/betti_numbers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/betti_numbers/graphs/GCN.py b/experiments/betti_numbers/graphs/GCN.py new file mode 100644 index 0000000..93c2b97 --- /dev/null +++ b/experiments/betti_numbers/graphs/GCN.py @@ -0,0 +1,92 @@ +import lightning as L +import torch +import wandb +from torch.utils.data import Subset +from torch_geometric.loader import DataLoader +from torch_geometric.transforms import FaceToEdge, OneHotDegree +from torchvision import transforms + +from experiments.experiment_utils import get_wandb_logger +from experiments.lightning_modules.GraphCommonModuleBettiNumbers import ( + GraphCommonModuleBettiNumbers, +) +from mantra.simplicial import SimplicialDataset +from mantra.transforms import TriangulationToFaceTransform, DegreeTransform +from models.graphs.GCN import GCNetwork + + +class GCNModule(GraphCommonModuleBettiNumbers): + def __init__( + self, + hidden_channels, + num_node_features, + out_channels, + num_hidden_layers, + learning_rate=0.01, + ): + base_model = GCNetwork( + hidden_channels=hidden_channels, + num_node_features=num_node_features, + out_channels=out_channels, + num_hidden_layers=num_hidden_layers, + ) + super().__init__(base_model=base_model) + self.learning_rate = learning_rate + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.base_model.parameters(), lr=self.learning_rate + ) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + TriangulationToFaceTransform(), + FaceToEdge(remove_faces=False), + DegreeTransform(), + OneHotDegree(max_degree=8, cat=False), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_betti_numbers_gnn(): + # =============================== + # Training parameters + # =============================== + hidden_channels = 64 + num_hidden_layers = 2 + batch_size = 32 + max_epochs = 100 + learning_rate = 0.1 + num_workers = 0 + # =============================== + dataset = load_dataset_with_transformations() + model = GCNModule( + hidden_channels=hidden_channels, + num_node_features=dataset.num_node_features, + out_channels=3, # Three different Betti numbers + num_hidden_layers=num_hidden_layers, + learning_rate=learning_rate, + ) + train_ds = Subset(dataset, dataset.train_betti_numbers_indices) + test_ds = Subset(dataset, dataset.test_betti_numbers_indices) + train_dl = DataLoader( + train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers + ) + test_dl = DataLoader( + test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers + ) + logger = get_wandb_logger(task_name="betti_numbers", model_name="GCN") + trainer = L.Trainer( + max_epochs=max_epochs, log_every_n_steps=1, logger=logger + ) + trainer.fit( + model, + train_dl, + test_dl, + ) + wandb.finish() diff --git a/experiments/betti_numbers/graphs/__init__.py b/experiments/betti_numbers/graphs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/lightning_modules/BaseModuleBettiNumbers.py b/experiments/lightning_modules/BaseModuleBettiNumbers.py new file mode 100644 index 0000000..a4fb1d5 --- /dev/null +++ b/experiments/lightning_modules/BaseModuleBettiNumbers.py @@ -0,0 +1,76 @@ +from typing import Literal + +import lightning as L +from torch.nn import ModuleList + +from experiments.metrics import GeneralAccuracy + + +class BaseBettiNumbersModule(L.LightningModule): + def __init__(self): + super().__init__() + # 3 accuracies per modality: one for each betti number + self.training_accuracies = ModuleList( + [GeneralAccuracy() for _ in range(3)] + ) + self.validation_accuracies = ModuleList( + [GeneralAccuracy() for _ in range(3)] + ) + self.test_accuracies = ModuleList( + [GeneralAccuracy() for _ in range(3)] + ) + + def log_scores( + self, x_hat, y, batch_len, step: Literal["train", "test", "validation"] + ): + # x_hat is a float tensor of shape (batch_len, 3), one column per betti number and row per sample + # y is a long tensor of shape (batch_len, 3), one row per sample and column per betti number + if step == "train": + for dim in range(3): + self.training_accuracies[dim]( + x_hat[:, dim].round().long(), y[:, dim].long() + ) + self.log( + f"training_accuracy_betti_{dim}", + self.training_accuracies[dim], + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch_len, + ) + + elif step == "test": + for dim in range(3): + self.test_accuracies[dim]( + x_hat[:, dim].round().long(), y[:, dim].long() + ) + self.log( + f"test_accuracy_betti_{dim}", + self.test_accuracies[dim], + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch_len, + ) + elif step == "validation": + for dim in range(3): + self.validation_accuracies[dim]( + x_hat[:, dim].round().long(), y[:, dim].long() + ) + self.log( + f"validation_accuracy_betti_{dim}", + self.validation_accuracies[dim], + prog_bar=True, + on_step=False, + on_epoch=True, + batch_size=batch_len, + ) + + def test_step(self, batch, batch_idx): + return self.general_step(batch, batch_idx, "test") + + def validation_step(self, batch, batch_idx): + return self.general_step(batch, batch_idx, "validation") + + def training_step(self, batch, batch_idx): + return self.general_step(batch, batch_idx, "train") diff --git a/experiments/lightning_modules/GraphCommonModuleBettiNumbers.py b/experiments/lightning_modules/GraphCommonModuleBettiNumbers.py new file mode 100644 index 0000000..eaa3b19 --- /dev/null +++ b/experiments/lightning_modules/GraphCommonModuleBettiNumbers.py @@ -0,0 +1,34 @@ +import torch +from torch import nn + +from experiments.lightning_modules.BaseModuleBettiNumbers import ( + BaseBettiNumbersModule, +) + + +class GraphCommonModuleBettiNumbers(BaseBettiNumbersModule): + def __init__(self, base_model): + super().__init__() + self.base_model = base_model + + def forward(self, x, edge_index, batch): + x = self.base_model(x, edge_index, batch) + return x + + def general_step(self, batch, batch_idx, step: str): + x_hat = self(batch.x, batch.edge_index, batch.batch) + y = torch.tensor( + batch.betti_numbers, device=x_hat.device, dtype=x_hat.dtype + ) + batch_len = len(y) + loss = nn.functional.mse_loss(x_hat, y) + self.log( + f"{step}_loss", + loss, + prog_bar=True, + batch_size=batch_len, + on_step=False, + on_epoch=True, + ) + self.log_scores(x_hat, y, batch_len, step) + return loss diff --git a/experiments/metrics.py b/experiments/metrics.py new file mode 100644 index 0000000..6551441 --- /dev/null +++ b/experiments/metrics.py @@ -0,0 +1,22 @@ +import torch +from torch import Tensor +from torchmetrics import Metric + + +class GeneralAccuracy(Metric): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.add_state( + "correct", default=torch.tensor(0), dist_reduce_fx="sum" + ) + self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum") + + def update(self, preds: Tensor, target: Tensor) -> None: + if preds.shape != target.shape: + raise ValueError("preds and target must have the same shape") + + self.correct += torch.sum(preds == target) + self.total += target.numel() + + def compute(self) -> Tensor: + return self.correct.float() / self.total diff --git a/experiments/name/__init__.py b/experiments/name/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py index 1fd09cb..0320834 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,6 @@ +from experiments.betti_numbers.graphs.GCN import ( + single_experiment_betti_numbers_gnn, +) from experiments.orientability.graphs.GATSimplex2Vec import ( single_experiment_orientability_gat_simplex2vec, ) @@ -9,6 +12,7 @@ ) if __name__ == "__main__": - single_experiment_orientability_gnn() - single_experiment_orientability_scnn() - single_experiment_orientability_gat_simplex2vec() + # single_experiment_orientability_gnn() + # single_experiment_orientability_scnn() + # single_experiment_orientability_gat_simplex2vec() + single_experiment_betti_numbers_gnn()