Skip to content

Commit

Permalink
Started working on toponetx, commit does not work
Browse files Browse the repository at this point in the history
  • Loading branch information
ErnstRoell committed Jun 14, 2024
1 parent 7e59a18 commit 46520bd
Show file tree
Hide file tree
Showing 7 changed files with 655 additions and 1 deletion.
19 changes: 19 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
"""
Function to load datasets based on configuration key.
"""

import importlib
from typing import Protocol


class DataModuleConfig(Protocol):
"""
The module key is a string that points towards
the
"""
module: str


def load_datamodule(config: dict):
module = importlib.import_module(config.module)
return module.DataModule(config)
289 changes: 289 additions & 0 deletions datasets/toponetx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
"""
Dataloaders for the toponetx models
"""

from typing import Any, Optional

import numpy as np
import scipy
from sklearn.model_selection import train_test_split
import torch.utils.data
from torch.utils.data.dataloader import DataLoader
from torch_geometric.transforms import Compose, FaceToEdge
from mantra.simplicial import SimplicialDataset
from lightning import LightningDataModule
from torch.utils.data import Subset
from datasets.transforms import (
BettiNumbersToTargetSimplicialComplexTransform,
DimOneHodgeLaplacianDownSimplicialComplexTransform,
DimOneHodgeLaplacianUpSimplicialComplexTransform,
DimTwoHodgeLaplacianSimplicialComplexTransform,
DimZeroHodgeLaplacianSimplicialComplexTransform,
SimplicialComplexDegreeTransform,
SimplicialComplexOnesTransform,
SimplicialComplexTransform,
)


class BatchedSimplicialComplex:
def __init__(
self,
signals: dict[Any, torch.Tensor],
neighborhood_matrices: Optional[
dict[Any, scipy.sparse.spmatrix]
] = None,
other_features: Optional[dict[Any, torch.Tensor]] = None,
):
self.signals = signals
self.neighborhood_matrices = neighborhood_matrices
self.other_features = other_features


def batch_signals(signals):
return torch.cat(signals, dim=0)


def convert_sparse_matrices_to_sparse_block_matrix(
key: Any,
sparse_matrices: list[Optional[scipy.sparse.spmatrix]],
batch: list[BatchedSimplicialComplex],
):
rows, columns, values = [], [], []
idx_rows, idx_cols = 0, 0
for matrix, c_complex in zip(sparse_matrices, batch):
if matrix is None:
# Depending on the neighborhood matrix, we need to perform different operations
match key:
case (
"0_adjacency"
| "1_coadjacency"
| "1_adjacency"
| "2_coadjacency"
| "1_laplacian"
| "2_laplacian"
| "0_laplacian"
| "1_laplacian_up"
| "1_laplacian_down"
):
# We do not do nothing, as if the matrix is None, it is
# because there are no cells of that dimension in the cell
# complex so we do not need to add any row or column
pass
case "1_boundary":
# Only one possibility: the dimension of the cell complex is
# 0, and therefore we need to add only as many rows as nodes
# there are in the cell complex
zero_cells_cardinality = len(c_complex.signals[0])
idx_rows += zero_cells_cardinality
case "2_boundary":
# Only two possibilities: the dimension of the cell complex
# is 0 or 1. If the dimension is 0, we do not add any row or
# column. If the dimension is 1, we need to add as many rows
# as edges there are in the cell complex
if 1 not in c_complex.signals:
pass
else:
one_cells_cardinality = len(c_complex.signals[1])
idx_rows += one_cells_cardinality
else:
coo_matrix = matrix.tocoo()
len_rows, len_cols = coo_matrix.shape
rows_example, cols_example, values_example = (
coo_matrix.row,
coo_matrix.col,
coo_matrix.data,
)
rows_abs, cols_abs = (
rows_example + idx_rows,
cols_example + idx_cols,
)
rows.append(rows_abs)
columns.append(cols_abs)
values.append(values_example)
idx_rows += len_rows
idx_cols += len_cols
rows_cat = np.concatenate(rows, axis=0)
columns_cat = np.concatenate(columns, axis=0)
values_cat = np.concatenate(values, axis=0)
return scipy.sparse.coo_matrix(
(values_cat, (rows_cat, columns_cat)), shape=(idx_rows, idx_cols)
)


def collate_signals(batch):
all_signals_keys = set(
[key for example in batch for key in example.signals]
)
signals = dict()
signals_belonging = dict()
for key in all_signals_keys:
signals_to_batch = [
example.signals[key] for example in batch if key in example.signals
]
signals[key] = batch_signals(signals_to_batch)
signals_of_dim_belonging = [
torch.tensor([i] * len(batch[i].signals[key]), dtype=torch.int64)
for i in range(len(batch))
if key in batch[i].signals
]
signals_belonging[key] = torch.cat(signals_of_dim_belonging, dim=0)
return signals, signals_belonging


def collate_other_features(batch):
feature_names = set(
[
key
for example in batch
if example.other_features is not None
for key in example.other_features.keys()
]
)
if len(feature_names) == 0:
other_features = None
else:
other_features = {}
for key in feature_names:
other_features[key] = torch.cat(
[
example.other_features[key]
for example in batch
if key in example.other_features
],
dim=0,
)
return other_features


def collate_neighborhood_matrices(batch):
all_neighborhood_matrices_keys = set(
[
key
for example in batch
if example.neighborhood_matrices is not None
for key in example.neighborhood_matrices
]
)
neighborhood_matrices = dict()
for key in all_neighborhood_matrices_keys:
neighborhood_matrices[key] = (
convert_sparse_matrices_to_sparse_block_matrix(
key,
[
# Get the neighborhood matrix if it exists, otherwise None
(
example.neighborhood_matrices[key]
if key in example.neighborhood_matrices
else None
)
for example in batch
],
batch,
)
)
return neighborhood_matrices


def generate_batched_simplicial_complex_from_data(data):
return BatchedSimplicialComplex(
signals=data.x,
neighborhood_matrices=data.neighborhood_matrices,
other_features=data.other_features,
)


def collate(batch):
batch = [
generate_batched_simplicial_complex_from_data(example)
for example in batch
]
# Collate the signals and make a belonging vector
signals, signals_belonging = collate_signals(batch)
# Concatenate other features. Take all the feature names from all the examples in the batch.
other_features = collate_other_features(batch)
# Concatenate neighborhood matrices
neighborhood_matrices = collate_neighborhood_matrices(batch)

return (
BatchedSimplicialComplex(
signals,
neighborhood_matrices=neighborhood_matrices,
other_features=other_features,
),
signals_belonging,
len(batch),
)


class SimplicialDataLoader(DataLoader):
def __init__(self, dataset, **kwargs):
collate_fn = kwargs.get("collate_fn", collate)
kwargs = {k: v for k, v in kwargs.items() if k != "collate_fn"}
super().__init__(dataset, collate_fn=collate_fn, **kwargs)


################################################################################
### Simplicial datamodule for toponetx.
################################################################################


class TopoNetXDataModule(LightningDataModule):
def __init__(
self,
data_dir: str = "./data",
transform: Compose | None = None,
use_stratified: bool = False,
batch_size: int = 128,
seed: int = 2024,
):
super().__init__()
base_transform = [
SimplicialComplexTransform(),
DimZeroHodgeLaplacianSimplicialComplexTransform(),
DimOneHodgeLaplacianUpSimplicialComplexTransform(),
DimOneHodgeLaplacianDownSimplicialComplexTransform(),
DimTwoHodgeLaplacianSimplicialComplexTransform(),
BettiNumbersToTargetSimplicialComplexTransform(),
SimplicialComplexOnesTransform(),
]
if transform is not None:
base_transform += transform
self.data_dir = data_dir
self.transform = Compose(base_transform)
self.use_stratified = use_stratified
self.stratified = None
self.batch_size = batch_size
self.seed = seed

def prepare_data(self) -> None:
SimplicialDataset(root=self.data_dir)

def setup(self, stage=None):
simplicial_full = SimplicialDataset(
root=self.data_dir, transform=self.transform
)
if self.use_stratified:
self.stratified = torch.vstack(
[data.y for data in simplicial_full]
)

indices_dataset = np.arange(len(simplicial_full))

train_indices, val_indices = train_test_split(
indices_dataset,
test_size=0.2,
shuffle=True,
stratify=self.stratified,
# random_state=RandomState(self.seed),
)
self.train_ds = Subset(simplicial_full, train_indices)
self.val_ds = Subset(simplicial_full, val_indices)

def train_dataloader(self):
return SimplicialDataLoader(self.train_ds, batch_size=self.batch_size)

def val_dataloader(self):
return SimplicialDataLoader(self.val_ds, batch_size=self.batch_size)

def test_dataloader(self):
return SimplicialDataLoader(self.val_ds, batch_size=self.batch_size)
Loading

0 comments on commit 46520bd

Please sign in to comment.