diff --git a/models/models.py b/models/models.py index e7bd77c..a032ffd 100644 --- a/models/models.py +++ b/models/models.py @@ -5,10 +5,9 @@ from typing import Dict, Union, Annotated, Callable from pydantic import Tag import torch.nn as nn -from topomodelx.nn.simplicial.sccnn import SCCNN from torch_geometric.loader import DataLoader -from datasets.topox_dataloader import collate_simplicial_models_topox, SimplicialTopoXDataloader +from datasets.topox_dataloader import SimplicialTopoXDataloader from models.GCN import GCN, GCNConfig from models.GAT import GAT, GATConfig from models.MLP import MLP, MLPConfig @@ -19,7 +18,7 @@ from models.simplicial_complexes.san import SAN, SANConfig from models.simplicial_complexes.sccn import SCCN, SCCNConfig -from models.simplicial_complexes.sccnn import SCCNNConfig +from models.simplicial_complexes.sccnn import SCCNN, SCCNNConfig from models.simplicial_complexes.scn import SCN, SCNConfig model_lookup: Dict[ModelType, nn.Module] = { diff --git a/models/simplicial_complexes/sccnn.py b/models/simplicial_complexes/sccnn.py index bc7028a..eda100c 100644 --- a/models/simplicial_complexes/sccnn.py +++ b/models/simplicial_complexes/sccnn.py @@ -52,7 +52,7 @@ def __init__( self, config: SCCNNConfig ): - super(SCCNN, self).__init__() + super().__init__() self.sccnn_backbone = SCCNNCustom( config.in_channels, config.hidden_channels_all,