-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Started working on toponetx, commit does not work
- Loading branch information
1 parent
7e59a18
commit 46520bd
Showing
7 changed files
with
655 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
""" | ||
Function to load datasets based on configuration key. | ||
""" | ||
|
||
import importlib | ||
from typing import Protocol | ||
|
||
|
||
class DataModuleConfig(Protocol): | ||
""" | ||
The module key is a string that points towards | ||
the | ||
""" | ||
module: str | ||
|
||
|
||
def load_datamodule(config: dict): | ||
module = importlib.import_module(config.module) | ||
return module.DataModule(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
""" | ||
Dataloaders for the toponetx models | ||
""" | ||
|
||
from typing import Any, Optional | ||
|
||
import numpy as np | ||
import scipy | ||
from sklearn.model_selection import train_test_split | ||
import torch.utils.data | ||
from torch.utils.data.dataloader import DataLoader | ||
from torch_geometric.transforms import Compose, FaceToEdge | ||
from mantra.simplicial import SimplicialDataset | ||
from lightning import LightningDataModule | ||
from torch.utils.data import Subset | ||
from datasets.transforms import ( | ||
BettiNumbersToTargetSimplicialComplexTransform, | ||
DimOneHodgeLaplacianDownSimplicialComplexTransform, | ||
DimOneHodgeLaplacianUpSimplicialComplexTransform, | ||
DimTwoHodgeLaplacianSimplicialComplexTransform, | ||
DimZeroHodgeLaplacianSimplicialComplexTransform, | ||
SimplicialComplexDegreeTransform, | ||
SimplicialComplexOnesTransform, | ||
SimplicialComplexTransform, | ||
) | ||
|
||
|
||
class BatchedSimplicialComplex: | ||
def __init__( | ||
self, | ||
signals: dict[Any, torch.Tensor], | ||
neighborhood_matrices: Optional[ | ||
dict[Any, scipy.sparse.spmatrix] | ||
] = None, | ||
other_features: Optional[dict[Any, torch.Tensor]] = None, | ||
): | ||
self.signals = signals | ||
self.neighborhood_matrices = neighborhood_matrices | ||
self.other_features = other_features | ||
|
||
|
||
def batch_signals(signals): | ||
return torch.cat(signals, dim=0) | ||
|
||
|
||
def convert_sparse_matrices_to_sparse_block_matrix( | ||
key: Any, | ||
sparse_matrices: list[Optional[scipy.sparse.spmatrix]], | ||
batch: list[BatchedSimplicialComplex], | ||
): | ||
rows, columns, values = [], [], [] | ||
idx_rows, idx_cols = 0, 0 | ||
for matrix, c_complex in zip(sparse_matrices, batch): | ||
if matrix is None: | ||
# Depending on the neighborhood matrix, we need to perform different operations | ||
match key: | ||
case ( | ||
"0_adjacency" | ||
| "1_coadjacency" | ||
| "1_adjacency" | ||
| "2_coadjacency" | ||
| "1_laplacian" | ||
| "2_laplacian" | ||
| "0_laplacian" | ||
| "1_laplacian_up" | ||
| "1_laplacian_down" | ||
): | ||
# We do not do nothing, as if the matrix is None, it is | ||
# because there are no cells of that dimension in the cell | ||
# complex so we do not need to add any row or column | ||
pass | ||
case "1_boundary": | ||
# Only one possibility: the dimension of the cell complex is | ||
# 0, and therefore we need to add only as many rows as nodes | ||
# there are in the cell complex | ||
zero_cells_cardinality = len(c_complex.signals[0]) | ||
idx_rows += zero_cells_cardinality | ||
case "2_boundary": | ||
# Only two possibilities: the dimension of the cell complex | ||
# is 0 or 1. If the dimension is 0, we do not add any row or | ||
# column. If the dimension is 1, we need to add as many rows | ||
# as edges there are in the cell complex | ||
if 1 not in c_complex.signals: | ||
pass | ||
else: | ||
one_cells_cardinality = len(c_complex.signals[1]) | ||
idx_rows += one_cells_cardinality | ||
else: | ||
coo_matrix = matrix.tocoo() | ||
len_rows, len_cols = coo_matrix.shape | ||
rows_example, cols_example, values_example = ( | ||
coo_matrix.row, | ||
coo_matrix.col, | ||
coo_matrix.data, | ||
) | ||
rows_abs, cols_abs = ( | ||
rows_example + idx_rows, | ||
cols_example + idx_cols, | ||
) | ||
rows.append(rows_abs) | ||
columns.append(cols_abs) | ||
values.append(values_example) | ||
idx_rows += len_rows | ||
idx_cols += len_cols | ||
rows_cat = np.concatenate(rows, axis=0) | ||
columns_cat = np.concatenate(columns, axis=0) | ||
values_cat = np.concatenate(values, axis=0) | ||
return scipy.sparse.coo_matrix( | ||
(values_cat, (rows_cat, columns_cat)), shape=(idx_rows, idx_cols) | ||
) | ||
|
||
|
||
def collate_signals(batch): | ||
all_signals_keys = set( | ||
[key for example in batch for key in example.signals] | ||
) | ||
signals = dict() | ||
signals_belonging = dict() | ||
for key in all_signals_keys: | ||
signals_to_batch = [ | ||
example.signals[key] for example in batch if key in example.signals | ||
] | ||
signals[key] = batch_signals(signals_to_batch) | ||
signals_of_dim_belonging = [ | ||
torch.tensor([i] * len(batch[i].signals[key]), dtype=torch.int64) | ||
for i in range(len(batch)) | ||
if key in batch[i].signals | ||
] | ||
signals_belonging[key] = torch.cat(signals_of_dim_belonging, dim=0) | ||
return signals, signals_belonging | ||
|
||
|
||
def collate_other_features(batch): | ||
feature_names = set( | ||
[ | ||
key | ||
for example in batch | ||
if example.other_features is not None | ||
for key in example.other_features.keys() | ||
] | ||
) | ||
if len(feature_names) == 0: | ||
other_features = None | ||
else: | ||
other_features = {} | ||
for key in feature_names: | ||
other_features[key] = torch.cat( | ||
[ | ||
example.other_features[key] | ||
for example in batch | ||
if key in example.other_features | ||
], | ||
dim=0, | ||
) | ||
return other_features | ||
|
||
|
||
def collate_neighborhood_matrices(batch): | ||
all_neighborhood_matrices_keys = set( | ||
[ | ||
key | ||
for example in batch | ||
if example.neighborhood_matrices is not None | ||
for key in example.neighborhood_matrices | ||
] | ||
) | ||
neighborhood_matrices = dict() | ||
for key in all_neighborhood_matrices_keys: | ||
neighborhood_matrices[key] = ( | ||
convert_sparse_matrices_to_sparse_block_matrix( | ||
key, | ||
[ | ||
# Get the neighborhood matrix if it exists, otherwise None | ||
( | ||
example.neighborhood_matrices[key] | ||
if key in example.neighborhood_matrices | ||
else None | ||
) | ||
for example in batch | ||
], | ||
batch, | ||
) | ||
) | ||
return neighborhood_matrices | ||
|
||
|
||
def generate_batched_simplicial_complex_from_data(data): | ||
return BatchedSimplicialComplex( | ||
signals=data.x, | ||
neighborhood_matrices=data.neighborhood_matrices, | ||
other_features=data.other_features, | ||
) | ||
|
||
|
||
def collate(batch): | ||
batch = [ | ||
generate_batched_simplicial_complex_from_data(example) | ||
for example in batch | ||
] | ||
# Collate the signals and make a belonging vector | ||
signals, signals_belonging = collate_signals(batch) | ||
# Concatenate other features. Take all the feature names from all the examples in the batch. | ||
other_features = collate_other_features(batch) | ||
# Concatenate neighborhood matrices | ||
neighborhood_matrices = collate_neighborhood_matrices(batch) | ||
|
||
return ( | ||
BatchedSimplicialComplex( | ||
signals, | ||
neighborhood_matrices=neighborhood_matrices, | ||
other_features=other_features, | ||
), | ||
signals_belonging, | ||
len(batch), | ||
) | ||
|
||
|
||
class SimplicialDataLoader(DataLoader): | ||
def __init__(self, dataset, **kwargs): | ||
collate_fn = kwargs.get("collate_fn", collate) | ||
kwargs = {k: v for k, v in kwargs.items() if k != "collate_fn"} | ||
super().__init__(dataset, collate_fn=collate_fn, **kwargs) | ||
|
||
|
||
################################################################################ | ||
### Simplicial datamodule for toponetx. | ||
################################################################################ | ||
|
||
|
||
class TopoNetXDataModule(LightningDataModule): | ||
def __init__( | ||
self, | ||
data_dir: str = "./data", | ||
transform: Compose | None = None, | ||
use_stratified: bool = False, | ||
batch_size: int = 128, | ||
seed: int = 2024, | ||
): | ||
super().__init__() | ||
base_transform = [ | ||
SimplicialComplexTransform(), | ||
DimZeroHodgeLaplacianSimplicialComplexTransform(), | ||
DimOneHodgeLaplacianUpSimplicialComplexTransform(), | ||
DimOneHodgeLaplacianDownSimplicialComplexTransform(), | ||
DimTwoHodgeLaplacianSimplicialComplexTransform(), | ||
BettiNumbersToTargetSimplicialComplexTransform(), | ||
SimplicialComplexOnesTransform(), | ||
] | ||
if transform is not None: | ||
base_transform += transform | ||
self.data_dir = data_dir | ||
self.transform = Compose(base_transform) | ||
self.use_stratified = use_stratified | ||
self.stratified = None | ||
self.batch_size = batch_size | ||
self.seed = seed | ||
|
||
def prepare_data(self) -> None: | ||
SimplicialDataset(root=self.data_dir) | ||
|
||
def setup(self, stage=None): | ||
simplicial_full = SimplicialDataset( | ||
root=self.data_dir, transform=self.transform | ||
) | ||
if self.use_stratified: | ||
self.stratified = torch.vstack( | ||
[data.y for data in simplicial_full] | ||
) | ||
|
||
indices_dataset = np.arange(len(simplicial_full)) | ||
|
||
train_indices, val_indices = train_test_split( | ||
indices_dataset, | ||
test_size=0.2, | ||
shuffle=True, | ||
stratify=self.stratified, | ||
# random_state=RandomState(self.seed), | ||
) | ||
self.train_ds = Subset(simplicial_full, train_indices) | ||
self.val_ds = Subset(simplicial_full, val_indices) | ||
|
||
def train_dataloader(self): | ||
return SimplicialDataLoader(self.train_ds, batch_size=self.batch_size) | ||
|
||
def val_dataloader(self): | ||
return SimplicialDataLoader(self.val_ds, batch_size=self.batch_size) | ||
|
||
def test_dataloader(self): | ||
return SimplicialDataLoader(self.val_ds, batch_size=self.batch_size) |
Oops, something went wrong.