From 0c608595a181490f7c9177de74675c7302469f0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rub=C3=A9n=20Ballester?= Date: Mon, 6 May 2024 13:36:58 +0200 Subject: [PATCH] Added SCN to Betti number experiments. --- .../simplicial_complexes/SCNN.py | 163 ++++++++++++++++++ .../simplicial_complexes/__init__.py | 0 main.py | 6 +- mantra/transforms.py | 7 + 4 files changed, 175 insertions(+), 1 deletion(-) create mode 100644 experiments/betti_numbers/simplicial_complexes/SCNN.py create mode 100644 experiments/betti_numbers/simplicial_complexes/__init__.py diff --git a/experiments/betti_numbers/simplicial_complexes/SCNN.py b/experiments/betti_numbers/simplicial_complexes/SCNN.py new file mode 100644 index 0000000..d30b7c0 --- /dev/null +++ b/experiments/betti_numbers/simplicial_complexes/SCNN.py @@ -0,0 +1,163 @@ +from typing import Literal + +import torch +import torchvision.transforms as transforms +from torch import nn + +from experiments.experiment_utils import perform_experiment +from experiments.lightning_modules.BaseModuleBettiNumbers import ( + BaseBettiNumbersModule, +) +from mantra.dataloaders import SimplicialDataLoader +from mantra.simplicial import SimplicialDataset +from mantra.transforms import ( + DimTwoHodgeLaplacianSimplicialComplexTransform, + DimOneHodgeLaplacianDownSimplicialComplexTransform, + DimOneHodgeLaplacianUpSimplicialComplexTransform, + DimZeroHodgeLaplacianSimplicialComplexTransform, + SimplicialComplexOnesTransform, + BettiNumbersToTargetSimplicialComplexTransform, +) +from mantra.transforms import SimplicialComplexTransform +from mantra.utils import transfer_simplicial_complex_batch_to_device +from models.simplicial_complexes.SCNN import SCNNNetwork + + +class SCNNNModule(BaseBettiNumbersModule): + def __init__( + self, + rank, + in_channels, + hidden_channels, + out_channels, + conv_order_down, + conv_order_up, + n_layers=3, + learning_rate=0.01, + ): + super().__init__() + self.rank = rank + self.learning_rate = learning_rate + self.base_model = SCNNNetwork( + rank=rank, + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=out_channels, + conv_order_down=conv_order_down, + conv_order_up=conv_order_up, + n_layers=n_layers, + ) + + def forward(self, x, laplacian_down, laplacian_up, signal_belongings): + x = self.base_model(x, laplacian_down, laplacian_up, signal_belongings) + return x + + def general_step( + self, batch, batch_idx, step: Literal["train", "test", "validation"] + ): + s_complexes, signal_belongings, batch_len = batch + x = s_complexes.signals[self.rank] + if self.rank == 0: + laplacian_down = None + laplacian_up = s_complexes.neighborhood_matrices[f"0_laplacian"] + elif self.rank == 1: + laplacian_down = s_complexes.neighborhood_matrices[ + f"1_laplacian_down" + ] + laplacian_up = s_complexes.neighborhood_matrices[f"1_laplacian_up"] + elif self.rank == 2: + laplacian_down = s_complexes.neighborhood_matrices[f"2_laplacian"] + laplacian_up = None + else: + raise ValueError("rank must be 0, 1 or 2.") + y = s_complexes.other_features["y"].to(torch.float32) + signal_belongings = signal_belongings[self.rank] + x_hat = self(x, laplacian_down, laplacian_up, signal_belongings) + loss = nn.functional.mse_loss(x_hat, y) + self.log( + f"{step}_loss", + loss, + prog_bar=True, + batch_size=batch_len, + on_step=False, + on_epoch=True, + ) + self.log_scores(x_hat, y, batch_len, step) + return loss + + def transfer_batch_to_device(self, batch, device, dataloader_idx): + return transfer_simplicial_complex_batch_to_device( + batch, device, dataloader_idx + ) + + def configure_optimizers(self): + optimizer = torch.optim.Adam( + self.base_model.parameters(), lr=self.learning_rate + ) + return optimizer + + +def load_dataset_with_transformations(): + tr = transforms.Compose( + [ + SimplicialComplexTransform(), + SimplicialComplexOnesTransform(ones_length=10), + DimZeroHodgeLaplacianSimplicialComplexTransform(), + DimOneHodgeLaplacianUpSimplicialComplexTransform(), + DimOneHodgeLaplacianDownSimplicialComplexTransform(), + DimTwoHodgeLaplacianSimplicialComplexTransform(), + BettiNumbersToTargetSimplicialComplexTransform(), + ] + ) + dataset = SimplicialDataset(root="./data", transform=tr) + return dataset + + +def single_experiment_betti_numbers_scnn(): + dataset = load_dataset_with_transformations() + # =============================== + # Training parameters + # =============================== + rank = 1 # We work with edge features + conv_order_down = 2 # TODO: No idea of what this parameter does + conv_order_up = 2 # TODO: No idea of what this parameter does + hidden_channels = 20 + num_layers = 5 + batch_size = 128 + max_epochs = 100 + learning_rate = 0.01 + num_workers = 0 + # =============================== + # Checks and dependent parameters + # =============================== + # Check the rank has an appropriate value. + assert 0 <= rank <= 2, "rank must be 0, 1 or 2." + # select the simplex level + if rank == 0: + conv_order_down = 0 + # configure parameters + in_channels = dataset[0].x[rank].shape[1] + # =============================== + # Model and dataloader creation + # =============================== + model = SCNNNModule( + rank=rank, + in_channels=in_channels, + hidden_channels=hidden_channels, + out_channels=3, # Betti numbers + conv_order_down=conv_order_down, + conv_order_up=conv_order_up, + n_layers=num_layers, + learning_rate=learning_rate, + ) + perform_experiment( + task="betti_numbers", + model=model, + model_name="SCNN", + dataset=dataset, + batch_size=batch_size, + num_workers=num_workers, + max_epochs=max_epochs, + data_loader_class=SimplicialDataLoader, + accelerator="cpu", + ) diff --git a/experiments/betti_numbers/simplicial_complexes/__init__.py b/experiments/betti_numbers/simplicial_complexes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/main.py b/main.py index c29f3b1..eca20bd 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,6 @@ +from experiments.betti_numbers.simplicial_complexes.SCNN import ( + single_experiment_betti_numbers_scnn, +) from experiments.name.simplicial_complexes.SCNN import ( single_experiment_name_scnn, ) @@ -17,4 +20,5 @@ # single_experiment_orientability_tag() # single_experiment_orientability_transformer_conv() # single_experiment_orientability_mlp_constant_shape() - single_experiment_name_scnn() + # single_experiment_name_scnn() + single_experiment_betti_numbers_scnn() diff --git a/mantra/transforms.py b/mantra/transforms.py index 9a9cb97..39486af 100644 --- a/mantra/transforms.py +++ b/mantra/transforms.py @@ -242,3 +242,10 @@ def __call__(self, data): data = create_other_features_on_data_if_needed(data) data.other_features["y"] = torch.tensor([self.class_dict[data.name]]) return data + + +class BettiNumbersToTargetSimplicialComplexTransform: + def __call__(self, data): + data = create_other_features_on_data_if_needed(data) + data.other_features["y"] = torch.tensor([data.betti_numbers]) + return data