Skip to content

Commit

Permalink
almost runs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielbinschmid committed Nov 16, 2024
1 parent 069c6bf commit ed715fc
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 11 deletions.
6 changes: 3 additions & 3 deletions code/datasets/cell_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
import torch
from toponetx import SimplicialComplex
from torch.utils.data import DataLoader
from torch_geometric.data.data import Data
from torch_geometric.data import Data
from typing import List

from datasets.utils import eliminate_zeros_torch_sparse, torch_sparse_to_scipy_sparse
from models.cells.mp.complex import ComplexBatch, Cochain, CochainBatch, Complex

Expand Down Expand Up @@ -109,10 +108,11 @@ def collate_cell_models(batch: List[Data]) -> ComplexBatch:
ComplexBatch = [Complex(Data Point 1), Complex(Data Point 2)]
"""
complexes = []
max_dim = max(x.sc.dim for x in batch)
max_dim = 0 max(x.sc.dim for x in batch)

for data in batch:
cochains = []
dim = max_dim
for dim in range(max_dim + 1):
cochain = data_to_cochain(data, dim, max_dim)
cochains.append(cochain)
Expand Down
10 changes: 6 additions & 4 deletions code/models/cells/mp/cin0.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ class SparseCIN(torch.nn.Module):

def __init__(
self,
num_input_features,
num_classes,
num_layers,
hidden,
config: CellMPConfig,
dropout_rate: float = 0.5,
max_dim: int = 2,
jump_mode=None,
Expand All @@ -178,6 +175,11 @@ def __init__(
graph_norm="bn",
):
super(SparseCIN, self).__init__()

num_input_features = config.num_input_features
num_classes = config.num_classes
num_layers = config.num_layers
hidden = config.hidden

self.max_dim = max_dim
if readout_dims is not None:
Expand Down
7 changes: 7 additions & 0 deletions code/models/cells/mp/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,3 +808,10 @@ def from_complex_list(
)

return batch

@property
def batch_size(self) -> int:
return self.num_complexes

def __len__(self) -> int:
return self.num_complexes
1 change: 0 additions & 1 deletion code/models/cells/mp/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ def forward(self, cochain: CochainMessagePassingParams):
up_attr=cochain.kwargs["up_attr"],
boundary_attr=cochain.kwargs["boundary_attr"],
)

# As in GIN, we can learn an injective update function for each multi-set
out_up += (1 + self.eps1) * cochain.x
out_boundaries += (1 + self.eps2) * cochain.x
Expand Down
4 changes: 2 additions & 2 deletions code/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from models.simplicial_complexes.sccn import SCCN, SCCNConfig
from models.simplicial_complexes.sccnn import SCCNN, SCCNNConfig
from models.simplicial_complexes.scn import SCN, SCNConfig
from models.cells.mp.cin0 import CIN0, CellMPConfig
from models.cells.mp.cin0 import CIN0, CellMPConfig, SparseCIN

model_lookup: Dict[ModelType, nn.Module] = {
ModelType.GAT: GAT,
Expand All @@ -33,7 +33,7 @@
ModelType.SCCNN: SCCNN,
ModelType.SCN: SCN,
ModelType.TransfConv: TransfConv,
ModelType.CELL_MP: CIN0,
ModelType.CELL_MP: SparseCIN,
}

ModelConfig = Union[
Expand Down
5 changes: 4 additions & 1 deletion code/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
# CONFIG ---------------------------------------------------------------------

# model
model_config = CellMPConfig()
model_config = CellMPConfig(
num_input_features=1,
num_classes=3,
)
transform_type = TransformType.degree_transform_sc
# model_config = MLPConfig(
# num_hidden_neurons=64,
Expand Down

0 comments on commit ed715fc

Please sign in to comment.