Skip to content

Commit

Permalink
finishing the code of the dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Nov 16, 2024
1 parent 8f19697 commit 069c6bf
Showing 1 changed file with 30 additions and 10 deletions.
40 changes: 30 additions & 10 deletions code/datasets/cell_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List

from datasets.utils import eliminate_zeros_torch_sparse, torch_sparse_to_scipy_sparse
from models.cells.mp.complex import ComplexBatch, Cochain
from models.cells.mp.complex import ComplexBatch, Cochain, CochainBatch, Complex


class CellDataloader(DataLoader):
Expand Down Expand Up @@ -77,22 +77,22 @@ def extract_adj_from_boundary(B, G=None):
return index, orient


def data_to_cochain(data: Data, dim: int):
def data_to_cochain(data: Data, dim: int, max_dim: int):
"""
Convert a SimplicialComplex object into a Cochain object for a given dimension.
"""
x = data.x[dim]
upper_index, upper_orient = extract_adj_from_boundary(
torch_sparse_to_scipy_sparse(data.connectivity[f"boundary_{dim+1}"].T)
) if dim < data.sc.dim else (None, None)
) if (dim < data.sc.dim) and (dim < max_dim) else (None, None)
lower_index, lower_orient = extract_adj_from_boundary(
torch_sparse_to_scipy_sparse(data.connectivity[f"boundary_{dim}"])
) if dim < data.sc.dim else (None, None)
) if (dim < data.sc.dim) and (dim > 0) else (None, None)
shared_boundaries = _get_shared_simplices(data, lower_index, dim, cofaces=False) \
if lower_index is not None else None
shared_coboundaries = _get_shared_simplices(data, upper_index, dim, cofaces=True) \
if upper_index is not None else None
boundary_index = data.connectivity[f"boundary_{dim}"].indices()
boundary_index = data.connectivity[f"boundary_{dim}"].indices() if dim > 0 else None
y = getattr(data, "y", None)
# TODO: Mapping is not used in their implementation, so I leave it as None for now
return Cochain(dim, x, upper_index, lower_index, shared_boundaries, shared_coboundaries,
Expand All @@ -102,10 +102,30 @@ def data_to_cochain(data: Data, dim: int):
def collate_cell_models(batch: List[Data]) -> ComplexBatch:
"""
Convert a list of SimplicialComplex objects into a ComplexBatch.
Data Point 1: [Cochain(0D), Cochain(1D), Cochain(2D)] → Complex(Data Point 1)
Data Point 2: [Cochain(0D), Cochain(1D), Cochain(2D)] → Complex(Data Point 2)
ComplexBatch = [Complex(Data Point 1), Complex(Data Point 2)]
"""
cochains = []
complexes = []
max_dim = max(x.sc.dim for x in batch)
for dim in range(max_dim + 1):
dim_cochains = [data_to_cochain(data, dim) for data in batch]
cochains.append(dim_cochains)
raise NotImplementedError()

for data in batch:
cochains = []
for dim in range(max_dim + 1):
cochain = data_to_cochain(data, dim, max_dim)
cochains.append(cochain)
complex = Complex(
*cochains,
y=data.y,
dimension=max_dim
)
complexes.append(complex)

cochain_batch = ComplexBatch.from_complex_list(
data_list=complexes,
max_dim=max_dim
)

return cochain_batch

0 comments on commit 069c6bf

Please sign in to comment.