From 069c6bf3bddca7f6d051e4a2317e2d7e7fa23f6f Mon Sep 17 00:00:00 2001 From: binbash Date: Sat, 16 Nov 2024 10:49:26 +0100 Subject: [PATCH] finishing the code of the dataloader --- code/datasets/cell_dataloader.py | 40 ++++++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/code/datasets/cell_dataloader.py b/code/datasets/cell_dataloader.py index 3acf883..3febadd 100644 --- a/code/datasets/cell_dataloader.py +++ b/code/datasets/cell_dataloader.py @@ -6,7 +6,7 @@ from typing import List from datasets.utils import eliminate_zeros_torch_sparse, torch_sparse_to_scipy_sparse -from models.cells.mp.complex import ComplexBatch, Cochain +from models.cells.mp.complex import ComplexBatch, Cochain, CochainBatch, Complex class CellDataloader(DataLoader): @@ -77,22 +77,22 @@ def extract_adj_from_boundary(B, G=None): return index, orient -def data_to_cochain(data: Data, dim: int): +def data_to_cochain(data: Data, dim: int, max_dim: int): """ Convert a SimplicialComplex object into a Cochain object for a given dimension. """ x = data.x[dim] upper_index, upper_orient = extract_adj_from_boundary( torch_sparse_to_scipy_sparse(data.connectivity[f"boundary_{dim+1}"].T) - ) if dim < data.sc.dim else (None, None) + ) if (dim < data.sc.dim) and (dim < max_dim) else (None, None) lower_index, lower_orient = extract_adj_from_boundary( torch_sparse_to_scipy_sparse(data.connectivity[f"boundary_{dim}"]) - ) if dim < data.sc.dim else (None, None) + ) if (dim < data.sc.dim) and (dim > 0) else (None, None) shared_boundaries = _get_shared_simplices(data, lower_index, dim, cofaces=False) \ if lower_index is not None else None shared_coboundaries = _get_shared_simplices(data, upper_index, dim, cofaces=True) \ if upper_index is not None else None - boundary_index = data.connectivity[f"boundary_{dim}"].indices() + boundary_index = data.connectivity[f"boundary_{dim}"].indices() if dim > 0 else None y = getattr(data, "y", None) # TODO: Mapping is not used in their implementation, so I leave it as None for now return Cochain(dim, x, upper_index, lower_index, shared_boundaries, shared_coboundaries, @@ -102,10 +102,30 @@ def data_to_cochain(data: Data, dim: int): def collate_cell_models(batch: List[Data]) -> ComplexBatch: """ Convert a list of SimplicialComplex objects into a ComplexBatch. + + Data Point 1: [Cochain(0D), Cochain(1D), Cochain(2D)] → Complex(Data Point 1) + Data Point 2: [Cochain(0D), Cochain(1D), Cochain(2D)] → Complex(Data Point 2) + + ComplexBatch = [Complex(Data Point 1), Complex(Data Point 2)] """ - cochains = [] + complexes = [] max_dim = max(x.sc.dim for x in batch) - for dim in range(max_dim + 1): - dim_cochains = [data_to_cochain(data, dim) for data in batch] - cochains.append(dim_cochains) - raise NotImplementedError() + + for data in batch: + cochains = [] + for dim in range(max_dim + 1): + cochain = data_to_cochain(data, dim, max_dim) + cochains.append(cochain) + complex = Complex( + *cochains, + y=data.y, + dimension=max_dim + ) + complexes.append(complex) + + cochain_batch = ComplexBatch.from_complex_list( + data_list=complexes, + max_dim=max_dim + ) + + return cochain_batch