Skip to content

Commit

Permalink
Adapting topomodelx models to the benchmark. Currently, SCN is working.
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed Jun 21, 2024
1 parent 7d51283 commit 8252e16
Show file tree
Hide file tree
Showing 20 changed files with 1,048 additions and 354 deletions.
13 changes: 8 additions & 5 deletions datasets/simplicial.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import Callable

from lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from mantra.simplicial import SimplicialDataset
from torch_geometric.loader import DataLoader
from torch_geometric.loader import DataLoader as DataLoaderGeometric
from torch_geometric.transforms import Compose
import torch
from torch.utils.data import Subset
import numpy as np


class SimplicialDataModule(LightningDataModule):
def __init__(
self,
Expand All @@ -16,6 +17,7 @@ def __init__(
use_stratified: bool = False,
batch_size: int = 128,
seed: int = 2024,
dataloader_builder: Callable = DataLoaderGeometric,
):
super().__init__()
self.data_dir = data_dir
Expand All @@ -24,6 +26,7 @@ def __init__(
self.stratified = None
self.batch_size = batch_size
self.seed = seed
self.dataloader_builder = dataloader_builder

def prepare_data(self) -> None:
SimplicialDataset(root=self.data_dir)
Expand All @@ -50,13 +53,13 @@ def setup(self, stage=None):
self.val_ds = Subset(simplicial_full, val_indices)

def train_dataloader(self):
return DataLoader(self.train_ds, batch_size=self.batch_size)
return self.dataloader_builder(self.train_ds, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size)
return self.dataloader_builder(self.val_ds, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.val_ds, batch_size=self.batch_size)
return self.dataloader_builder(self.val_ds, batch_size=self.batch_size)


if __name__ == "__main__":
Expand Down
94 changes: 94 additions & 0 deletions datasets/topox_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from toponetx import SimplicialComplex
from torch.utils.data import DataLoader
from torch_geometric.data import Data


class SimplicialTopoXDataloader(DataLoader):
def __init__(self, dataset, **kwargs):
collate_fn = kwargs.get("collate_fn", collate_simplicial_models_topox)
kwargs = {k: v for k, v in kwargs.items() if k != "collate_fn"}
super().__init__(dataset, collate_fn=collate_fn, **kwargs)


def batch_connectivity_matrices(key, matrices, batch):
rows, columns, values = [], [], []
matrix_dim = int(key.split("_")[-1])
row_idx, col_idx = 0, 0
for matrix, example in zip(matrices, batch):
if matrix is None:
if key.startswith("incidence"):
# We only add rows if the simplicial complex has simplices of dimension matrix_dim - 1.
# because otherwise we do not have simplices of the dimensions represented by the rows and columns.
if example.dim == matrix_dim - 1:
# The simplicial complex has a non-empty set of simplices of dimension matrix_dim - 1,
# but it does not have simplices of dimension matrix_dim
row_idx += example.shape(matrix_dim - 1)
elif (key.startswith("down_laplacian") or key.startswith("up_laplacian")
or key.startswith("hodge_laplacian")) or key.startswith("adjacency"):
# We do not do nothing, as if the matrix is None, it is because there are no cells
# of that dimension in the cell complex so we do not need to add any row or column
pass
else:
raise NotImplementedError(f"{key} is not valid connectivity matrix.")
else:
indices = matrix.indices()
rows_submatrix = indices[0]
cols_submatrix = indices[1]
rows.append(rows_submatrix + row_idx)
columns.append(cols_submatrix + col_idx)
values.append(matrix.values())
row_idx += matrix.shape[0]
col_idx += matrix.shape[1]
rows_cat = torch.cat(rows, dim=0)
columns_cat = torch.cat(columns, dim=0)
values_cat = torch.cat(values, dim=0)
return torch.sparse_coo_tensor(torch.stack([rows_cat, columns_cat]), values_cat, (row_idx, col_idx))


def collate_connectivity_matrices(batch):
connectivity_batched = dict()
connectivity_keys = set([key
for example in batch
if example.connectivity is not None
for key in example.connectivity.keys()])
for key in connectivity_keys:
connectivity_batched[key] = batch_connectivity_matrices(key,
[example.connectivity[key]
if key in example.connectivity
else None
for example in batch],
batch)
return connectivity_batched


def collate_signals(batch):
x_batched = dict()
all_x_keys = set([key for example in batch for key in example.x.keys()])
x_belonging = dict()
for key in all_x_keys:
x_to_batch = [
example.x[key] for example in batch if key in example.x
]
x_batched[key] = torch.cat(x_to_batch, dim=0)
signals_of_dim_belonging = [
torch.tensor([i] * len(batch[i].x[key]), dtype=torch.int64)
for i in range(len(batch))
if key in batch[i].x
]
x_belonging[key] = torch.cat(signals_of_dim_belonging, dim=0)
return x_batched, x_belonging


def collate_simplicial_models_topox(batch):
batched_data = Data()
batched_data.batch_size = len(batch)
# First, batch signals
x_batched, x_belonging = collate_signals(batch)
batched_data.x = x_batched
batched_data.x_belonging = x_belonging
# Second, batch structure matrices
batched_data.connectivity = collate_connectivity_matrices(batch)
# Third, batch output
batched_data.y = torch.cat([example.y for example in batch], dim=0)
return batched_data
102 changes: 33 additions & 69 deletions datasets/transforms.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import torch
from toponetx.classes import SimplicialComplex
from toponetx.utils import (
compute_bunch_normalized_matrices,
compute_x_laplacian_normalized_matrix,
)
from torch_geometric.utils import degree
from torch_geometric.transforms import FaceToEdge, OneHotDegree
from datasets.utils import (
create_signals_on_data_if_needed,
append_signals,
create_other_features_on_data_if_needed,
create_neighborhood_matrices_on_data_if_needed,
get_complex_connectivity,
)
from enum import Enum

Expand Down Expand Up @@ -52,8 +47,6 @@ class SimplicialComplexTransform:
def __call__(self, data):
data.sc = SimplicialComplex(data.triangulation)
create_signals_on_data_if_needed(data)
create_other_features_on_data_if_needed(data)
create_neighborhood_matrices_on_data_if_needed(data)
return data


Expand Down Expand Up @@ -111,65 +104,13 @@ def __call__(self, data):

class OrientableToClassSimplicialComplexTransform:
def __call__(self, data):
data = create_other_features_on_data_if_needed(data)
data.other_features["y"] = data.orientable.long()
return data


class SCNNNeighborhoodMatricesTransform:
class SimplicialComplexStructureMatricesTransform:
def __call__(self, data):
data = create_neighborhood_matrices_on_data_if_needed(data)
data.neighborhood_matrices["1_boundary"] = data.sc.incidence_matrix(1)
data.neighborhood_matrices["2_boundary"] = data.sc.incidence_matrix(2)
data.neighborhood_matrices["0_laplacian"] = data.sc.laplacian_matrix(0)
data.neighborhood_matrices["1_laplacian_up"] = (
data.sc.up_laplacian_matrix(rank=1)
)
data.neighborhood_matrices["1_laplacian_down"] = (
data.sc.down_laplacian_matrix(rank=1)
)
data.neighborhood_matrices["1_laplacian"] = (
data.sc.hodge_laplacian_matrix(rank=1)
)
data.neighborhood_matrices["2_laplacian"] = (
data.sc.hodge_laplacian_matrix(rank=2)
)
return data


class SCConvNeighborhoodMatricesTransform:
def __call__(self, data):
B1 = data.sc.incidence_matrix(1)
B2 = data.sc.incidence_matrix(2)
B1N, B1TN, B2N, B2TN = compute_bunch_normalized_matrices(B1, B2)
data = create_neighborhood_matrices_on_data_if_needed(data)
data.neighborhood_matrices["1_boundary"] = B1
data.neighborhood_matrices["2_boundary"] = B2
data.neighborhood_matrices["1_boundary_norm"] = B1N
data.neighborhood_matrices["2_boundary_norm"] = B2N
data.neighborhood_matrices["1_boundary_transpose_norm"] = B1TN
data.neighborhood_matrices["2_boundary_transpose_norm"] = B2TN
# Matrices normalized using the normalization given by TopoNetX. For incidence matrices,
# it coincides with the normalization of the paper. For the Laplacian matrices, however, it does not coincide.
L0_up = data.sc.up_laplacian_matrix(0)
L1_down = data.sc.down_laplacian_matrix(1)
L1_up = data.sc.up_laplacian_matrix(1)
L2_down = data.sc.down_laplacian_matrix(2)
L0 = L0_up
L1 = L1_down + L1_up
L2 = L2_down
data.neighborhood_matrices["0_laplacian_up_norm"] = (
compute_x_laplacian_normalized_matrix(L0, L0_up)
)
data.neighborhood_matrices["1_laplacian_up_norm"] = (
compute_x_laplacian_normalized_matrix(L1, L1_up)
)
data.neighborhood_matrices["1_laplacian_down_norm"] = (
compute_x_laplacian_normalized_matrix(L1, L1_down)
)
data.neighborhood_matrices["2_laplacian_down_norm"] = (
compute_x_laplacian_normalized_matrix(L2, L2_down)
)
data.connectivity = get_complex_connectivity(data.sc, data.sc.dim)
return data


Expand Down Expand Up @@ -200,24 +141,29 @@ def __init__(self):
}

def __call__(self, data):
data = create_other_features_on_data_if_needed(data)
data.other_features["y"] = torch.tensor([self.class_dict[data.name]])
return data


class BettiNumbersToTargetSimplicialComplexTransform:
class RandomNodeFeatures:
def __init__(self, size):
self.size = size

def __call__(self, data):
data = create_other_features_on_data_if_needed(data)
data.other_features["y"] = torch.tensor([data.betti_numbers])
data.x = torch.normal(0, 1, size=(data.num_nodes, self.size))
return data


class RandomNodeFeatures:
class RandomSimplicesFeatures:
def __init__(self, size):
self.size = size

def __call__(self, data):
data.x = torch.normal(0, 1, size=(data.num_nodes, self.size))
data = create_signals_on_data_if_needed(data)
for dim in range(data.sc.dim):
data = append_signals(
data, dim, torch.normal(0, 1, size=(data.sc.shape[dim], self.size))
)
return data


Expand All @@ -233,7 +179,6 @@ def __call__(self, data):
OneHotDegree(max_degree=8),
]


degree_transform = [
TriangulationToFaceTransform(),
FaceToEdge(remove_faces=False),
Expand All @@ -250,15 +195,34 @@ def __call__(self, data):
BettiToY(),
]

degree_transform_sc = [
SimplicialComplexTransform(),
SimplicialComplexStructureMatricesTransform(),
SimplicialComplexDegreeTransform(),
SimplicialComplexEdgeAdjacencyDegreeTransform(),
SimplicialComplexEdgeCoadjacencyDegreeTransform(),
SimplicialComplexTriangleCoadjacencyDegreeTransform()
]

random_simplices_features = [
SimplicialComplexTransform(),
SimplicialComplexStructureMatricesTransform(),
RandomSimplicesFeatures(size=8),
]


class TransformType(Enum):
degree_transform = "degree_transform"
degree_transform_onehot = "degree_transform_onehot"
random_node_features = "random_node_features"
degree_transform_sc = "degree_transform_sc"
random_simplices_features = "random_simplices_features"


transforms_lookup = {
TransformType.degree_transform: degree_transform,
TransformType.degree_transform_onehot: degree_transform_onehot,
TransformType.random_node_features: random_node_features,
TransformType.degree_transform_sc: degree_transform_sc,
TransformType.random_simplices_features: random_simplices_features
}
Loading

0 comments on commit 8252e16

Please sign in to comment.