From ed715fcce78cc9dd1e3f3a3c0797b2658fbbb0ac Mon Sep 17 00:00:00 2001 From: binbash Date: Sat, 16 Nov 2024 17:31:02 +0100 Subject: [PATCH] almost runs --- code/datasets/cell_dataloader.py | 6 +++--- code/models/cells/mp/cin0.py | 10 ++++++---- code/models/cells/mp/complex.py | 7 +++++++ code/models/cells/mp/layers.py | 1 - code/models/models.py | 4 ++-- code/run.py | 5 ++++- 6 files changed, 22 insertions(+), 11 deletions(-) diff --git a/code/datasets/cell_dataloader.py b/code/datasets/cell_dataloader.py index 3febadd..4c92c4e 100644 --- a/code/datasets/cell_dataloader.py +++ b/code/datasets/cell_dataloader.py @@ -2,9 +2,8 @@ import torch from toponetx import SimplicialComplex from torch.utils.data import DataLoader -from torch_geometric.data.data import Data +from torch_geometric.data import Data from typing import List - from datasets.utils import eliminate_zeros_torch_sparse, torch_sparse_to_scipy_sparse from models.cells.mp.complex import ComplexBatch, Cochain, CochainBatch, Complex @@ -109,10 +108,11 @@ def collate_cell_models(batch: List[Data]) -> ComplexBatch: ComplexBatch = [Complex(Data Point 1), Complex(Data Point 2)] """ complexes = [] - max_dim = max(x.sc.dim for x in batch) + max_dim = 0 max(x.sc.dim for x in batch) for data in batch: cochains = [] + dim = max_dim for dim in range(max_dim + 1): cochain = data_to_cochain(data, dim, max_dim) cochains.append(cochain) diff --git a/code/models/cells/mp/cin0.py b/code/models/cells/mp/cin0.py index cd5c881..e311f47 100644 --- a/code/models/cells/mp/cin0.py +++ b/code/models/cells/mp/cin0.py @@ -160,10 +160,7 @@ class SparseCIN(torch.nn.Module): def __init__( self, - num_input_features, - num_classes, - num_layers, - hidden, + config: CellMPConfig, dropout_rate: float = 0.5, max_dim: int = 2, jump_mode=None, @@ -178,6 +175,11 @@ def __init__( graph_norm="bn", ): super(SparseCIN, self).__init__() + + num_input_features = config.num_input_features + num_classes = config.num_classes + num_layers = config.num_layers + hidden = config.hidden self.max_dim = max_dim if readout_dims is not None: diff --git a/code/models/cells/mp/complex.py b/code/models/cells/mp/complex.py index de2c4bd..fa914e4 100644 --- a/code/models/cells/mp/complex.py +++ b/code/models/cells/mp/complex.py @@ -808,3 +808,10 @@ def from_complex_list( ) return batch + + @property + def batch_size(self) -> int: + return self.num_complexes + + def __len__(self) -> int: + return self.num_complexes \ No newline at end of file diff --git a/code/models/cells/mp/layers.py b/code/models/cells/mp/layers.py index b9b5ba9..c532dc5 100644 --- a/code/models/cells/mp/layers.py +++ b/code/models/cells/mp/layers.py @@ -276,7 +276,6 @@ def forward(self, cochain: CochainMessagePassingParams): up_attr=cochain.kwargs["up_attr"], boundary_attr=cochain.kwargs["boundary_attr"], ) - # As in GIN, we can learn an injective update function for each multi-set out_up += (1 + self.eps1) * cochain.x out_boundaries += (1 + self.eps2) * cochain.x diff --git a/code/models/models.py b/code/models/models.py index 9470d04..ef69868 100644 --- a/code/models/models.py +++ b/code/models/models.py @@ -21,7 +21,7 @@ from models.simplicial_complexes.sccn import SCCN, SCCNConfig from models.simplicial_complexes.sccnn import SCCNN, SCCNNConfig from models.simplicial_complexes.scn import SCN, SCNConfig -from models.cells.mp.cin0 import CIN0, CellMPConfig +from models.cells.mp.cin0 import CIN0, CellMPConfig, SparseCIN model_lookup: Dict[ModelType, nn.Module] = { ModelType.GAT: GAT, @@ -33,7 +33,7 @@ ModelType.SCCNN: SCCNN, ModelType.SCN: SCN, ModelType.TransfConv: TransfConv, - ModelType.CELL_MP: CIN0, + ModelType.CELL_MP: SparseCIN, } ModelConfig = Union[ diff --git a/code/run.py b/code/run.py index 380d64e..741a7a9 100644 --- a/code/run.py +++ b/code/run.py @@ -29,7 +29,10 @@ # CONFIG --------------------------------------------------------------------- # model -model_config = CellMPConfig() +model_config = CellMPConfig( + num_input_features=1, + num_classes=3, +) transform_type = TransformType.degree_transform_sc # model_config = MLPConfig( # num_hidden_neurons=64,