Skip to content

Commit

Permalink
Updating dataloader for cell model
Browse files Browse the repository at this point in the history
  • Loading branch information
rballeba committed Nov 16, 2024
1 parent 8acd984 commit 8f19697
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 7 deletions.
94 changes: 87 additions & 7 deletions code/datasets/cell_dataloader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import numpy as np
import torch
from toponetx import SimplicialComplex
from torch.utils.data import DataLoader
from torch_geometric.data.data import Data
from typing import List
from models.cells.mp.complex import ComplexBatch

from datasets.utils import eliminate_zeros_torch_sparse, torch_sparse_to_scipy_sparse
from models.cells.mp.complex import ComplexBatch, Cochain


class CellDataloader(DataLoader):
Expand All @@ -13,19 +16,96 @@ def __init__(self, dataset, **kwargs):
super().__init__(dataset, collate_fn=collate_fn, **kwargs)


def simplicial_complex_to_cochain(sc: SimplicialComplex, dim: int, x, y):
def get_boundary_index(sc: Data, dim: int):
if dim == 0:
return None # No boundary for vertices


def _get_shared_simplices(data: Data, adj_index, dim: int, cofaces: bool = False):
simplices_dim = {i: simplex for i, simplex in enumerate(data.sc.skeleton(dim))}
simplices_related_dim = dim + 1 if cofaces else dim - 1
simplices_related = {simplex: i for i, simplex in enumerate(data.sc.skeleton(simplices_related_dim))}
common_cofaces = []
for i in range(adj_index.shape[1]):
s1_idx = adj_index[0, i].item()
s2_idx = adj_index[1, i].item()
s1 = simplices_dim[s1_idx]
s2 = simplices_dim[s2_idx]
if cofaces:
common_simplex = tuple(sorted(frozenset(s1).union(s2)))
else:
common_simplex = tuple(sorted(frozenset(s1).intersection(s2)))
common_simplex_idx = simplices_related[common_simplex]
common_cofaces.append(common_simplex_idx)
return torch.tensor(common_cofaces)


from scipy import sparse


def extract_adj_from_boundary(B, G=None):
A = sparse.csr_matrix(B.T).dot(sparse.csr_matrix(B))

n = A.shape[0]
if G is not None:
assert n == G.number_of_edges()

# Subtract self-loops, which we do not count.
connections = A.count_nonzero() - np.sum(A.diagonal() != 0)

index = torch.empty((2, connections), dtype=torch.long)
orient = torch.empty(connections)

connection = 0
cA = A.tocoo()
for i, j, v in zip(cA.row, cA.col, cA.data):
if j >= i:
continue
assert v == 1 or v == -1, print(v)

index[0, connection] = i
index[1, connection] = j
orient[connection] = float(np.sign(v))

index[0, connection + 1] = j
index[1, connection + 1] = i
orient[connection + 1] = float(np.sign(v))

connection += 2

assert connection == connections
return index, orient


def data_to_cochain(data: Data, dim: int):
"""
Convert a SimplicialComplex object into a Cochain object for a given dimension.
"""
raise NotImplementedError()
x = data.x[dim]
upper_index, upper_orient = extract_adj_from_boundary(
torch_sparse_to_scipy_sparse(data.connectivity[f"boundary_{dim+1}"].T)
) if dim < data.sc.dim else (None, None)
lower_index, lower_orient = extract_adj_from_boundary(
torch_sparse_to_scipy_sparse(data.connectivity[f"boundary_{dim}"])
) if dim < data.sc.dim else (None, None)
shared_boundaries = _get_shared_simplices(data, lower_index, dim, cofaces=False) \
if lower_index is not None else None
shared_coboundaries = _get_shared_simplices(data, upper_index, dim, cofaces=True) \
if upper_index is not None else None
boundary_index = data.connectivity[f"boundary_{dim}"].indices()
y = getattr(data, "y", None)
# TODO: Mapping is not used in their implementation, so I leave it as None for now
return Cochain(dim, x, upper_index, lower_index, shared_boundaries, shared_coboundaries,
None, boundary_index, upper_orient, lower_orient, y)


def collate_cell_models(sc_list: List[Data]) -> ComplexBatch:
def collate_cell_models(batch: List[Data]) -> ComplexBatch:
"""
Convert a list of SimplicialComplex objects into a ComplexBatch.
"""
max_dim = max(x.sc.dim for x in sc_list)
cochains = []
max_dim = max(x.sc.dim for x in batch)
for dim in range(max_dim + 1):
pass

dim_cochains = [data_to_cochain(data, dim) for data in batch]
cochains.append(dim_cochains)
raise NotImplementedError()
33 changes: 33 additions & 0 deletions code/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,3 +246,36 @@ def torch_sparse_to_scipy_sparse(torch_tensor):
)

return scipy_sparse_matrix


def eliminate_zeros_torch_sparse(sparse_tensor: torch.Tensor):
"""
Remove zero entries from a PyTorch sparse tensor.
Args:
sparse_tensor (torch.sparse.Tensor): The input sparse tensor.
Returns:
torch.sparse.Tensor: A new sparse tensor with zero values removed.
"""
if not sparse_tensor.is_coalesced():
sparse_tensor = sparse_tensor.coalesce()

# Get the indices and values
indices = sparse_tensor.indices()
values = sparse_tensor.values()

# Create a mask of non-zero values
non_zero_mask = values != 0

# Filter out zero entries
filtered_indices = indices[:, non_zero_mask]
filtered_values = values[non_zero_mask]

# Create a new sparse tensor without zero entries
new_sparse_tensor = torch.sparse_coo_tensor(
filtered_indices, filtered_values, sparse_tensor.size(), dtype=values.dtype, device=values.device
)

# Coalesce the new tensor to ensure indices are unique and sorted
return new_sparse_tensor.coalesce()

0 comments on commit 8f19697

Please sign in to comment.