diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100755 index 7a2a6bce..00000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,18 +0,0 @@ -//devcontainer.json -{ - "name": "TopoBenchmarkX:new", - "dockerFile": "./Dockerfile", - "customizations": { - "vscode": { - "settings": { - "terminal.integrated.shell.linux": "/bin/bash" - }, - "extensions": [ - "ms-python.python", - "ms-python.isort", - "ms-python.vscode-pylance", - "ms-toolsai.jupyter" - ] - } - } -} \ No newline at end of file diff --git a/.devcontainer/pyproject.toml b/.devcontainer/pyproject.toml deleted file mode 100755 index e9439fe3..00000000 --- a/.devcontainer/pyproject.toml +++ /dev/null @@ -1,129 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "TopoBenchmarkX" -version = "0.0.1" -authors = [ - {name = "PyT-Team Authors", email = "tlscabinet@gmail.com"} -] -readme = "README.md" -description = "Topological Deep Learning" -license = {file = "LICENSE.txt"} -classifiers = [ - "License :: OSI Approved :: MIT License", - "Development Status :: 4 - Beta", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11" -] -requires-python = ">= 3.10" -dependencies=[ - "tqdm", - "numpy", - "scipy", - "requests", - "scikit-learn", - "matplotlib", - "networkx", - "pandas", - "pyg-nightly", - "decorator", - "hypernetx < 2.0.0", - "trimesh", - "spharapy", - "hydra-core==1.3.2", - "hydra-colorlog==1.2.0", - "hydra-optuna-sweeper==1.2.0", - "rich", - "rootutils", - "pytest", - -] - -[project.optional-dependencies] -doc = [ - "jupyter", - "nbsphinx", - "nbsphinx_link", - "sphinx", - "sphinx_gallery", - "pydata-sphinx-theme" -] -lint = [ - "black", - "black[jupyter]", - "flake8", - "flake8-docstrings", - "Flake8-pyproject", - "isort", - "pre-commit" -] -test = [ - "pytest", - "pytest-cov", - "coverage", - "jupyter", - "mypy" -] - -dev = ["TopoBenchmarkX[test, lint]"] -all = ["TopoBenchmarkX[dev, doc]"] - -[project.urls] -homepage="https://github.com/pyt-team/TopoBenchmarkX" -repository="https://github.com/pyt-team/TopoBenchmarkX" - -[tool.setuptools.dynamic] -version = {attr = "topobenchmarkx.__version__"} - -[tool.setuptools.packages.find] -include = [ - "topobenchmarkx", - "topobenchmarkx.*" -] - -[tool.mypy] -warn_redundant_casts = true -warn_unused_ignores = true -show_error_codes = true -plugins = "numpy.typing.mypy_plugin" - -[[tool.mypy.overrides]] -module = [ - "torch_cluster.*","networkx.*","scipy.spatial","scipy.sparse","toponetx.classes.simplicial_complex" -] -ignore_missing_imports = true - -[tool.pytest.ini_options] -addopts = "--capture=no" - -[tool.black] -line-length = 88 - -[tool.isort] -line_length = 88 -multi_line_output = 3 -include_trailing_comma = true -skip = [".gitignore", "__init__.py"] - -[tool.flake8] -max-line-length = 88 -application_import_names = "topobenchmarkx" -docstring-convention = "numpy" -exclude = [ - "topobenchmarkx/__init__.py", - "docs/conf.py" -] - -import_order_style = "smarkets" -extend-ignore = ["E501", "E203"] -per-file-ignores = [ - "*/__init__.py: D104, F401", -] diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 00000000..1d720b4b --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +/logs +/datasets/graph +.ruff_cache diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2122dbd9..9b0a82ab 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,21 +15,16 @@ repos: - id: check-added-large-files args: - --maxkb=2048 -# - id: trailing-whitespace - id: requirements-txt-fixer - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 - hooks: - - id: ruff - #types_or: [ python, pyi, jupyter ] - #types_or: [ python, pyi ] - args: [ --fix ] - - id: ruff-format - #types_or: [ python, pyi, jupyter ] - #types_or: [ python, pyi ] + # - repo: https://github.com/astral-sh/ruff-pre-commit + # rev: v0.4.4 + # hooks: + # - id: ruff + # args: [ --fix ] + # - id: ruff-format - - repo: https://github.com/numpy/numpydoc - rev: v1.6.0 - hooks: - - id: numpydoc-validation \ No newline at end of file + # - repo: https://github.com/numpy/numpydoc + # rev: v1.6.0 + # hooks: + # - id: numpydoc-validation diff --git a/.devcontainer/Dockerfile b/Dockerfile old mode 100755 new mode 100644 similarity index 59% rename from .devcontainer/Dockerfile rename to Dockerfile index afb6e0e4..0dd0bd01 --- a/.devcontainer/Dockerfile +++ b/Dockerfile @@ -7,10 +7,11 @@ COPY . . RUN pip install --upgrade pip RUN pip install -e '.[all]' -RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git +RUN pip install git+https://github.com/pyt-team/TopoNetX.git +RUN pip install git+https://github.com/pyt-team/TopoModelX.git +RUN pip install git+https://github.com/pyt-team/TopoEmbedX.git + +RUN pip install torch_geometric==2.4.0 RUN pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 RUN pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html RUN pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -RUN pip install lightning>=2.0.0 -RUN pip install numpy pre-commit jupyterlab notebook ipykernel \ No newline at end of file diff --git a/conda.sh b/conda.sh new file mode 100755 index 00000000..40717fa9 --- /dev/null +++ b/conda.sh @@ -0,0 +1,11 @@ +# #!/bin/bash + +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm -rf ~/miniconda3/miniconda.sh + +~/miniconda3/bin/conda init bash + +#conda create -n topox python=3.11.3 +#conda activate topox \ No newline at end of file diff --git a/configs/dataset/MUTAG.yaml b/configs/dataset/MUTAG.yaml index 654c7b33..e3cdbdb7 100755 --- a/configs/dataset/MUTAG.yaml +++ b/configs/dataset/MUTAG.yaml @@ -15,6 +15,7 @@ parameters: num_features: - 7 # initial node features - 4 # initial edge features + num_classes: 2 task: classification loss_type: cross_entropy @@ -26,7 +27,7 @@ parameters: train_prop: 0.5 # for "random" strategy splitting # Lifting parameters - max_dim_if_lifted: 2 + max_dim_if_lifted: 3 # This is the maximum dimension of the simplicial complex in the dataset preserve_edge_attr_if_lifted: False # Dataloader parameters diff --git a/configs/dataset/PROTEINS_TU.yaml b/configs/dataset/PROTEINS_TU.yaml index d69a2a31..1b4afb96 100755 --- a/configs/dataset/PROTEINS_TU.yaml +++ b/configs/dataset/PROTEINS_TU.yaml @@ -19,7 +19,7 @@ parameters: monitor_metric: accuracy task_level: graph data_seed: 9 - split_type: k-fold #'k-fold' # either "k-fold" or "random" strategies + split_type: random #'k-fold' # either "k-fold" or "random" strategies k: 10 # for "k-fold" Cross-Validation train_prop: 0.5 # for "random" strategy splitting diff --git a/configs/dataset/ZINC.yaml b/configs/dataset/ZINC.yaml index 62a39319..58d2ee3e 100644 --- a/configs/dataset/ZINC.yaml +++ b/configs/dataset/ZINC.yaml @@ -1,7 +1,10 @@ _target_: topobenchmarkx.io.load.loaders.GraphLoader +# USE python train.py dataset.transforms.one_hot_node_degree_features.degrees_fields=x to run this config + defaults: - - transforms/data_manipulations: node_feat_to_float + - transforms/data_manipulations: node_degrees + - transforms/data_manipulations@transforms.one_hot_node_degree_features: one_hot_node_degree_features - transforms: ${get_default_transform:graph,${model}} # Data definition @@ -13,14 +16,15 @@ parameters: data_split_dir: ${paths.data_dir}data_splits/${dataset.parameters.data_name} # Dataset parameters - num_features: 1 # here basically I specify the initial num features in mutang at x aka x_0 + num_features: 21 # torch_geometric ZINC dataset has 21 atom types + max_node_degree: 20 # Use it to one_hot encode node degrees. Additional parameter to run dataset.transforms.one_hot_node_degree_features.degrees_fields=x num_classes: 1 task: regression loss_type: mse monitor_metric: mae task_level: graph data_seed: 0 - split_type: 'fixed' # either k-fold or test + split_type: 'fixed' # ZINC accept only split #k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/dataset/coauthorship_citeseer.yaml b/configs/dataset/coauthorship_citeseer.yaml index 2b9dc3c4..078b0e21 100755 --- a/configs/dataset/coauthorship_citeseer.yaml +++ b/configs/dataset/coauthorship_citeseer.yaml @@ -20,7 +20,7 @@ parameters: monitor_metric: accuracy task_level: node data_seed: 0 - split_type: k-fold #'k-fold' # either k-fold or test + split_type: random #'k-fold' # either k-fold or test k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/dataset/coauthorship_cora.yaml b/configs/dataset/coauthorship_cora.yaml index 0c626534..d48233de 100755 --- a/configs/dataset/coauthorship_cora.yaml +++ b/configs/dataset/coauthorship_cora.yaml @@ -19,7 +19,7 @@ parameters: monitor_metric: accuracy task_level: node data_seed: 0 - split_type: k-fold #'k-fold' # either k-fold or test + split_type: random #'k-fold' # either k-fold or test k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/dataset/manual_dataset.yaml b/configs/dataset/manual_dataset.yaml index 6d7191cc..861ad4c8 100755 --- a/configs/dataset/manual_dataset.yaml +++ b/configs/dataset/manual_dataset.yaml @@ -19,7 +19,7 @@ parameters: monitor_metric: accuracy task_level: node data_seed: 0 - split_type: k-fold #'k-fold' # either k-fold or test + split_type: random #'k-fold' # either k-fold or test k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/dataset/transforms/data_manipulations/node_degrees.yaml b/configs/dataset/transforms/data_manipulations/node_degrees.yaml index c1775453..1d666d32 100755 --- a/configs/dataset/transforms/data_manipulations/node_degrees.yaml +++ b/configs/dataset/transforms/data_manipulations/node_degrees.yaml @@ -1,5 +1,5 @@ _target_: topobenchmarkx.transforms.data_transform.DataTransform transform_name: "NodeDegrees" transform_type: "data manipulation" -selected_fields: ["edge_index", "incidence"] #"incidence" +selected_fields: ["edge_index"] # "incidence" diff --git a/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml b/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml index b8892aa9..7fdd0e67 100755 --- a/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml +++ b/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml @@ -4,5 +4,5 @@ transform_type: "data manipulation" degrees_fields: "node_degrees" features_fields: "x" -max_degrees: ${dataset.parameters.max_node_degree} +max_degree: ${dataset.parameters.max_node_degree} diff --git a/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml b/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml index ba60f86c..d9a6c272 100644 --- a/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml +++ b/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml @@ -1,7 +1,6 @@ _target_: topobenchmarkx.transforms.data_transform.DataTransform transform_type: 'lifting' transform_name: "CellCyclesLifting" -k_value: 1 complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3} -max_cell_length: 6 +max_cell_length: 10 preserve_edge_attr: ${oc.select:dataset.parameters.preserve_edge_attr_if_lifted,False} diff --git a/configs/dataset/us_country_demos.yaml b/configs/dataset/us_country_demos.yaml index 29d9b44b..61cd17a3 100755 --- a/configs/dataset/us_country_demos.yaml +++ b/configs/dataset/us_country_demos.yaml @@ -17,7 +17,7 @@ parameters: num_features: 6 num_classes: 1 task: regression - task_variable: 'Election' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'] + task_variable: 'MedianIncome' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'] force_reload: True loss_type: mse monitor_metric: mae diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml index b40863f7..da285376 100755 --- a/configs/logger/wandb.yaml +++ b/configs/logger/wandb.yaml @@ -7,7 +7,7 @@ wandb: offline: False id: null # pass correct id to resume experiment! anonymous: null # enable anonymous logging - project: "topox_10fold_sweep" + project: "None" log_model: False # upload lightning ckpts prefix: "" # a string to put at the beginning of metric keys # entity: "" # set to name of your wandb team diff --git a/configs/loss/default.yaml b/configs/loss/default.yaml deleted file mode 100755 index e69de29b..00000000 diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index 510ce09c..4d1e1575 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: can +model_domain: cell feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 @@ -17,30 +22,32 @@ backbone: heads: 1 # For now we stuck to out_channels//heads, keep heads = 1 concat: True skip_connection: True - n_layers: 1 + n_layers: 4 att_lift: False backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CANWrapper + _target_: topobenchmarkx.models.wrappers.CANWrapper _partial_: true + wrapper_name: CANWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} - in_channels: ${parameter_multiplication:${model.backbone.out_channels},${model.backbone.heads}} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel + in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - pooling_type: sum + task_level: ${dataset.parameters.task_level} + pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/cell/cwn_dcm.yaml b/configs/model/cell/cccn.yaml similarity index 60% rename from configs/model/cell/cwn_dcm.yaml rename to configs/model/cell/cccn.yaml index 764602f7..26a3360f 100755 --- a/configs/model/cell/cwn_dcm.yaml +++ b/configs/model/cell/cccn.yaml @@ -1,41 +1,47 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: cwn_dcm +model_domain: cell feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 - proj_dropout: 0.0 + proj_dropout: 0. selected_dimensions: - 0 - 1 backbone: - _target_: custom_models.cell.cwn_dcm.CWNDCM + _target_: custom_models.cell.cccn.CCCN in_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 dropout: 0.0 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CWNDCMWrapper + _target_: topobenchmarkx.models.wrappers.CCCNWrapper _partial_: true + wrapper_name: CCCNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/cell/ccxn.yaml b/configs/model/cell/ccxn.yaml index 7d4f336a..851c4544 100755 --- a/configs/model/cell/ccxn.yaml +++ b/configs/model/cell/ccxn.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: ccxn +model_domain: cell feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.cell.ccxn.CCXN @@ -17,26 +22,28 @@ backbone_additional_params: hidden_channels: ${model.feature_encoder.out_channels} backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CCXNWrapper + _target_: topobenchmarkx.models.wrappers.CCXNWrapper _partial_: true + wrapper_name: CCXNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/cell/cwn.yaml b/configs/model/cell/cwn.yaml index 85ab2cb7..a5508dc7 100755 --- a/configs/model/cell/cwn.yaml +++ b/configs/model/cell/cwn.yaml @@ -1,9 +1,13 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: cwn +model_domain: cell feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 64 proj_dropout: 0.0 backbone: @@ -12,30 +16,31 @@ backbone: in_channels_1: ${model.feature_encoder.out_channels} in_channels_2: ${model.feature_encoder.out_channels} hid_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CWNWrapper + _target_: topobenchmarkx.models.wrappers.CWNWrapper _partial_: true + wrapper_name: CWNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} - readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/graph/gat.yaml b/configs/model/graph/gat.yaml index 92a74591..27763d6a 100755 --- a/configs/model/graph/gat.yaml +++ b/configs/model/graph/gat.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: gat +model_domain: graph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: torch_geometric.nn.models.GAT @@ -17,27 +22,28 @@ backbone: concat: true backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true + wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum - loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index ffc152ad..a40f2bd6 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -1,39 +1,46 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: gcn +model_domain: graph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 64 + proj_dropout: 0.0 backbone: _target_: torch_geometric.nn.models.GCN in_channels: ${model.feature_encoder.out_channels} hidden_channels: ${model.feature_encoder.out_channels} - num_layers: 1 + num_layers: 2 dropout: 0.0 act: relu backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true + wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/graph/gin.yaml b/configs/model/graph/gin.yaml index b3797cdf..76658ac8 100755 --- a/configs/model/graph/gin.yaml +++ b/configs/model/graph/gin.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: gin +model_domain: graph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: torch_geometric.nn.models.GIN @@ -14,27 +19,28 @@ backbone: act: relu backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true + wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum - loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/alldeepset.yaml b/configs/model/hypergraph/alldeepset.yaml index ff6d84e2..58681a35 100755 --- a/configs/model/hypergraph/alldeepset.yaml +++ b/configs/model/hypergraph/alldeepset.yaml @@ -1,11 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: alldeepset +model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.allset.AllSet @@ -24,26 +27,28 @@ backbone: #num_features: ${model.backbone.hidden_channels} backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 6c589729..984d78ec 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -1,44 +1,48 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: allsettransformer +model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.allset_transformer.AllSetTransformer in_channels: ${model.feature_encoder.out_channels} hidden_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 heads: 4 dropout: 0. mlp_num_layers: 1 mlp_dropout: 0. backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum - loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 42088fff..2512dcf3 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -1,45 +1,49 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: edgnn model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 + proj_dropout: 0.0 backbone: _target_: custom_models.hypergraph.edgnn.EDGNN num_features: ${model.feature_encoder.out_channels} # ${dataset.parameters.num_features} - input_dropout: 0.2 - dropout: 0.2 + input_dropout: 0. + dropout: 0. activation: relu - MLP_num_layers: 0 + MLP_num_layers: 1 All_num_layers: 1 edconv_type: EquivSet aggregate: 'add' backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown, NoReadOut hidden_dim: None num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/unignn.yaml b/configs/model/hypergraph/unignn.yaml index bfd56c5f..5623d33a 100755 --- a/configs/model/hypergraph/unignn.yaml +++ b/configs/model/hypergraph/unignn.yaml @@ -1,4 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: unignn2 +model_domain: hypergraph + +feature_encoder: + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder + in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} + out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.unigcn.UniGCN @@ -7,26 +17,28 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index 31dd1d62..1d380cfa 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -1,43 +1,48 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: unignn2 +model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.unigcnii.UniGCNII - in_channels: ${model.feature_encoder.out_channels} # ${dataset.parameters.num_features} + in_channels: ${model.feature_encoder.out_channels} hidden_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 alpha: 0.5 beta: 0.5 - input_drop: 0.2 - layer_drop: 0.2 + input_drop: 0.0 + layer_drop: 0.0 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index eb13722f..53f2eeb5 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: san +model_domain: simplicial feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 64 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 @@ -18,27 +23,28 @@ backbone: epsilon_harmonic: 1e-1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SANWrapper + _target_: topobenchmarkx.models.wrappers.SANWrapper _partial_: true + wrapper_name: SANWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum - loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index b92ee200..55e61dde 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -1,38 +1,45 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: sccnn +model_domain: simplicial feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} # ${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.simplicial.sccn.SCCN channels: ${model.feature_encoder.out_channels} - max_rank: 1 + max_rank: 2 n_layers: 1 update_func: "sigmoid" backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper + _target_: topobenchmarkx.models.wrappers.SCCNWrapper _partial_: true + wrapper_name: SCCNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/sccnn.yaml b/configs/model/simplicial/sccnn.yaml index 4e1d65ff..2c15e115 100755 --- a/configs/model/simplicial/sccnn.yaml +++ b/configs/model/simplicial/sccnn.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: sccnn +model_domain: simplicial feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 @@ -26,27 +31,29 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper + _target_: topobenchmarkx.models.wrappers.SCCNNWrapper _partial_: true + wrapper_name: SCCNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/sccnn_custom.yaml b/configs/model/simplicial/sccnn_custom.yaml index 48a5f4a9..29fab672 100755 --- a/configs/model/simplicial/sccnn_custom.yaml +++ b/configs/model/simplicial/sccnn_custom.yaml @@ -1,15 +1,19 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +model_name: sccnn_custom +model_domain: simplicial feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 64 + out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 - 2 - backbone: _target_: custom_models.simplicial.sccnn.SCCNNCusctom in_channels_all: @@ -27,26 +31,28 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper + _target_: topobenchmarkx.models.wrappers.SCCNNWrapper _partial_: true + wrapper_name: SCCNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/scn.yaml b/configs/model/simplicial/scn.yaml index bb2c2b28..cee0fce6 100755 --- a/configs/model/simplicial/scn.yaml +++ b/configs/model/simplicial/scn.yaml @@ -1,9 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule + +mdoel_name: scn +model_domain: simplicial feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 @@ -17,26 +22,28 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCNWrapper + _target_: topobenchmarkx.models.wrappers.SCNWrapper _partial_: true + wrapper_name: SCNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/train.yaml b/configs/train.yaml index 57c43837..0bdd2097 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,8 +4,8 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: PROTEINS_TU #us_country_demos - - model: hypergraph/allsettransformer #hypergraph/unignn2 #allsettransformer + - dataset: MUTAG # us_country_demos + - model: cell/can #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/custom_models/cell/cwn_dcm.py b/custom_models/cell/cccn.py similarity index 98% rename from custom_models/cell/cwn_dcm.py rename to custom_models/cell/cccn.py index eaa81cb9..c1e936a7 100644 --- a/custom_models/cell/cwn_dcm.py +++ b/custom_models/cell/cccn.py @@ -34,7 +34,7 @@ def forward(self, xe, Lu, Ld): return z_h + z_s + z_i -class CWNDCM(nn.Module): +class CCCN(nn.Module): def __init__(self, in_channels, n_layers=2, dropout=0.0, last_act=False): super().__init__() self.d = dropout diff --git a/custom_models/cell/cin.py b/custom_models/cell/cin.py index deefffcf..9b136bbb 100644 --- a/custom_models/cell/cin.py +++ b/custom_models/cell/cin.py @@ -1,9 +1,8 @@ """CWN class.""" import torch -import torch.nn.functional as F -from topomodelx.nn.cell.cwn_layer import CWNLayer import torch.nn as nn +import torch.nn.functional as F from topomodelx.base.conv import Conv from torch_geometric.nn.models import MLP @@ -65,7 +64,8 @@ def forward( neighborhood_2_to_1, neighborhood_0_to_1, ): - """Forward computation through projection, convolutions, linear layers and average pooling. + """Forward computation through projection, convolutions, linear layers + and average pooling. Parameters ---------- @@ -192,15 +192,21 @@ def __init__( self.conv_1_to_1 = ( conv_1_to_1 if conv_1_to_1 is not None - else _CWNDefaultFirstConv(in_channels_1, in_channels_2, out_channels) + else _CWNDefaultFirstConv( + in_channels_1, in_channels_2, out_channels + ) ) self.conv_0_to_1 = ( conv_0_to_1 if conv_0_to_1 is not None - else _CWNDefaultSecondConv(in_channels_0, in_channels_1, out_channels) + else _CWNDefaultSecondConv( + in_channels_0, in_channels_1, out_channels + ) ) self.aggregate_fn = ( - aggregate_fn if aggregate_fn is not None else _CWNDefaultAggregate() + aggregate_fn + if aggregate_fn is not None + else _CWNDefaultAggregate() ) self.update_fn = ( update_fn @@ -325,11 +331,10 @@ def forward( class _CWNDefaultFirstConv(nn.Module): - r""" - Default implementation of the first convolutional step in CWNLayer. + r"""Default implementation of the first convolutional step in CWNLayer. - The self.forward method of this module must be treated as - a protocol for the first convolutional step in CWN layer. + The self.forward method of this module must be treated as a protocol for + the first convolutional step in CWN layer. """ def __init__( @@ -383,11 +388,10 @@ def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1): class _CWNDefaultSecondConv(nn.Module): - r""" - Default implementation of the second convolutional step in CWNLayer. + r"""Default implementation of the second convolutional step in CWNLayer. - The self.forward method of this module must be treated as - a protocol for the second convolutional step in CWN layer. + The self.forward method of this module must be treated as a protocol for + the second convolutional step in CWN layer. """ def __init__(self, in_channels_0, out_channels) -> None: @@ -417,11 +421,10 @@ def forward(self, x_0, neighborhood_0_to_1): class _CWNDefaultAggregate(nn.Module): - r""" - Default implementation of an aggregation step in CWNLayer. + r"""Default implementation of an aggregation step in CWNLayer. - The self.forward method of this module must be treated as - a protocol for the aggregation step in CWN layer. + The self.forward method of this module must be treated as a protocol for + the aggregation step in CWN layer. """ def __init__(self) -> None: diff --git a/custom_models/hypergraph/edgnn.py b/custom_models/hypergraph/edgnn.py index 0888fc5b..3867ff62 100644 --- a/custom_models/hypergraph/edgnn.py +++ b/custom_models/hypergraph/edgnn.py @@ -46,7 +46,9 @@ def __init__( self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.BatchNorm1d(hidden_channels)) for _ in range(num_layers - 2): - self.lins.append(nn.Linear(hidden_channels, hidden_channels)) + self.lins.append( + nn.Linear(hidden_channels, hidden_channels) + ) self.normalizations.append(nn.BatchNorm1d(hidden_channels)) self.lins.append(nn.Linear(hidden_channels, out_channels)) elif Normalization == "ln": @@ -65,7 +67,9 @@ def __init__( self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.LayerNorm(hidden_channels)) for _ in range(num_layers - 2): - self.lins.append(nn.Linear(hidden_channels, hidden_channels)) + self.lins.append( + nn.Linear(hidden_channels, hidden_channels) + ) self.normalizations.append(nn.LayerNorm(hidden_channels)) self.lins.append(nn.Linear(hidden_channels, out_channels)) else: @@ -78,7 +82,9 @@ def __init__( self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.Identity()) for _ in range(num_layers - 2): - self.lins.append(nn.Linear(hidden_channels, hidden_channels)) + self.lins.append( + nn.Linear(hidden_channels, hidden_channels) + ) self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(hidden_channels, out_channels)) @@ -88,7 +94,7 @@ def reset_parameters(self): for lin in self.lins: lin.reset_parameters() for normalization in self.normalizations: - if not (normalization.__class__.__name__ == "Identity"): + if normalization.__class__.__name__ != "Identity": normalization.reset_parameters() def forward(self, x): @@ -245,7 +251,9 @@ def forward(self, X, vertex, edges, X0): class JumpLinkConv(nn.Module): - def __init__(self, in_features, out_features, mlp_layers=2, aggr="add", alpha=0.5): + def __init__( + self, in_features, out_features, mlp_layers=2, aggr="add", alpha=0.5 + ): super().__init__() self.W = MLP( in_features, @@ -339,7 +347,10 @@ def forward(self, X, vertex, edges, X0): ) # [E, C], reduce is 'mean' here as default deg_e = torch_scatter.scatter( - torch.ones(Xve.shape[0], device=Xve.device), edges, dim=-2, reduce="sum" + torch.ones(Xve.shape[0], device=Xve.device), + edges, + dim=-2, + reduce="sum", ) Xe = torch.cat([Xe, torch.log(deg_e)[..., None]], -1) @@ -350,7 +361,10 @@ def forward(self, X, vertex, edges, X0): ) # [N, C] deg_v = torch_scatter.scatter( - torch.ones(Xev.shape[0], device=Xev.device), vertex, dim=-2, reduce="sum" + torch.ones(Xev.shape[0], device=Xev.device), + vertex, + dim=-2, + reduce="sum", ) X = self.W3(torch.cat([Xv, X, X0, torch.log(deg_v)[..., None]], -1)) @@ -374,7 +388,7 @@ def __init__( normalization="None", AllSet_input_norm=False, ): - """EDGNN + """EDGNN. Args: num_features (int): number of input features @@ -390,7 +404,6 @@ def __init__( aggregate (str, optional): aggregation method. Defaults to 'add'. normalization (str, optional): normalization method. Defaults to 'None'. AllSet_input_norm (bool, optional): whether to normalize input features. Defaults to False. - """ super().__init__() act = {"Id": nn.Identity(), "relu": nn.ReLU(), "prelu": nn.PReLU()} @@ -402,8 +415,12 @@ def __init__( self.hidden_channels = self.in_channels self.mlp1_layers = MLP_num_layers - self.mlp2_layers = MLP_num_layers if MLP2_num_layers < 0 else MLP2_num_layers - self.mlp3_layers = MLP_num_layers if MLP3_num_layers < 0 else MLP3_num_layers + self.mlp2_layers = ( + MLP_num_layers if MLP2_num_layers < 0 else MLP2_num_layers + ) + self.mlp3_layers = ( + MLP_num_layers if MLP3_num_layers < 0 else MLP3_num_layers + ) self.nlayer = All_num_layers self.edconv_type = edconv_type diff --git a/custom_models/simplicial/sccnn.py b/custom_models/simplicial/sccnn.py index fa49754f..b9a816ab 100644 --- a/custom_models/simplicial/sccnn.py +++ b/custom_models/simplicial/sccnn.py @@ -1,11 +1,9 @@ """SCCNN implementation for complex classification.""" import torch -from topomodelx.nn.simplicial.sccnn_layer import SCCNNLayer from torch.nn.parameter import Parameter - class SCCNNCusctom(torch.nn.Module): """SCCNN implementation for complex classification. @@ -28,7 +26,6 @@ class SCCNNCusctom(torch.nn.Module): Update function for the simplicial complex convolution. n_layers: int Number of layers. - """ def __init__( @@ -44,9 +41,15 @@ def __init__( super().__init__() # first layer # we use an MLP to map the features on simplices of different dimensions to the same dimension - self.in_linear_0 = torch.nn.Linear(in_channels_all[0], hidden_channels_all[0]) - self.in_linear_1 = torch.nn.Linear(in_channels_all[1], hidden_channels_all[1]) - self.in_linear_2 = torch.nn.Linear(in_channels_all[2], hidden_channels_all[2]) + self.in_linear_0 = torch.nn.Linear( + in_channels_all[0], hidden_channels_all[0] + ) + self.in_linear_1 = torch.nn.Linear( + in_channels_all[1], hidden_channels_all[1] + ) + self.in_linear_2 = torch.nn.Linear( + in_channels_all[2], hidden_channels_all[2] + ) self.layers = torch.nn.ModuleList( SCCNNLayer( @@ -100,6 +103,7 @@ def forward(self, x_all, laplacian_all, incidence_all): # Layer """Simplicial Complex Convolutional Neural Network Layer.""" + class SCCNNLayer(torch.nn.Module): r"""Layer of a Simplicial Complex Convolutional Neural Network. @@ -215,7 +219,9 @@ def __init__( self.weight_0 = Parameter( torch.Tensor( - self.in_channels_0, self.out_channels_0, 1 + conv_order + 1 + conv_order + self.in_channels_0, + self.out_channels_0, + 1 + conv_order + 1 + conv_order, ) ) @@ -326,7 +332,9 @@ def chebyshev_conv(self, conv_operator, conv_order, x): Output tensor. x[:, :, k] = (conv_operator@....@conv_operator) @ x. """ num_simplices, num_channels = x.shape - X = torch.empty(size=(num_simplices, num_channels, conv_order)).to(x.device) + X = torch.empty(size=(num_simplices, num_channels, conv_order)).to( + x.device + ) if self.aggr_norm: X[:, :, 0] = torch.mm(conv_operator, x) @@ -388,7 +396,9 @@ def forward(self, x_all, laplacian_all, incidence_all): x_0, x_1, x_2 = x_all if self.sc_order == 2: - laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = laplacian_all + laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = ( + laplacian_all + ) elif self.sc_order > 2: ( laplacian_0, @@ -407,7 +417,6 @@ def forward(self, x_all, laplacian_all, incidence_all): # torch.eye(num_edges).to(x_0.device), # torch.eye(num_triangles).to(x_0.device), # ) - """ Convolution in the node space """ @@ -429,7 +438,9 @@ def forward(self, x_all, laplacian_all, incidence_all): x_1_to_0_laplacian = self.chebyshev_conv( laplacian_0, self.conv_order, x_1_to_0_upper ) - x_1_to_0 = torch.cat([x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2) + x_1_to_0 = torch.cat( + [x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2 + ) # ------------------- x_0_all = torch.cat((x_0_to_0, x_1_to_0), 2) @@ -460,13 +471,19 @@ def forward(self, x_all, laplacian_all, incidence_all): x_0_1_lower = torch.mm(b1.T, x_0) # Calculate lowwer chebyshev_conv - x_0_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_1_lower) + x_0_1_down = self.chebyshev_conv( + laplacian_down_1, self.conv_order, x_0_1_lower + ) # Calculate upper chebyshev_conv (Note: in case of signed incidence should be always zero) - x_0_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_0_1_lower) + x_0_1_up = self.chebyshev_conv( + laplacian_up_1, self.conv_order, x_0_1_lower + ) # Concatenate output of filters - x_0_to_1 = torch.cat([x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2) + x_0_to_1 = torch.cat( + [x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2 + ) # ------------------- # x_2_to_1 = torch.mm(b2, x_2) @@ -477,20 +494,23 @@ def forward(self, x_all, laplacian_all, incidence_all): x_2_1_upper = torch.mm(b2, x_2) # Calculate lowwer chebyshev_conv (Note: In case of signed incidence should be always zero) - x_2_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_2_1_upper) + x_2_1_down = self.chebyshev_conv( + laplacian_down_1, self.conv_order, x_2_1_upper + ) # Calculate upper chebyshev_conv - x_2_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_2_1_upper) + x_2_1_up = self.chebyshev_conv( + laplacian_up_1, self.conv_order, x_2_1_upper + ) - x_2_to_1 = torch.cat([x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2) + x_2_to_1 = torch.cat( + [x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2 + ) # ------------------- x_1_all = torch.cat((x_0_to_1, x_1_to_1, x_2_to_1), 2) - - """ - convolution in the face (triangle) space, depending on the SC order, - the exact form maybe a little different - """ + """Convolution in the face (triangle) space, depending on the SC order, + the exact form maybe a little different.""" # -------------------Logic to obtain update for 2-cells -------- # x_identity_2 = torch.unsqueeze(identity_2 @ x_2, 2) @@ -516,10 +536,16 @@ def forward(self, x_all, laplacian_all, incidence_all): # x_1_to_2 = torch.cat((x_1_to_2_identity, x_1_to_2), 2) x_1_2_lower = torch.mm(b2.T, x_1) - x_1_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_1_2_lower) - x_1_2_down = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_1_2_lower) + x_1_2_down = self.chebyshev_conv( + laplacian_down_2, self.conv_order, x_1_2_lower + ) + x_1_2_down = self.chebyshev_conv( + laplacian_up_2, self.conv_order, x_1_2_lower + ) - x_1_to_2 = torch.cat([x_1_2_lower.unsqueeze(2), x_1_2_down, x_1_2_down], dim=2) + x_1_to_2 = torch.cat( + [x_1_2_lower.unsqueeze(2), x_1_2_down, x_1_2_down], dim=2 + ) # That is my code, but to execute this part we need to have simplices order of k+1 in this case order of 3 # x_3_2_upper = x_1_to_2 = torch.mm(b2, x_3) diff --git a/env.bash b/env.bash deleted file mode 100644 index b4225bbf..00000000 --- a/env.bash +++ /dev/null @@ -1,34 +0,0 @@ -# #!/bin/bash - -# set -e - -# # Step 1: Upgrade pip -# pip install --upgrade pip - -# # Step 2: Install dependencies -# yes | pip install -e '.[all]' -# yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -# yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git -# yes | pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 -# yes | pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html -# yes | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -# yes | pip install lightning>=2.0.0 -# yes | pip install numpy pre-commit jupyterlab notebook ipykernel - - -yes | conda create -n topox python=3.11.3 -conda activate topox - -pip install -e '.[all]' - -yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git - -yes | pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 -yes | pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html -yes | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -yes | pip install numpy pre-commit jupyterlab notebook ipykernel - -pytest - -pre-commit install diff --git a/env.sh b/env.sh new file mode 100755 index 00000000..fe8c8123 --- /dev/null +++ b/env.sh @@ -0,0 +1,20 @@ +#!/bin/bash -l + +pip install --upgrade pip +pip install -e '.[all]' + +pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git +pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git +pip install --no-dependencies git+https://github.com/pyt-team/TopoEmbedX.git + +# Note that not all combinations of torch and CUDA are available +# See https://github.com/pyg-team/pyg-lib to check the configuration that works for you +TORCH="2.3.0" # available options: 1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, or 2.3.0 +CUDA="cu121" # if available, select the CUDA version suitable for your system + # available options: cpu, cu102, cu113, cu116, cu117, cu118, or cu121 +pip install torch==${TORCH} --extra-index-url https://download.pytorch.org/whl/${CUDA} +pip install lightning torch_geometric==2.4.0 +pip install pyg-lib torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html +pytest + +pre-commit install diff --git a/format_and_lint.sh b/format_and_lint.sh new file mode 100755 index 00000000..e7795e8c --- /dev/null +++ b/format_and_lint.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +# Run ruff to check for issues and fix them +ruff check . --fix + +# Run docformatter to reformat docstrings and comments +docformatter --in-place --recursive --wrap-summaries 79 --wrap-descriptions 79 . + +# Run black to format the code +black . \ No newline at end of file diff --git a/hp_scripts/main_exp/cellular/CAN.sh b/hp_scripts/main_exp/cellular/CAN.sh new file mode 100644 index 00000000..d112e1a5 --- /dev/null +++ b/hp_scripts/main_exp/cellular/CAN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/can \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/can \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/can \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/CCCN.sh b/hp_scripts/main_exp/cellular/CCCN.sh new file mode 100644 index 00000000..cef694ca --- /dev/null +++ b/hp_scripts/main_exp/cellular/CCCN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/cccn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/cccn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/cccn \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/CCXN.sh b/hp_scripts/main_exp/cellular/CCXN.sh new file mode 100644 index 00000000..87f316c9 --- /dev/null +++ b/hp_scripts/main_exp/cellular/CCXN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/ccxn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/ccxn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/ccxn \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/CWN.sh b/hp_scripts/main_exp/cellular/CWN.sh new file mode 100644 index 00000000..58366f32 --- /dev/null +++ b/hp_scripts/main_exp/cellular/CWN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/cwn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/cwn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/cwn \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/left_out.sh b/hp_scripts/main_exp/cellular/left_out.sh new file mode 100644 index 00000000..81f00969 --- /dev/null +++ b/hp_scripts/main_exp/cellular/left_out.sh @@ -0,0 +1,103 @@ +# ----Graph regression dataset---- +# Train on ZINC dataset + +# CWN +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# CCXN +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# CCCN +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + + +# CAN + +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + + +# REDDIT BINARY for all \ No newline at end of file diff --git a/hp_scripts/main_exp/graph/gat.sh b/hp_scripts/main_exp/graph/gat.sh new file mode 100644 index 00000000..231afe7c --- /dev/null +++ b/hp_scripts/main_exp/graph/gat.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=graph/gat \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=graph/gat \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/hp_scripts/main_exp/graph/gcn.sh b/hp_scripts/main_exp/graph/gcn.sh new file mode 100644 index 00000000..3e9f85d2 --- /dev/null +++ b/hp_scripts/main_exp/graph/gcn.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=graph/gcn \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=graph/gcn \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/hp_scripts/main_exp/graph/gin.sh b/hp_scripts/main_exp/graph/gin.sh new file mode 100644 index 00000000..a71f31fe --- /dev/null +++ b/hp_scripts/main_exp/graph/gin.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=graph/gin \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=graph/gin \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/hp_scripts/main_exp/hypergraph/allsettransformer.sh b/hp_scripts/main_exp/hypergraph/allsettransformer.sh new file mode 100644 index 00000000..3f7987c0 --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/allsettransformer.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/allsettransformer \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/allsettransformer \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done \ No newline at end of file diff --git a/hp_scripts/main_exp/hypergraph/edgnn.sh b/hp_scripts/main_exp/hypergraph/edgnn.sh new file mode 100644 index 00000000..87a28e90 --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/edgnn.sh @@ -0,0 +1,135 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.All_num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/hypergraph/left_out.sh b/hp_scripts/main_exp/hypergraph/left_out.sh new file mode 100644 index 00000000..94ac96e9 --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/left_out.sh @@ -0,0 +1,63 @@ +# ----Heterophilic datasets---- + +datasets=( questions ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done \ No newline at end of file diff --git a/hp_scripts/main_exp/hypergraph/unignn2.sh b/hp_scripts/main_exp/hypergraph/unignn2.sh new file mode 100644 index 00000000..f62d8ac0 --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/unignn2.sh @@ -0,0 +1,135 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/unignn2 \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/unignn2 \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/simplicial/SCN.sh b/hp_scripts/main_exp/simplicial/SCN.sh new file mode 100644 index 00000000..9c139fcd --- /dev/null +++ b/hp_scripts/main_exp/simplicial/SCN.sh @@ -0,0 +1,137 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + model.readout.readout_name=NoReadOut \ + dataset.transforms.graph2simplicial_lifting.signed=True \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper questions ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/topobenchmarkx/data/cornel_dataset.ipynb b/notebooks/cornel_dataset.ipynb similarity index 98% rename from topobenchmarkx/data/cornel_dataset.ipynb rename to notebooks/cornel_dataset.ipynb index 718f7ff1..e6eeae39 100644 --- a/topobenchmarkx/data/cornel_dataset.ipynb +++ b/notebooks/cornel_dataset.ipynb @@ -15,7 +15,6 @@ "\n", "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.io import fs\n", @@ -43,9 +42,9 @@ " root: str,\n", " name: str,\n", " parameters: dict = None,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " use_node_attr: bool = False,\n", " use_edge_attr: bool = False,\n", diff --git a/notebooks/curvature_results.ipynb b/notebooks/curvature_results.ipynb index 4c61ea97..ea591cdc 100644 --- a/notebooks/curvature_results.ipynb +++ b/notebooks/curvature_results.ipynb @@ -6,15 +6,14 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "import wandb\n", - "import pandas as pd\n", "import ast\n", "import glob\n", - "import numpy as np\n", "import warnings\n", - "from datetime import date\n", "from collections import defaultdict\n", + "from datetime import date\n", + "\n", + "import pandas as pd\n", + "import wandb\n", "\n", "today = date.today()\n", "api = wandb.Api()\n", @@ -712,9 +711,9 @@ " ]\n", "\n", " if subset.empty:\n", - " print(f\"---------\")\n", + " print(\"---------\")\n", " print(f\"No results for {model} on {dataset}\")\n", - " print(f\"---------\")\n", + " print(\"---------\")\n", " continue\n", " # Suppress all warnings\n", " warnings.filterwarnings(\"ignore\")\n", @@ -749,7 +748,7 @@ " for col, unique in unique_colums_values.items():\n", " print(f\"{col}: {unique}\")\n", " print()\n", - " print(f\"---------\")\n", + " print(\"---------\")\n", "\n", " # Check if \"special colums\" are not in unique_colums_values\n", " # For example dataset.parameters.data_seed should not be in aggregation columns\n", @@ -1486,7 +1485,7 @@ "result_dict = pd.DataFrame.from_dict(\n", " {\n", " (i, j): nested_dict[i][j]\n", - " for i in nested_dict.keys()\n", + " for i in nested_dict\n", " for j in nested_dict[i].keys()\n", " },\n", " orient=\"index\",\n", diff --git a/notebooks/data.ipynb b/notebooks/data.ipynb index e34a964b..696a2b55 100755 --- a/notebooks/data.ipynb +++ b/notebooks/data.ipynb @@ -25,18 +25,17 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", + "import hydra\n", "import torch\n", "import torch_geometric\n", - "from topobenchmarkx.data.datasets import CustomDataset\n", - "import hydra\n", - "from hydra import initialize, compose\n", + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "\n", "from topobenchmarkx.data.dataloader_fullbatch import FullBatchDataModule\n", + "from topobenchmarkx.data.datasets import CustomDataset\n", "from topobenchmarkx.io.load.loaders import (\n", - " GraphLoader,\n", - " SimplicialLoader,\n", " HypergraphLoader,\n", ")\n", - "from omegaconf import DictConfig, OmegaConf\n", "from topobenchmarkx.utils.config_resolvers import (\n", " get_default_transform,\n", " get_monitor_metric,\n", @@ -44,7 +43,6 @@ " infer_in_channels,\n", ")\n", "\n", - "\n", "OmegaConf.register_new_resolver(\"get_default_transform\", get_default_transform)\n", "OmegaConf.register_new_resolver(\"get_monitor_metric\", get_monitor_metric)\n", "OmegaConf.register_new_resolver(\"get_monitor_mode\", get_monitor_mode)\n", @@ -114,9 +112,6 @@ } ], "source": [ - "import torch\n", - "import torch_geometric\n", - "import numpy as np\n", "\n", "nci1 = torch_geometric.datasets.TUDataset(\n", " root=\".\",\n", @@ -712,8 +707,7 @@ } ], "source": [ - "from lightning import Callback, LightningDataModule, LightningModule, Trainer\n", - "from lightning.pytorch.loggers import Logger\n", + "from lightning import LightningModule\n", "\n", "model: LightningModule = hydra.utils.instantiate(config.model)" ] @@ -1012,9 +1006,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from topobenchmarkx.data.datasets import CustomDataset" - ] + "source": [] }, { "cell_type": "code", @@ -1050,7 +1042,7 @@ " for b in batch:\n", " values, keys = b[0], b[1]\n", " data = Data()\n", - " for key, value in zip(keys, values):\n", + " for key, value in zip(keys, values, strict=False):\n", " data[key] = value\n", "\n", " return data\n", @@ -1181,7 +1173,7 @@ "outputs": [], "source": [ "import torch\n", - "from torch.utils.data import Dataset, DataLoader\n", + "from torch.utils.data import DataLoader, Dataset\n", "\n", "\n", "class TextDataset(Dataset):\n", @@ -1495,7 +1487,6 @@ "source": [ "# Load data\n", "from topobenchmarkx.data.load.loaders import HypergraphLoader\n", - "from topobenchmarkx.data.dataloader_fullbatch import FullBatchDataModule\n", "\n", "data_loader = HypergraphLoader(config)\n", "data = data_loader.load()\n", @@ -1562,7 +1553,7 @@ "metadata": {}, "outputs": [], "source": [ - "#b in []topomodelx.nn.hypergraph.unigcnii.UniGCNII" + "# b in []topomodelx.nn.hypergraph.unigcnii.UniGCNII" ] }, { @@ -1570,9 +1561,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import topomodelx" - ] + "source": [] }, { "cell_type": "code", diff --git a/notebooks/play.ipynb b/notebooks/play.ipynb index 563e7c2d..93dc9f33 100644 --- a/notebooks/play.ipynb +++ b/notebooks/play.ipynb @@ -6,55 +6,58 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import os\n", + "import urllib.request\n", + "\n", + "import numpy as np\n", "import torch\n", "import torch_geometric\n", - "import urllib.request\n", "\n", "\n", "def hetero_load(name, path):\n", - " file_name = f'{name}.npz'\n", + " file_name = f\"{name}.npz\"\n", "\n", " data = np.load(os.path.join(path, file_name))\n", "\n", - " x = torch.tensor(data['node_features'])\n", - " y = torch.tensor(data['node_labels'])\n", - " edge_index = torch.tensor(data['edges']).T\n", + " x = torch.tensor(data[\"node_features\"])\n", + " y = torch.tensor(data[\"node_labels\"])\n", + " edge_index = torch.tensor(data[\"edges\"]).T\n", "\n", " # Make edge_index undirected\n", " edge_index = torch_geometric.utils.to_undirected(edge_index)\n", "\n", " # Remove self-loops\n", " edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)\n", - " \n", + "\n", " data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)\n", " return data\n", "\n", + "\n", "def download_hetero_datasets(name, path):\n", - " url = 'https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/'\n", - " name = f'{name}.npz'\n", + " url = \"https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/\"\n", + " name = f\"{name}.npz\"\n", " try:\n", - " print(f'Downloading {name}')\n", + " print(f\"Downloading {name}\")\n", " path2save = os.path.join(path, name)\n", " urllib.request.urlretrieve(url + name, path2save)\n", - " print('Done!')\n", + " print(\"Done!\")\n", " except:\n", - " raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')\n", - "\n", + " raise Exception(\n", + " \"\"\"Download failed! Make sure you have stable Internet connection and enter the right name\"\"\"\n", + " )\n", "\n", "\n", "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", - "import torch\n", "from omegaconf import DictConfig\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.io import fs\n", "\n", - "from topobenchmarkx.io.load.heterophilic import download_hetero_datasets, load_heterophilic_data\n", - "\n", + "from topobenchmarkx.io.load.heterophilic import (\n", + " download_hetero_datasets,\n", + " load_heterophilic_data,\n", + ")\n", "from topobenchmarkx.io.load.split_utils import random_splitting\n", "\n", "\n", @@ -97,14 +100,14 @@ " root: str,\n", " name: str,\n", " parameters: DictConfig,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " use_node_attr: bool = False,\n", " use_edge_attr: bool = False,\n", " ) -> None:\n", - " self.name = name #.replace(\"_\", \"-\")\n", + " self.name = name # .replace(\"_\", \"-\")\n", " self.parameters = parameters\n", " super().__init__(\n", " root, transform, pre_transform, pre_filter, force_reload=force_reload\n", @@ -144,7 +147,7 @@ " @property\n", " def processed_file_names(self) -> str:\n", " return \"data.pt\"\n", - " \n", + "\n", " @property\n", " def raw_file_names(self) -> list[str]:\n", " \"\"\"Spefify the downloaded raw fine name\"\"\"\n", @@ -171,7 +174,7 @@ " Returns:\n", " None\n", " \"\"\"\n", - " \n", + "\n", " data = load_heterophilic_data(name=self.name, path=self.raw_dir)\n", " data = data if self.pre_transform is None else self.pre_transform(data)\n", " self.save([data], self.processed_paths[0])\n", @@ -180,23 +183,24 @@ " return f\"{self.name}()\"\n", "\n", "\n", + "data_dir = \"/home/lev/projects/TopoBenchmarkX/datasets\"\n", + "data_domain = \"graph\"\n", + "data_type = \"heterophilic\"\n", + "data_name = \"amazon_ratings\"\n", "\n", - "data_dir = '/home/lev/projects/TopoBenchmarkX/datasets'\n", - "data_domain = 'graph'\n", - "data_type = 'heterophilic'\n", - "data_name = 'amazon_ratings'\n", - "\n", - "data_dir = f'{data_dir}/{data_domain}/{data_type}'\n", + "data_dir = f\"{data_dir}/{data_domain}/{data_type}\"\n", "\n", - "parameters={\n", - " 'split_type': 'random',\n", - " 'k': 10,\n", - " 'train_prop': 0.5,\n", - " 'data_seed':0,\n", - " 'data_split_dir': f'/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}'\n", - " }\n", + "parameters = {\n", + " \"split_type\": \"random\",\n", + " \"k\": 10,\n", + " \"train_prop\": 0.5,\n", + " \"data_seed\": 0,\n", + " \"data_split_dir\": f\"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}\",\n", + "}\n", "\n", - "dataset = HeteroDataset(name=data_name, root = data_dir, parameters=parameters, force_reload=True)" + "dataset = HeteroDataset(\n", + " name=data_name, root=data_dir, parameters=parameters, force_reload=True\n", + ")" ] }, { diff --git a/notebooks/result_processing.ipynb b/notebooks/result_processing.ipynb index a289ba0f..24ce72e4 100644 --- a/notebooks/result_processing.ipynb +++ b/notebooks/result_processing.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -87,7 +87,7 @@ " dtype='object')" ] }, - "execution_count": 2, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -98,9 +98,52 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'enforce_tags': True, 'print_config': True, 'ignore_warnings': False}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_init['extras'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'model'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/miniconda3/envs/topox/lib/python3.11/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", + "File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: 'model'", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 23\u001b[0m\n\u001b[1;32m 21\u001b[0m config_columns \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m column \u001b[38;5;129;01min\u001b[39;00m columns_to_normalize:\n\u001b[0;32m---> 23\u001b[0m df, columns \u001b[38;5;241m=\u001b[39m \u001b[43mnormalize_column\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolumn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m config_columns\u001b[38;5;241m.\u001b[39mextend(columns)\n", + "Cell \u001b[0;32mIn[7], line 3\u001b[0m, in \u001b[0;36mnormalize_column\u001b[0;34m(df, column_to_normalize)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnormalize_column\u001b[39m(df, column_to_normalize):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# Use json_normalize to flatten the nested dictionaries into separate columns\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m flattened_df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mjson_normalize(\u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[43mcolumn_to_normalize\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Rename columns to include 'nested_column' prefix\u001b[39;00m\n\u001b[1;32m 5\u001b[0m flattened_df\u001b[38;5;241m.\u001b[39mcolumns \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcolumn_to_normalize\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcol\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m col \u001b[38;5;129;01min\u001b[39;00m flattened_df\u001b[38;5;241m.\u001b[39mcolumns\n\u001b[1;32m 7\u001b[0m ]\n", + "File \u001b[0;32m~/miniconda3/envs/topox/lib/python3.11/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n", + "File \u001b[0;32m~/miniconda3/envs/topox/lib/python3.11/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3807\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m 3810\u001b[0m ):\n\u001b[1;32m 3811\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 3814\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3815\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3816\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3817\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n", + "\u001b[0;31mKeyError\u001b[0m: 'model'" + ] + } + ], "source": [ "def normalize_column(df, column_to_normalize):\n", " # Use json_normalize to flatten the nested dictionaries into separate columns\n", @@ -119,7 +162,7 @@ "\n", "\n", "# Config columns to normalize\n", - "columns_to_normalize = [\"model\", \"dataset\", \"callbacks\"]\n", + "columns_to_normalize = [\"model\", \"dataset\", \"callbacks\", \"paths\"]\n", "\n", "# Keep track of config columns added\n", "config_columns = []\n", @@ -2679,7 +2722,7 @@ "a[\"dataset.transforms.graph2simplicial_lifting.feature_lifting\"][\n", " a[\"dataset.transforms.graph2simplicial_lifting.feature_lifting\"].isna()\n", "] = \"projection\"\n", - "#a = a[~a[\"test/mae\"].isna()]\n", + "# a = a[~a[\"test/mae\"].isna()]\n", "a = a[~a[\"test/accuracy\"].isna()]\n", "\n", "a = a.groupby(\n", @@ -2692,7 +2735,7 @@ ").agg({col: [\"mean\", \"std\"] for col in performance_cols})\n", "\n", "ascending = True\n", - "#a = a.sort_values(by=(\"test/mae\", \"mean\"), ascending=ascending)\n", + "# a = a.sort_values(by=(\"test/mae\", \"mean\"), ascending=ascending)\n", "a = a.sort_values(by=(\"test/accuracy\", \"mean\"), ascending=ascending)\n", "# Show all rows\n", "pd.set_option(\"display.max_rows\", None)\n", @@ -6313,7 +6356,7 @@ "result_dict = pd.DataFrame.from_dict(\n", " {\n", " (i, j): nested_dict[i][j]\n", - " for i in nested_dict.keys()\n", + " for i in nested_dict\n", " for j in nested_dict[i].keys()\n", " },\n", " orient=\"index\",\n", @@ -6605,7 +6648,6 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "import numpy as np\n", "\n", "\n", "# Define the vector field function\n", diff --git a/notebooks/test_feature_lifting_dev.ipynb b/notebooks/test_feature_lifting_dev.ipynb index 704b350a..c2661f1d 100644 --- a/notebooks/test_feature_lifting_dev.ipynb +++ b/notebooks/test_feature_lifting_dev.ipynb @@ -47,7 +47,9 @@ "from topobenchmarkx.transforms.feature_liftings.feature_liftings import (\n", " ProjectionLifting,\n", ")\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", + "from topobenchmarkx.transforms.liftings.graph2simplicial import (\n", + " SimplicialCliqueLifting,\n", + ")\n", "\n", "\n", "class TestProjectionLifting:\n", @@ -134,13 +136,10 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", - "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", "from topobenchmarkx.transforms.feature_liftings.feature_liftings import (\n", " ConcatentionLifting,\n", ")\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", "\n", "\n", "class TestConcatentionLifting:\n", @@ -254,11 +253,10 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", - "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", - "from topobenchmarkx.transforms.feature_liftings.feature_liftings import SetLifting\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", + "from topobenchmarkx.transforms.feature_liftings.feature_liftings import (\n", + " SetLifting,\n", + ")\n", "\n", "\n", "class TestSetLifting:\n", diff --git a/notebooks/test_hypergraph_lifting_dev.ipynb b/notebooks/test_hypergraph_lifting_dev.ipynb index 635965f6..90abf896 100644 --- a/notebooks/test_hypergraph_lifting_dev.ipynb +++ b/notebooks/test_hypergraph_lifting_dev.ipynb @@ -113,14 +113,8 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", - "from topobenchmarkx.io.load.loaders import manual_graph\n", - "from topobenchmarkx.transforms.liftings.graph2hypergraph import (\n", - " HypergraphKHopLifting,\n", - " HypergraphKNearestNeighborsLifting,\n", - ")\n", "\n", "\n", "class TestHypergraphKNearestNeighborsLifting:\n", diff --git a/notebooks/test_simplicialclique_dev.ipynb b/notebooks/test_simplicialclique_dev.ipynb index 2b69bb64..b2f8ad8a 100644 --- a/notebooks/test_simplicialclique_dev.ipynb +++ b/notebooks/test_simplicialclique_dev.ipynb @@ -23,7 +23,9 @@ "\n", "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", + "from topobenchmarkx.transforms.liftings.graph2simplicial import (\n", + " SimplicialCliqueLifting,\n", + ")\n", "\n", "\n", "class TestSimplicialCliqueLifting:\n", @@ -219,10 +221,8 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", - "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", "from topobenchmarkx.transforms.liftings.graph2simplicial import (\n", " SimplicialNeighborhoodLifting,\n", ")\n", diff --git a/pyproject.toml b/pyproject.toml index 1a1b8bd5..e8290f40 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ dependencies=[ "networkx", "pandas", "gudhi", - "pyg-nightly", "decorator", "hypernetx < 2.0.0", "trimesh", @@ -42,9 +41,12 @@ dependencies=[ "hydra-core==1.3.2", "hydra-colorlog==1.2.0", "hydra-optuna-sweeper==1.2.0", - "lightning==2.2.1", - "einops==0.7.0", "wandb==0.16.4", + "einops==0.7.0", + "tabulate", + "ipykernel", + "notebook", + "jupyterlab", "rich", "rootutils", "pytest", @@ -78,9 +80,18 @@ all = ["TopoBenchmarkX[dev, doc]"] homepage="https://github.com/pyt-team/TopoBenchmarkX" repository="https://github.com/pyt-team/TopoBenchmarkX" +[tool.black] +line-length = 79 # PEP 8 standard for maximum line length +target-version = ['py310'] + +[tool.docformatter] +wrap-summaries = 79 +wrap-descriptions = 79 + [tool.ruff] target-version = "py310" extend-include = ["*.ipynb"] +line-length = 79 # PEP 8 standard for maximum line length [tool.ruff.format] docstring-code-format = false @@ -110,7 +121,7 @@ ignore = [ "UP038", # Use `X | Y` in `isinstance` call instead of `(X, Y)` -- not compatible with python 3.9 (even with __future__ import) "W293", # Does not allow to have empty lines in multiline comments "PERF203", # [TODO: fix all such issues] `try`-`except` within a loop incurs performance overhead -] +] [tool.ruff.lint.pydocstyle] convention = "numpy" @@ -141,10 +152,11 @@ module = [ ignore_missing_imports = true [tool.pytest.ini_options] -addopts = [ - "--capture=no", -] -pythonpath = "." +addopts = "--capture=no" +pythonpath = [ + "." +] + [tool.numpydoc_validation] checks = [ @@ -157,4 +169,5 @@ checks = [ exclude = [ '\.undocumented_method$', '\.__init__$', -] \ No newline at end of file + '\.__repr__$', +] diff --git a/tables/cell_statistics.csv b/tables/cell_statistics.csv new file mode 100644 index 00000000..ca90bae6 --- /dev/null +++ b/tables/cell_statistics.csv @@ -0,0 +1,13 @@ +,3,4,5,6,7,8,9,10,greater_than_10,dataset,domain +0,1120,260,103,51,40,27,22,25,1000,Cora,cell +1,750,278,97,61,30,12,19,12,404,citeseer,cell +2,4174,3017,599,558,266,313,226,172,14280,PubMed,cell +3,769,192,10196,21486,407,53,0,1,17,ZINC,cell +4,7165,1701,700,310,178,97,56,24,35,roman_empire,cell +5,51642,9105,1392,710,359,289,170,103,4783,amazon_ratings,cell +6,24123,288,0,96,2,94,0,92,4260,minesweeper,cell +7,0,0,68,419,0,0,0,0,51,MUTAG,cell +8,24211,7495,1842,1531,568,585,358,349,1834,PROTEINS,cell +9,186,129,3025,10657,373,59,16,64,376,NCI1,cell +10,77706,36,13,3,0,0,0,0,0,IMDB-BINARY,cell +11,80846,35,17,3,0,0,0,0,0,IMDB-MULTI,cell diff --git a/tables/dataset_statistics.csv b/tables/dataset_statistics.csv new file mode 100644 index 00000000..38e43633 --- /dev/null +++ b/tables/dataset_statistics.csv @@ -0,0 +1,13 @@ +,num_hyperedges,zero_cell,one_cell,two_cell,three_cell,dataset,domain +0,0,2708,5278,2648,0,Cora,cell +1,0,3327,4552,1663,0,citeseer,cell +2,0,19717,44324,23605,0,PubMed,cell +3,0,277864,298985,33121,0,ZINC,cell +4,0,22662,32927,10266,0,roman_empire,cell +5,0,24492,93050,68553,0,amazon_ratings,cell +6,0,10000,39402,28955,0,minesweeper,cell +7,0,3371,3721,538,0,MUTAG,cell +8,0,43471,81044,38773,0,PROTEINS,cell +9,0,122747,132753,14885,0,NCI1,cell +10,0,19773,96531,77758,0,IMDB-BINARY,cell +11,0,19502,98903,80901,0,IMDB-MULTI,cell diff --git a/test.bash b/test.sh similarity index 100% rename from test.bash rename to test.sh diff --git a/test/data/test_Dataloaders.py b/test/data/test_Dataloaders.py new file mode 100644 index 00000000..42b6b17b --- /dev/null +++ b/test/data/test_Dataloaders.py @@ -0,0 +1,138 @@ +"""Test the collate function.""" +import hydra +from hydra import compose, initialize +from omegaconf import OmegaConf + +import torch + +from topobenchmarkx.data.dataloaders import to_data_list, DefaultDataModule + +from topobenchmarkx.utils.config_resolvers import ( + get_default_transform, + get_monitor_metric, + get_monitor_mode, + infer_in_channels, +) + +import rootutils + +rootutils.setup_root("./", indicator=".project-root", pythonpath=True) + +class TestCollateFunction: + """Test collate_fn.""" + + def setup_method(self): + """Setup the test. + + For this test we load the MUTAG dataset. + + Parameters + ---------- + None + """ + OmegaConf.register_new_resolver("get_default_transform", get_default_transform) + OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) + OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode) + OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels) + OmegaConf.register_new_resolver( + "parameter_multiplication", lambda x, y: int(int(x) * int(y)) + ) + + initialize(version_base="1.3", config_path="../../configs", job_name="job") + cfg = compose(config_name="train.yaml") + + graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False) + datasets = graph_loader.load() + self.batch_size = 2 + datamodule = DefaultDataModule( + dataset_train=datasets[0], + dataset_val=datasets[1], + dataset_test=datasets[2], + batch_size=self.batch_size + ) + self.val_dataloader = datamodule.val_dataloader() + self.val_dataset = datasets[1] + + def test_lift_features(self): + """Test the collate funciton. + + To test the collate function we use the DefaultDataModule class to create a dataloader that uses the collate function. We then first check that the batched data has the expected shape. We then convert the batched data back to a list and check that the data in the list is the same as the original data. + + Parameters + ---------- + None + """ + def check_shape(batch, elems, key): + """Check that the batched data has the expected shape.""" + if 'x_' in key or 'x'==key: + rows = 0 + for i in range(len(elems)): + rows += elems[i][key].shape[0] + assert batch[key].shape[0] == rows + assert batch[key].shape[1] == elems[0][key].shape[1] + elif 'edge_index' in key: + cols = 0 + for i in range(len(elems)): + cols += elems[i][key].shape[1] + assert batch[key].shape[0] == 2 + assert batch[key].shape[1] == cols + elif 'batch_' in key: + rows = 0 + n = int(key.split('_')[1]) + for i in range(len(elems)): + rows += elems[i][f'x_{n}'].shape[0] + assert batch[key].shape[0] == rows + elif key in elems[0].keys(): + for i in range(len(batch[key].shape)): + i_elems = 0 + for j in range(len(elems)): + i_elems += elems[j][key].shape[i] + assert batch[key].shape[i] == i_elems + + def check_separation(matrix, n_elems_0_row, n_elems_0_col): + """Check that the matrix is separated into two parts diagonally concatenated.""" + assert torch.all(matrix[:n_elems_0_row, n_elems_0_col:] == 0) + assert torch.all(matrix[n_elems_0_row:, :n_elems_0_col] == 0) + + def check_values(matrix, m1, m2): + """Check that the values in the matrix are the same as the values in the original data.""" + assert torch.allclose(matrix[:m1.shape[0], :m1.shape[1]], m1) + assert torch.allclose(matrix[m1.shape[0]:, m1.shape[1]:], m2) + + + batch = next(iter(self.val_dataloader)) + elems = [] + for i in range(self.batch_size): + elems.append(self.val_dataset.data_lst[i]) + + # Check shape + for key in batch.keys(): + check_shape(batch, elems, key) + + # Check that the batched data is separated correctly and the values are correct + if self.batch_size == 2: + for key in batch.keys(): + if 'incidence_' in key: + i = int(key.split('_')[1]) + if i==0: + n0_row = 1 + else: + n0_row = torch.sum(batch[f'batch_{i-1}']==0) + n0_col = torch.sum(batch[f'batch_{i}']==0) + check_separation(batch[key].to_dense(), n0_row, n0_col) + check_values(batch[key].to_dense(), + elems[0][key].to_dense(), + elems[1][key].to_dense()) + + # Check that going back to a list of data gives the same data + batch_list = to_data_list(batch) + assert len(batch_list) == len(elems) + for i in range(len(batch_list)): + for key in elems[i].keys(): + if key in batch_list[i].keys(): + if batch_list[i][key].is_sparse: + assert torch.all(batch_list[i][key].coalesce().indices() == elems[i][key].coalesce().indices()) + assert torch.allclose(batch_list[i][key].coalesce().values(), elems[i][key].coalesce().values()) + assert batch_list[i][key].shape, elems[i][key].shape + else: + assert torch.allclose(batch_list[i][key], elems[i][key]) \ No newline at end of file diff --git a/test/transforms/feature_liftings/test_ConcatenationLifting.py b/test/transforms/feature_liftings/test_ConcatenationLifting.py index 87deddef..6cf9122c 100644 --- a/test/transforms/feature_liftings/test_ConcatenationLifting.py +++ b/test/transforms/feature_liftings/test_ConcatenationLifting.py @@ -6,7 +6,9 @@ from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( ConcatentionLifting, ) -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestConcatentionLifting: diff --git a/test/transforms/feature_liftings/test_SetLifting.py b/test/transforms/feature_liftings/test_SetLifting.py index a36bf62c..145c3eb8 100644 --- a/test/transforms/feature_liftings/test_SetLifting.py +++ b/test/transforms/feature_liftings/test_SetLifting.py @@ -3,8 +3,12 @@ import torch from topobenchmarkx.io.load.loaders import manual_simple_graph -from topobenchmarkx.transforms.feature_liftings.feature_liftings import SetLifting -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( + SetLifting, +) +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestSetLifting: diff --git a/test/transforms/liftings/cell/test_CellCyclesLifting.py b/test/transforms/liftings/cell/test_CellCyclesLifting.py index 364152d7..37f9005d 100644 --- a/test/transforms/liftings/cell/test_CellCyclesLifting.py +++ b/test/transforms/liftings/cell/test_CellCyclesLifting.py @@ -22,14 +22,126 @@ def test_lift_topology(self): expected_incidence_1 = torch.tensor( [ - [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + ], ] ) diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py index 384bf849..a961709d 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py @@ -3,7 +3,9 @@ import torch from topobenchmarkx.io.load.loaders import manual_graph -from topobenchmarkx.transforms.liftings.graph2hypergraph import HypergraphKHopLifting +from topobenchmarkx.transforms.liftings.graph2hypergraph import ( + HypergraphKHopLifting, +) class TestHypergraphKHopLifting: @@ -38,7 +40,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k1.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k1.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=1)." assert ( expected_n_hyperedges == lifted_data_k1.num_hyperedges @@ -63,7 +66,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k2.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k2.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=2)." assert ( expected_n_hyperedges == lifted_data_k2.num_hyperedges diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py index 01b4ce9a..0807dd9c 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py @@ -16,8 +16,12 @@ def setup_method(self): self.data = manual_graph() # Initialise the HypergraphKNearestNeighborsLifting class - self.lifting_k2 = HypergraphKNearestNeighborsLifting(k_value=2, loop=True) - self.lifting_k3 = HypergraphKNearestNeighborsLifting(k_value=3, loop=True) + self.lifting_k2 = HypergraphKNearestNeighborsLifting( + k_value=2, loop=True + ) + self.lifting_k3 = HypergraphKNearestNeighborsLifting( + k_value=3, loop=True + ) def test_lift_topology(self): # Test the lift_topology method @@ -40,7 +44,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k2.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k2.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=2)." assert ( expected_n_hyperedges == lifted_data_k2.num_hyperedges @@ -65,7 +70,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k3.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k3.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=3)." assert ( expected_n_hyperedges == lifted_data_k3.num_hyperedges diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index d9ff1f18..8b23334c 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -3,7 +3,9 @@ import torch from topobenchmarkx.io.load.loaders import manual_simple_graph -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestSimplicialCliqueLifting: @@ -14,8 +16,12 @@ def setup_method(self): self.data = manual_simple_graph() # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialCliqueLifting(complex_dim=3, signed=True) - self.lifting_unsigned = SimplicialCliqueLifting(complex_dim=3, signed=False) + self.lifting_signed = SimplicialCliqueLifting( + complex_dim=3, signed=True + ) + self.lifting_unsigned = SimplicialCliqueLifting( + complex_dim=3, signed=False + ) def test_lift_topology(self): """Test the lift_topology method.""" @@ -26,20 +32,135 @@ def test_lift_topology(self): expected_incidence_1 = torch.tensor( [ - [-1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, -1.0, -1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [ + -1.0, + -1.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + -1.0, + -1.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + ], ] ) assert ( - abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense() - ).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)." + abs(expected_incidence_1) + == lifted_data_unsigned.incidence_1.to_dense() + ).all(), ( + "Something is wrong with unsigned incidence_1 (nodes to edges)." + ) assert ( expected_incidence_1 == lifted_data_signed.incidence_1.to_dense() ).all(), "Something is wrong with signed incidence_1 (nodes to edges)." @@ -63,26 +184,26 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense() + abs(expected_incidence_2) + == lifted_data_unsigned.incidence_2.to_dense() ).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)." assert ( expected_incidence_2 == lifted_data_signed.incidence_2.to_dense() - ).all(), "Something is wrong with signed incidence_2 (edges to triangles)." + ).all(), ( + "Something is wrong with signed incidence_2 (edges to triangles)." + ) expected_incidence_3 = torch.tensor( [[-1.0], [1.0], [-1.0], [0.0], [1.0], [0.0]] ) assert ( - abs(expected_incidence_3) == lifted_data_unsigned.incidence_3.to_dense() - ).all(), ( - "Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)." - ) + abs(expected_incidence_3) + == lifted_data_unsigned.incidence_3.to_dense() + ).all(), "Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)." assert ( expected_incidence_3 == lifted_data_signed.incidence_3.to_dense() - ).all(), ( - "Something is wrong with signed incidence_3 (triangles to tetrahedrons)." - ) + ).all(), "Something is wrong with signed incidence_3 (triangles to tetrahedrons)." def test_lifted_features_signed(self): """Test the lift_features method in signed incidence cases.""" @@ -112,7 +233,9 @@ def test_lifted_features_signed(self): expected_features_1 == lifted_data.x_1 ).all(), "Something is wrong with x_1 features." - expected_features_2 = torch.tensor([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]) + expected_features_2 = torch.tensor( + [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]] + ) assert ( expected_features_2 == lifted_data.x_2 diff --git a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py index 2ea913e2..ac07a745 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py @@ -16,7 +16,9 @@ def setup_method(self): self.data = manual_simple_graph() # Initialise the SimplicialNeighborhoodLifting class - self.lifting_signed = SimplicialNeighborhoodLifting(complex_dim=3, signed=True) + self.lifting_signed = SimplicialNeighborhoodLifting( + complex_dim=3, signed=True + ) self.lifting_unsigned = SimplicialNeighborhoodLifting( complex_dim=3, signed=False ) @@ -247,8 +249,11 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense() - ).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)." + abs(expected_incidence_1) + == lifted_data_unsigned.incidence_1.to_dense() + ).all(), ( + "Something is wrong with unsigned incidence_1 (nodes to edges)." + ) assert ( expected_incidence_1 == lifted_data_signed.incidence_1.to_dense() ).all(), "Something is wrong with signed incidence_1 (nodes to edges)." @@ -1309,11 +1314,14 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense() + abs(expected_incidence_2) + == lifted_data_unsigned.incidence_2.to_dense() ).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)." assert ( expected_incidence_2 == lifted_data_signed.incidence_2.to_dense() - ).all(), "Something is wrong with signed incidence_2 (edges to triangles)." + ).all(), ( + "Something is wrong with signed incidence_2 (edges to triangles)." + ) def test_lifted_features_signed(self): # Test the lift_features method for signed case diff --git a/topobenchmarkx/=0.12.10 b/topobenchmarkx/=0.12.10 deleted file mode 100644 index 3afb89fd..00000000 --- a/topobenchmarkx/=0.12.10 +++ /dev/null @@ -1,44 +0,0 @@ -Collecting wandb - Downloading wandb-0.16.6-py3-none-any.whl.metadata (10 kB) -Collecting Click!=8.0.0,>=7.1 (from wandb) - Using cached click-8.1.7-py3-none-any.whl.metadata (3.0 kB) -Collecting GitPython!=3.1.29,>=1.0.0 (from wandb) - Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB) -Requirement already satisfied: requests<3,>=2.0.0 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (2.31.0) -Requirement already satisfied: psutil>=5.0.0 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (5.9.8) -Collecting sentry-sdk>=1.0.0 (from wandb) - Downloading sentry_sdk-2.1.1-py2.py3-none-any.whl.metadata (10 kB) -Collecting docker-pycreds>=0.4.0 (from wandb) - Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB) -Requirement already satisfied: PyYAML in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (6.0.1) -Collecting setproctitle (from wandb) - Using cached setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB) -Requirement already satisfied: setuptools in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (68.2.2) -Collecting appdirs>=1.4.3 (from wandb) - Using cached appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB) -Collecting protobuf!=4.21.0,<5,>=3.19.0 (from wandb) - Using cached protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes) -Requirement already satisfied: six>=1.4.0 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0) -Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb) - Using cached gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB) -Requirement already satisfied: charset-normalizer<4,>=2 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (3.3.2) -Requirement already satisfied: idna<4,>=2.5 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (3.7) -Requirement already satisfied: urllib3<3,>=1.21.1 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (2.2.1) -Requirement already satisfied: certifi>=2017.4.17 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (2024.2.2) -Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) - Using cached smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB) -Downloading wandb-0.16.6-py3-none-any.whl (2.2 MB) - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 35.5 MB/s eta 0:00:00 -Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB) -Using cached click-8.1.7-py3-none-any.whl (97 kB) -Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB) -Downloading GitPython-3.1.43-py3-none-any.whl (207 kB) - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.3/207.3 kB 59.8 MB/s eta 0:00:00 -Using cached protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl (294 kB) -Downloading sentry_sdk-2.1.1-py2.py3-none-any.whl (277 kB) - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 277.3/277.3 kB 66.9 MB/s eta 0:00:00 -Using cached setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB) -Using cached gitdb-4.0.11-py3-none-any.whl (62 kB) -Using cached smmap-5.0.1-py3-none-any.whl (24 kB) -Installing collected packages: appdirs, smmap, setproctitle, sentry-sdk, protobuf, docker-pycreds, Click, gitdb, GitPython, wandb -Successfully installed Click-8.1.7 GitPython-3.1.43 appdirs-1.4.4 docker-pycreds-0.4.0 gitdb-4.0.11 protobuf-4.25.3 sentry-sdk-2.1.1 setproctitle-1.3.3 smmap-5.0.1 wandb-0.16.6 diff --git a/topobenchmarkx/__init__.py b/topobenchmarkx/__init__.py index 895b4d03..e1e887f3 100755 --- a/topobenchmarkx/__init__.py +++ b/topobenchmarkx/__init__.py @@ -1,3 +1,23 @@ + +# Import submodules +from . import data +from . import evaluators +from . import io +from . import models +from . import transforms +from . import utils + +__all__ = [ + "data", + "evaluators", + "hp_scripts", + "io", + "models", + "transforms", + "utils", +] + + __version__ = "0.0.1" # from .io import * diff --git a/topobenchmarkx/data/dataloader_fullbatch.py b/topobenchmarkx/data/dataloader_fullbatch.py deleted file mode 100755 index 2ab9124d..00000000 --- a/topobenchmarkx/data/dataloader_fullbatch.py +++ /dev/null @@ -1,347 +0,0 @@ -from collections import defaultdict -from typing import Any, Optional - -import torch -from lightning import LightningDataModule -from torch.utils.data import DataLoader -from torch_geometric.data import Batch, Data -from torch_geometric.utils import is_sparse -from torch_sparse import SparseTensor - - -class MyData(Data): - """ - Data object class that overwrites some methods from torch_geometric.data.Data so that not only sparse matrices with adj in the name can work with the torch_geometric dataloaders. - """ - def is_valid(self, string): - valid_names = ["adj", "incidence", "laplacian"] - return any(name in string for name in valid_names) - - def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: - if is_sparse(value) and self.is_valid(key): - return (0, 1) - elif "index" in key or key == "face": - return -1 - else: - return 0 - - -def to_data_list(batch): - """ - Workaround needed since torch_geometric doesn't work well with torch.sparse - """ - for key in batch: - if batch[key].is_sparse: - sparse_data = batch[key].coalesce() - batch[key] = SparseTensor.from_torch_sparse_coo_tensor(sparse_data) - data_list = batch.to_data_list() - for i, data in enumerate(data_list): - for key in data: - if isinstance(data[key], SparseTensor): - data_list[i][key] = data[key].to_torch_sparse_coo_tensor() - return data_list - - -def collate_fn(batch): - """ - args: - batch - list of (tensor, label) - - return: - xs - a tensor of all examples in 'batch' after padding - ys - a LongTensor of all labels in batch - """ - data_list = [] - batch_idx_dict = defaultdict(list) - - # Keep track of the running index for each cell dimension - running_idx = {} - - for batch_idx, b in enumerate(batch): - values, keys = b[0], b[1] - data = MyData() - for key, value in zip(keys, values, strict=False): - if is_sparse(value): - value = value.coalesce() - data[key] = value - - # Generate batch_slice values for x_2, x_3, ... - x_keys = [el for el in keys if ("x_" in el)] - for x_key in x_keys: - # current_number_of_nodes = data["x_0"].shape[0] - - if x_key != "x_0" and x_key != "x_hyperedges": - cell_dim = int(x_key.split("_")[1]) - current_number_of_cells = data[x_key].shape[0] - - batch_idx_dict[f"batch_{cell_dim}"].append( - torch.tensor([[batch_idx] * current_number_of_cells]) - ) - - if running_idx.get(f"cell_running_idx_number_{cell_dim}") is None: - running_idx[f"cell_running_idx_number_{cell_dim}"] = ( - current_number_of_cells # current_number_of_nodes - ) - else: - # Make sure the idx is contiguous - data[f"x_{cell_dim}"] = ( - data[f"x_{cell_dim}"] - + running_idx[f"cell_running_idx_number_{cell_dim}"] - ).long() - - running_idx[ - f"cell_running_idx_number_{cell_dim}" - ] += current_number_of_cells # current_number_of_nodes - - elif x_key == "x_hyperedges": - cell_dim = x_key.split("_")[1] - current_number_of_hyperedges = data[x_key].shape[0] - - batch_idx_dict["batch_hyperedges"].append( - torch.tensor([[batch_idx] * current_number_of_hyperedges]) - ) - - if running_idx.get(f"cell_running_idx_number_{cell_dim}") is None: - running_idx[f"cell_running_idx_number_{cell_dim}"] = ( - current_number_of_hyperedges - ) - else: - # Make sure the idx is contiguous - data[f"x_{cell_dim}"] = ( - data[f"x_{cell_dim}"] - + running_idx[f"cell_running_idx_number_{cell_dim}"] - ).long() - - running_idx[ - f"cell_running_idx_number_{cell_dim}" - ] += current_number_of_hyperedges - else: - # Function Batch.from_data_list creates a running index automatically - pass - - data_list.append(data) - - batch = Batch.from_data_list(data_list) - - # Rename batch.batch to batch.batch_0 for consistency - batch["batch_0"] = batch.pop("batch") - - # Add batch slices to batch - for key, value in batch_idx_dict.items(): - batch[key] = torch.cat(value, dim=1).squeeze(0).long() - return batch - - -# class FullBatchDataModule(LightningDataModule): -# """ - -# Read the docs: -# https://lightning.ai/docs/pytorch/latest/data/datamodule.html -# """ - -# def __init__( -# self, -# dataset, -# batch_size: int = 64, -# num_workers: int = 0, -# pin_memory: bool = False, -# ) -> None: -# """Initialize a `MNISTDataModule`. - -# :param data_dir: The data directory. Defaults to `"data/"`. -# :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. -# :param batch_size: The batch size. Defaults to `64`. -# :param num_workers: The number of workers. Defaults to `0`. -# :param pin_memory: Whether to pin memory. Defaults to `False`. -# """ -# super().__init__() - -# # this line allows to access init params with 'self.hparams' attribute -# # also ensures init params will be stored in ckpt -# self.save_hyperparameters(logger=False) - -# self.dataset = dataset -# self.batch_size = batch_size - -# def train_dataloader(self) -> DataLoader: -# """Create and return the train dataloader. - -# :return: The train dataloader. -# """ -# return DataLoader( -# dataset=self.dataset, -# batch_size=1, -# num_workers=self.hparams.num_workers, -# pin_memory=self.hparams.pin_memory, -# # persistent_workers=True, -# shuffle=True, -# collate_fn=collate_fn, -# ) - -# def val_dataloader(self) -> DataLoader: -# """Create and return the validation dataloader. - -# :return: The validation dataloader. -# """ -# return DataLoader( -# dataset=self.dataset, -# batch_size=1, -# num_workers=self.hparams.num_workers, -# pin_memory=self.hparams.pin_memory, -# # persistent_workers=True, -# shuffle=False, -# collate_fn=collate_fn, -# ) - -# def test_dataloader(self) -> DataLoader: -# """Create and return the test dataloader. - -# :return: The test dataloader. -# """ -# return DataLoader( -# dataset=self.dataset, -# batch_size=1, -# num_workers=self.hparams.num_workers, -# pin_memory=self.hparams.pin_memory, -# # persistent_workers=True, -# shuffle=False, -# collate_fn=collate_fn, -# ) - -# def teardown(self, stage: Optional[str] = None) -> None: -# """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, -# `trainer.test()`, and `trainer.predict()`. - -# :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. -# Defaults to ``None``. -# """ - -# def state_dict(self) -> dict[Any, Any]: -# """Called when saving a checkpoint. Implement to generate and save the datamodule state. - -# :return: A dictionary containing the datamodule state that you want to save. -# """ -# return {} - -# def load_state_dict(self, state_dict: dict[str, Any]) -> None: -# """Called when loading a checkpoint. Implement to reload datamodule state given datamodule -# `state_dict()`. - -# :param state_dict: The datamodule state returned by `self.state_dict()`. -# """ - - -class DefaultDataModule(LightningDataModule): - """ - Initializes the DefaultDataModule class. - - Args: - dataset_train: The training dataset. - dataset_val: The validation dataset (optional). - dataset_test: The test dataset (optional). - batch_size: The batch size for the dataloader. - num_workers: The number of worker processes to use for data loading. - pin_memory: If True, the data loader will copy tensors into pinned memory before returning them. - - Returns: - None - - Read the docs: - https://lightning.ai/docs/pytorch/latest/data/datamodule.html - """ - - def __init__( - self, - dataset_train, - dataset_val=None, - dataset_test=None, - batch_size=1, - num_workers: int = 0, - pin_memory: bool = False, - ) -> None: - super().__init__() - - # this line allows to access init params with 'self.hparams' attribute - # also ensures init params will be stored in ckpt - self.save_hyperparameters(logger=False, ignore=["dataset_train", "dataset_val", "dataset_test"]) - - - self.dataset_train = dataset_train - self.batch_size = batch_size - - if dataset_val is None and dataset_test is None: - # Transductive setting - self.dataset_val = dataset_train - self.dataset_test = dataset_train - assert ( - self.batch_size == 1 - ), "Batch size must be 1 for transductive setting." - else: - self.dataset_val = dataset_val - self.dataset_test = dataset_test - - def train_dataloader(self) -> DataLoader: - """Create and return the train dataloader. - - :return: The train dataloader. - """ - return DataLoader( - dataset=self.dataset_train, - batch_size=self.batch_size, - num_workers=self.hparams.num_workers, - pin_memory=self.hparams.pin_memory, - shuffle=True, - collate_fn=collate_fn, - ) - - def val_dataloader(self) -> DataLoader: - """Create and return the validation dataloader. - - :return: The validation dataloader. - """ - return DataLoader( - dataset=self.dataset_val, - batch_size=self.batch_size, - num_workers=self.hparams.num_workers, - pin_memory=self.hparams.pin_memory, - shuffle=False, - collate_fn=collate_fn, - ) - - def test_dataloader(self) -> DataLoader: - """Create and return the test dataloader. - - :return: The test dataloader. - """ - if self.dataset_test is None: - raise ValueError("There is no test dataloader.") - return DataLoader( - dataset=self.dataset_test, - batch_size=self.batch_size, - num_workers=self.hparams.num_workers, - pin_memory=self.hparams.pin_memory, - shuffle=False, - collate_fn=collate_fn, - ) - - def teardown(self, stage: Optional[str] = None) -> None: - """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, - `trainer.test()`, and `trainer.predict()`. - - :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. - Defaults to ``None``. - """ - - def state_dict(self) -> dict[Any, Any]: - """Called when saving a checkpoint. Implement to generate and save the datamodule state. - - :return: A dictionary containing the datamodule state that you want to save. - """ - return {} - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Called when loading a checkpoint. Implement to reload datamodule state given datamodule - `state_dict()`. - - :param state_dict: The datamodule state returned by `self.state_dict()`. - """ diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py new file mode 100755 index 00000000..d474a7bf --- /dev/null +++ b/topobenchmarkx/data/dataloaders.py @@ -0,0 +1,240 @@ +from collections import defaultdict +from typing import Any + +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader +from torch_geometric.data import Batch, Data +from torch_geometric.utils import is_sparse +from torch_sparse import SparseTensor + + +class DomainData(Data): + r"""Data object class that overwrites some methods from + `torch_geometric.data.Data` so that not only sparse matrices with adj in the + name can work with the `torch_geometric` dataloaders.""" + + def is_valid(self, string): + r"""Check if the string contains any of the valid names.""" + valid_names = ["adj", "incidence", "laplacian"] + return any(name in string for name in valid_names) + + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: + r"""Overwrite the `__cat_dim__` method to handle sparse matrices to handle the names specified in `is_valid`.""" + if is_sparse(value) and self.is_valid(key): + return (0, 1) + elif "index" in key or key == "face": + return -1 + else: + return 0 + + +def to_data_list(batch): + """Workaround needed since `torch_geometric` doesn't work when using `torch.sparse` instead of `torch_sparse`.""" + for key in batch.keys(): + if batch[key].is_sparse: + sparse_data = batch[key].coalesce() + batch[key] = SparseTensor.from_torch_sparse_coo_tensor(sparse_data) + data_list = batch.to_data_list() + for i, data in enumerate(data_list): + for key, d in data: + if isinstance(data[key], SparseTensor): + data_list[i][key] = d.to_torch_sparse_coo_tensor() + return data_list + + +def collate_fn(batch): + r"""This function overwrites the `torch_geometric.data.DataLoader` collate function to use the `DomainData` class. This ensures that the `torch_geometric` dataloaders work with sparse matrices that are not necessarily named `adj`. The function also generates the batch slices for the different cell dimensions. + + Args: + batch (list): List of data objects (e.g., `torch_geometric.data.Data`). + + Returns: + torch_geometric.data.Batch: A `torch_geometric.data.Batch` object. + """ + data_list = [] + batch_idx_dict = defaultdict(list) + + # Keep track of the running index for each cell dimension + running_idx = {} + + for batch_idx, b in enumerate(batch): + values, keys = b[0], b[1] + data = DomainData() + for key, value in zip(keys, values, strict=False): + if is_sparse(value): + value = value.coalesce() + data[key] = value + + # Generate batch_slice values for x_1, x_2, x_3, ... + x_keys = [el for el in keys if ("x_" in el)] + for x_key in x_keys: + if x_key != "x_0": + if x_key != "x_hyperedges": + cell_dim = int(x_key.split("_")[1]) + else: + cell_dim = x_key.split("_")[1] + + current_number_of_cells = data[x_key].shape[0] + + batch_idx_dict[f"batch_{cell_dim}"].append( + torch.tensor([[batch_idx] * current_number_of_cells]) + ) + + if ( + running_idx.get(f"cell_running_idx_number_{cell_dim}") + is None + ): + running_idx[f"cell_running_idx_number_{cell_dim}"] = ( + current_number_of_cells + ) + + else: + running_idx[f"cell_running_idx_number_{cell_dim}"] += ( + current_number_of_cells + ) + + data_list.append(data) + + batch = Batch.from_data_list(data_list) + + # Rename batch.batch to batch.batch_0 for consistency + batch["batch_0"] = batch.pop("batch") + + # Add batch slices to batch + for key, value in batch_idx_dict.items(): + batch[key] = torch.cat(value, dim=1).squeeze(0).long() + + # Ensure shape is torch.Tensor + # "shape" describes the number of n_cells in each graph + if batch.get("shape") is not None: + cell_statistics = batch.pop("shape") + batch["cell_statistics"] = torch.Tensor(cell_statistics).long() + + return batch + + +class DefaultDataModule(LightningDataModule): + r"""This class takes care of returning the dataloaders for the training, validation, and test datasets. It also handles the collate function. The class is designed to work with the `torch` dataloaders. + + Args: + dataset_train (CustomDataset): The training dataset. + dataset_val (CustomDataset, optional): The validation dataset. (default: None) + dataset_test (CustomDataset, optional): The test dataset. (default: None) + batch_size (int, optional): The batch size for the dataloader. (default: 1) + num_workers (int, optional): The number of worker processes to use for data loading. (default: 0) + pin_memory (bool, optional): If True, the data loader will copy tensors into pinned memory before returning them. (default: False) + + Returns: + None + + Read the docs: + https://lightning.ai/docs/pytorch/latest/data/datamodule.html + """ + + def __init__( + self, + dataset_train, + dataset_val=None, + dataset_test=None, + batch_size=1, + num_workers: int = 0, + pin_memory: bool = False, + ) -> None: + super().__init__() + + # this line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters( + logger=False, + ignore=["dataset_train", "dataset_val", "dataset_test"], + ) + + self.dataset_train = dataset_train + self.batch_size = batch_size + + if dataset_val is None and dataset_test is None: + # Transductive setting + self.dataset_val = dataset_train + self.dataset_test = dataset_train + assert ( + self.batch_size == 1 + ), "Batch size must be 1 for transductive setting." + else: + self.dataset_val = dataset_val + self.dataset_test = dataset_test + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(dataset_train={self.dataset_train}, dataset_val={self.dataset_val}, dataset_test={self.dataset_test}, batch_size={self.batch_size})" + + def train_dataloader(self) -> DataLoader: + r"""Create and return the train dataloader. + + Returns: + torch.utils.data.DataLoader: The train dataloader. + """ + return DataLoader( + dataset=self.dataset_train, + batch_size=self.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=True, + collate_fn=collate_fn, + ) + + def val_dataloader(self) -> DataLoader: + r"""Create and return the validation dataloader. + + Returns: + torch.utils.data.DataLoader: The validation dataloader. + """ + return DataLoader( + dataset=self.dataset_val, + batch_size=self.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=collate_fn, + ) + + def test_dataloader(self) -> DataLoader: + r"""Create and return the test dataloader. + + Returns: + torch.utils.data.DataLoader: The test dataloader. + """ + if self.dataset_test is None: + raise ValueError("There is no test dataloader.") + return DataLoader( + dataset=self.dataset_test, + batch_size=self.batch_size, + num_workers=self.hparams.num_workers, + pin_memory=self.hparams.pin_memory, + shuffle=False, + collate_fn=collate_fn, + ) + + def teardown(self, stage: str | None = None) -> None: + r"""Lightning hook for cleaning up after `trainer.fit()`, + `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. + + Args: + stage (str, optional): The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. (default: None) + """ + + def state_dict(self) -> dict[Any, Any]: + r"""Called when saving a checkpoint. Implement to generate and save the + datamodule state. + + Returns: + dict: A dictionary containing the datamodule state that you want to save. + """ + return {} + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r"""Called when loading a checkpoint. Implement to reload datamodule + state given datamodule `state_dict()`. + + Args: + state_dict (dict): The datamodule state. This is the object returned by `state_dict()`. + """ diff --git a/topobenchmarkx/data/datasets.py b/topobenchmarkx/data/datasets.py index 9d3c77ed..d8d5b1bb 100644 --- a/topobenchmarkx/data/datasets.py +++ b/topobenchmarkx/data/datasets.py @@ -4,75 +4,34 @@ class CustomDataset(torch_geometric.data.Dataset): r"""Custom dataset to return all the values added to the dataset object. - Parameters - ---------- - data_lst: list - List of torch_geometric.data.Data objects . + Args: + data_lst (list[torch_geometric.data.Data]): List of torch_geometric.data.Data objects. """ + def __init__(self, data_lst): super().__init__() self.data_lst = data_lst + def __repr__(self): + return f"{self.__class__.__name__}(data_lst={self.data_lst})" + def get(self, idx): r"""Get data object from data list. - Parameters - ---------- - idx: int - Index of the data object to get. + Args: + idx (int): Index of the data object to get. - Returns - ------- - tuple - tuple containing a list of all the values for the data and the keys corresponding to the values. + Returns: + tuple: tuple containing a list of all the values for the data and the corresponding keys. """ data = self.data_lst[idx] keys = list(data.keys()) return ([data[key] for key in keys], keys) def len(self): - r"""Return length of the dataset. - Returns - ------- - int - Length of the dataset. - """ - return len(self.data_lst) - - -class TorchGeometricDataset(torch_geometric.data.Dataset): - r"""Dataset to work with a list of data objects. - - Parameters - ---------- - data_lst: list - List of torch_geometric.data.Data objects . - """ - def __init__(self, data_lst): - super().__init__() - self.data_lst = data_lst - - def get(self, idx): - r"""Get data object from data list. - - Parameters - ---------- - idx: int - Index of the data object to get. + r"""Return the length of the dataset. - Returns - ------- - torch_geometric.data.Data - Data object of corresponding index. - """ - data = self.data_lst[idx] - return data - - def len(self): - r"""Return length of the dataset. - Returns - ------- - int - Length of the dataset. + Returns: + int: Length of the dataset. """ return len(self.data_lst) diff --git a/topobenchmarkx/data/heteriphilic_dataset.py b/topobenchmarkx/data/heteriphilic_dataset.py index 8a729440..1ee51bcf 100644 --- a/topobenchmarkx/data/heteriphilic_dataset.py +++ b/topobenchmarkx/data/heteriphilic_dataset.py @@ -1,46 +1,45 @@ import os.path as osp from collections.abc import Callable -from typing import Optional, ClassVar +from typing import ClassVar import torch from omegaconf import DictConfig from torch_geometric.data import Data, InMemoryDataset -from torch_geometric.io import fs +#from torch_geometric.io import fs -from topobenchmarkx.io.load.heterophilic import download_hetero_datasets, load_heterophilic_data +from topobenchmarkx.io.load.heterophilic import ( + download_hetero_datasets, + load_heterophilic_data, +) from topobenchmarkx.io.load.split_utils import random_splitting class HeteroDataset(InMemoryDataset): - r""" - Dataset class for US County Demographics dataset. + r"""Dataset class for heterophilic datasets. Args: root (str): Root directory where the dataset will be saved. name (str): Name of the dataset. parameters (DictConfig): Configuration parameters for the dataset. - transform (Optional[Callable]): A function/transform that takes in an + transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. - The transform function is applied to the loaded data before saving it. - pre_transform (Optional[Callable]): A function/transform that takes in an + The transform function is applied to the loaded data before saving it. (default: None) + pre_transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. The pre_transform function is applied to the data before the transform - function is applied. - pre_filter (Optional[Callable]): A function that takes in an + function is applied. (default: None) + pre_filter (Callable, optional): A function that takes in an `torch_geometric.data.Data` object and returns a boolean value - indicating whether the data object should be included in the dataset. - force_reload (bool): If set to True, the dataset will be re-downloaded + indicating whether the data object should be included in the dataset. (default: None) + force_reload (bool, optional): If set to True, the dataset will be re-downloaded even if it already exists on disk. (default: True) - use_node_attr (bool): If set to True, the node attributes will be included + use_node_attr (bool, optional): If set to True, the node attributes will be included in the dataset. (default: False) - use_edge_attr (bool): If set to True, the edge attributes will be included + use_edge_attr (bool, optional): If set to True, the edge attributes will be included in the dataset. (default: False) Attributes: - URLS (dict): Dictionary containing the URLs for downloading the dataset. - FILE_FORMAT (dict): Dictionary containing the file formats for the dataset. RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset. - """ RAW_FILE_NAMES: ClassVar = {} @@ -50,31 +49,30 @@ def __init__( root: str, name: str, parameters: DictConfig, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - force_reload: bool = True, + transform: Callable | None = None, + pre_transform: Callable | None = None, + pre_filter: Callable | None = None, + #force_reload: bool = True, use_node_attr: bool = False, use_edge_attr: bool = False, ) -> None: - self.name = name #.replace("_", "-") + self.name = name # .replace("_", "-") self.parameters = parameters super().__init__( - root, transform, pre_transform, pre_filter, force_reload=force_reload + root, + transform, + pre_transform, + pre_filter, + #force_reload=force_reload, ) - # Step 3:Load the processed data - # After the data has been downloaded from source - # Then preprocessed to obtain x,y and saved into processed folder - # We can now load the processed data from processed folder - # Load the processed data - data, _, _ = fs.torch_load(self.processed_paths[0]) + data, _, _ = torch.load(self.processed_paths[0]) # Map the loaded data into data = Data.from_dict(data) - # Step 5: Create the splits and upload desired fold + # Create the splits and upload desired fold splits = random_splitting(data.y, parameters=self.parameters) # Assign train val test masks to the graph @@ -85,7 +83,9 @@ def __init__( # Assign data object to self.data, to make it be prodessed by Dataset class self.data, self.slices = self.collate([data]) - # Do not forget to take care of properties + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload})" + @property def raw_dir(self) -> str: return osp.join(self.root, self.name, "raw") @@ -97,37 +97,26 @@ def processed_dir(self) -> str: @property def processed_file_names(self) -> str: return "data.pt" - + @property def raw_file_names(self) -> list[str]: - """Spefify the downloaded raw fine name""" return [f"{self.name}.npz"] def download(self) -> None: - """ - Downloads the dataset from the specified URL and saves it to the raw directory. + r"""Downloads the dataset from the specified URL and saves it to the raw + directory. Raises: FileNotFoundError: If the dataset URL is not found. """ - - # Step 1: Download data from the source download_hetero_datasets(name=self.name, path=self.raw_dir) def process(self) -> None: - """ - Process the data for the dataset. + r"""Process the data for the dataset. - This method loads the US county demographics data, applies any pre-processing transformations if specified, + This method loads the heterophilic data, applies any pre-processing transformations if specified, and saves the processed data to the appropriate location. - - Returns: - None """ - data = load_heterophilic_data(name=self.name, path=self.raw_dir) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) - - def __repr__(self) -> str: - return f"{self.name}()" diff --git a/topobenchmarkx/data/us_county_demos_dataset.py b/topobenchmarkx/data/us_county_demos_dataset.py index 4cd162b3..a5d5dd15 100644 --- a/topobenchmarkx/data/us_county_demos_dataset.py +++ b/topobenchmarkx/data/us_county_demos_dataset.py @@ -1,47 +1,46 @@ +import os import os.path as osp from collections.abc import Callable -from typing import Optional, ClassVar - +from typing import ClassVar +import shutil import torch from omegaconf import DictConfig -from torch_geometric.data import Data, InMemoryDataset -from torch_geometric.io import fs +from torch_geometric.data import Data, InMemoryDataset, extract_zip +# from torch_geometric.io import fs -from topobenchmarkx.io.load.us_county_demos import load_us_county_demos from topobenchmarkx.io.load.download_utils import download_file_from_drive from topobenchmarkx.io.load.split_utils import random_splitting +from topobenchmarkx.io.load.us_county_demos import load_us_county_demos class USCountyDemosDataset(InMemoryDataset): - r""" - Dataset class for US County Demographics dataset. + r"""Dataset class for US County Demographics dataset. Args: root (str): Root directory where the dataset will be saved. name (str): Name of the dataset. parameters (DictConfig): Configuration parameters for the dataset. - transform (Optional[Callable]): A function/transform that takes in an + transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. - The transform function is applied to the loaded data before saving it. - pre_transform (Optional[Callable]): A function/transform that takes in an + The transform function is applied to the loaded data before saving it. (default: None) + pre_transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. The pre_transform function is applied to the data before the transform - function is applied. - pre_filter (Optional[Callable]): A function that takes in an + function is applied. (default: None) + pre_filter (Callable, optional): A function that takes in an `torch_geometric.data.Data` object and returns a boolean value - indicating whether the data object should be included in the dataset. - force_reload (bool): If set to True, the dataset will be re-downloaded + indicating whether the data object should be included in the dataset. (default: None) + force_reload (bool, optional): If set to True, the dataset will be re-downloaded even if it already exists on disk. (default: True) - use_node_attr (bool): If set to True, the node attributes will be included + use_node_attr (bool, optional): If set to True, the node attributes will be included in the dataset. (default: False) - use_edge_attr (bool): If set to True, the edge attributes will be included + use_edge_attr (bool, optional): If set to True, the edge attributes will be included in the dataset. (default: False) Attributes: URLS (dict): Dictionary containing the URLs for downloading the dataset. FILE_FORMAT (dict): Dictionary containing the file formats for the dataset. RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset. - """ URLS: ClassVar = { @@ -61,31 +60,30 @@ def __init__( root: str, name: str, parameters: DictConfig, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, - force_reload: bool = True, + transform: Callable | None = None, + pre_transform: Callable | None = None, + pre_filter: Callable | None = None, + #force_reload: bool = True, use_node_attr: bool = False, use_edge_attr: bool = False, ) -> None: self.name = name.replace("_", "-") self.parameters = parameters super().__init__( - root, transform, pre_transform, pre_filter, force_reload=force_reload + root, + transform, + pre_transform, + pre_filter, + #force_reload=force_reload, ) - # Step 3:Load the processed data - # After the data has been downloaded from source - # Then preprocessed to obtain x,y and saved into processed folder - # We can now load the processed data from processed folder - # Load the processed data - data, _, _ = fs.torch_load(self.processed_paths[0]) - + data, _ = torch.load(self.processed_paths[0]) + # Map the loaded data into - data = Data.from_dict(data) + data = Data.from_dict(data) if isinstance(data, dict) else data - # Step 5: Create the splits and upload desired fold + # Create the splits and upload desired fold splits = random_splitting(data.y, parameters=self.parameters) # Assign train val test masks to the graph data.train_mask = torch.from_numpy(splits["train"]) @@ -102,6 +100,15 @@ def __init__( # Assign data object to self.data, to make it be prodessed by Dataset class self.data, self.slices = self.collate([data]) + + # Make sure the dataset will be reloaded during next run + shutil.rmtree(self.raw_dir) + # Get parent dir of self.processed_paths[0] + processed_dir = os.path.abspath(os.path.join(self.processed_paths[0], os.pardir)) + shutil.rmtree(processed_dir) + + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload})" @property def raw_dir(self) -> str: @@ -113,16 +120,16 @@ def processed_dir(self) -> str: @property def raw_file_names(self) -> list[str]: - names = ["", f"_{self.parameters.year}"] - return [f"{self.name}_{name}.txt" for name in names] + #names = ["county", f"{self.parameters.year}"] + return [f"county_graph.csv", f"county_stats_{self.parameters.year}.csv"] @property def processed_file_names(self) -> str: return "data.pt" def download(self) -> None: - """ - Downloads the dataset from the specified URL and saves it to the raw directory. + r"""Downloads the dataset from the specified URL and saves it to the raw + directory. Raises: FileNotFoundError: If the dataset URL is not found. @@ -139,35 +146,32 @@ def download(self) -> None: file_format=self.file_format, ) - # Extract the downloaded file if it is compressed - fs.cp( - f"{self.raw_dir}/{self.name}.{self.file_format}", self.raw_dir, extract=True - ) - - # Move the etracted files to the datasets/domain/dataset_name/raw/ directory - for filename in fs.ls(osp.join(self.raw_dir, self.name)): - fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename))) - fs.rm(osp.join(self.raw_dir, self.name)) - - # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}' - fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}") + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = osp.join(folder, filename) + extract_zip(path, folder) + # Delete zip file + os.unlink(path) + #shutil.rmtree(path) + # Move files from osp.join(folder, self.name) to folder + for file in os.listdir(osp.join(folder, self.name)): + shutil.move(osp.join(folder, self.name, file), folder) + + # Delete osp.join(folder, self.name) dir + shutil.rmtree(osp.join(folder, self.name)) + def process(self) -> None: - """ - Process the data for the dataset. + r"""Process the data for the dataset. This method loads the US county demographics data, applies any pre-processing transformations if specified, and saves the processed data to the appropriate location. - - Returns: - None """ data = load_us_county_demos( - self.raw_dir, year=self.parameters.year, y_col=self.parameters.task_variable + self.raw_dir, + year=self.parameters.year, + y_col=self.parameters.task_variable, ) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) - - def __repr__(self) -> str: - return f"{self.name}()" diff --git a/topobenchmarkx/dataset_statistics.py b/topobenchmarkx/dataset_statistics.py new file mode 100755 index 00000000..a6a6dfe7 --- /dev/null +++ b/topobenchmarkx/dataset_statistics.py @@ -0,0 +1,245 @@ +import random +from typing import Any + +import hydra +import lightning as L +import numpy as np +import rootutils + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +import torch +from lightning import Callback, LightningModule, Trainer +from lightning.pytorch.loggers import Logger +from omegaconf import DictConfig, OmegaConf + +from topobenchmarkx.data.dataloaders import DefaultDataModule +from topobenchmarkx.utils import ( + RankedLogger, + extras, + get_metric_value, + instantiate_callbacks, + instantiate_loggers, + log_hyperparameters, + task_wrapper, +) + +from topobenchmarkx.utils.config_resolvers import ( + get_default_transform, + get_monitor_metric, + get_monitor_mode, + infer_in_channels, + infere_list_length, +) +import pandas as pd +import os +# ------------------------------------------------------------------------------------ # +# the setup_root above is equivalent to: +# - adding project root dir to PYTHONPATH +# (so you don't need to force user to install project as a package) +# (necessary before importing any local modules e.g. `from src import utils`) +# - setting up PROJECT_ROOT environment variable +# (which is used as a base for paths in "configs/paths/default.yaml") +# (this way all filepaths are the same no matter where you run the code) +# - loading environment variables from ".env" in root dir +# +# you can remove it if you: +# 1. either install project as a package or move entry files to project root dir +# 2. set `root_dir` to "." in "configs/paths/default.yaml" +# +# more info: https://github.com/ashleve/rootutils +# ------------------------------------------------------------------------------------ # + + +OmegaConf.register_new_resolver("get_default_transform", get_default_transform) +OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) +OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode) +OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels) +OmegaConf.register_new_resolver("infere_list_length", infere_list_length) +OmegaConf.register_new_resolver( + "parameter_multiplication", lambda x, y: int(int(x) * int(y)) +) + +torch.set_num_threads(1) +log = RankedLogger(__name__, rank_zero_only=True) + + + +def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: + """Trains the model. Can additionally evaluate on a testset, using best + weights obtained during training. + + This method is wrapped in optional @task_wrapper decorator, that controls + the behavior during failure. Useful for multiruns, saving info about the + crash, etc. + + :param cfg: A DictConfig configuration composed by Hydra. + :return: A tuple with metrics and dict with all instantiated objects. + """ + + # Set seed for random number generators in pytorch, numpy and python.random + # if cfg.get("seed"): + L.seed_everything(cfg.seed, workers=True) + # Seed for torch + torch.manual_seed(cfg.seed) + # Seed for numpy + np.random.seed(cfg.seed) + # Seed for python random + random.seed(cfg.seed) + + if cfg.model.model_domain == "cell": + cfg.dataset.transforms.graph2cell_lifting.max_cell_length=1000 + + # Instantiate and load dataset + dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False) + dataset = dataset.load() + + one_graph_flag = True + if cfg.dataset.parameters.batch_size != 1: + cfg.dataset.parameters.batch_size != 1 + one_graph_flag = False + + + log.info(f"Instantiating datamodule <{cfg.dataset._target_}>") + + if cfg.dataset.parameters.task_level == "node": + datamodule = DefaultDataModule(dataset_train=dataset) + + elif cfg.dataset.parameters.task_level == "graph": + datamodule = DefaultDataModule( + dataset_train=dataset[0], + dataset_val=dataset[1], + dataset_test=dataset[2], + batch_size=cfg.dataset.parameters.batch_size, + ) + + else: + raise ValueError("Invalid task_level") + + if one_graph_flag == True: + dataloaders = [datamodule.train_dataloader()] + else: + dataloaders = [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()] + + dict_collector = { + "num_hyperedges": 0, + "zero_cell": 0, + "one_cell": 0, + "two_cell": 0, + "three_cell": 0, + } + + cell_dict = { + "3":0, + "4":0, + "5":0, + "6":0, + "7":0, + "8":0, + "9":0, + "10":0, + "greater_than_10":0 + + } + + for loader in dataloaders: + for batch in loader: + if cfg.model.model_domain == "hypergraph": + dict_collector["zero_cell"] += batch.x.shape[0] + dict_collector["num_hyperedges"] += batch.x_hyperedges.shape[0] + + elif cfg.model.model_domain == "simplicial": + dict_collector["zero_cell"] += batch.x_0.shape[0] + dict_collector["one_cell"] +=batch.x_1.shape[0] + dict_collector["two_cell"] +=batch.x_2.shape[0] + dict_collector["three_cell"] += batch.x_3.shape[0] + + elif cfg.model.model_domain == "cell": + dict_collector["zero_cell"] += batch.x_0.shape[0] + dict_collector["one_cell"] += batch.x_1.shape[0] + dict_collector["two_cell"] += batch.x_2.shape[0] + cell_sizes, cell_counts = torch.unique(batch.incidence_2.to_dense().sum(0), return_counts=True) + cell_sizes = cell_sizes.long() + for i in range(len(cell_sizes)): + if cell_sizes[i].item() > 10: + cell_dict["greater_than_10"] += cell_counts[i].item() + else: + cell_dict[str(cell_sizes[i].item())] += cell_counts[i].item() + + # Get current working dir + filename = f"{cfg.paths['root_dir']}/tables/dataset_statistics.csv" + + dict_collector['dataset'] = cfg.dataset.parameters.data_name + dict_collector['domain'] = cfg.model.model_domain + + df = pd.DataFrame.from_dict(dict_collector, orient='index') + if not os.path.exists(filename) == True: + # Save to csv file such as methods .... is a header + df.T.to_csv(filename, header=True) + else: + # read csv file with deader + df_saved = pd.read_csv(filename, index_col=0) + # add new row + df_saved = df_saved._append(dict_collector, ignore_index=True) + # write to csv file + df_saved.to_csv(filename) + + if cfg.model.model_domain == "cell": + filename = f"{cfg.paths['root_dir']}/tables/cell_statistics.csv" + # Create a dict from two arrays + # cell_dict = dict(zip(cell_sizes.long().tolist(), cell_counts.long().tolist())) + # keys = list(cell_dict.keys()) + # for key in keys: + # cell_dict[str(key)] = cell_dict.pop(key) + + # # Check if there are cells size of which greater than 10 + # n_large_cells = 0 + # subset_keys = [key for key in sorted(cell_dict.keys()) if int(key) > 10] + + # for key in subset_keys: + # n_large_cells += cell_dict.pop(key) + + # cell_dict["greater_than_10"] = n_large_cells + + cell_dict['dataset'] = cfg.dataset.parameters.data_name + cell_dict['domain'] = cfg.model.model_domain + + df = pd.DataFrame.from_dict(cell_dict, orient='index') + if not os.path.exists(filename) == True: + # Save to csv file such as methods .... is a header + df.T.to_csv(filename, header=True) + else: + # read csv file with deader + df_saved = pd.read_csv(filename, index_col=0) + # add new row + df_saved = df_saved._append(df.T, ignore_index=True) + # write to csv file + df_saved.to_csv(filename) + + + return + + + +@hydra.main( + version_base="1.3", config_path="../configs", config_name="train.yaml" +) +def main(cfg: DictConfig) -> float | None: + """Main entry point for training. + + :param cfg: DictConfig configuration composed by Hydra. + :return: Optional[float] with optimized metric value. + """ + # apply extra utilities + # (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.) + extras(cfg) + + + train(cfg) + + + # return optimized metric + return + + +if __name__ == "__main__": + main() diff --git a/topobenchmarkx/eval.py b/topobenchmarkx/eval.py index 492a2c54..b5c1fd5b 100755 --- a/topobenchmarkx/eval.py +++ b/topobenchmarkx/eval.py @@ -5,7 +5,6 @@ from lightning import LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig - from src.utils import ( RankedLogger, extras, @@ -39,11 +38,13 @@ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: """Evaluates given checkpoint on a datamodule testset. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. + This method is wrapped in optional @task_wrapper decorator, that controls + the behavior during failure. Useful for multiruns, saving info about the + crash, etc. :param cfg: DictConfig configuration composed by Hydra. - :return: tuple[dict, dict] with metrics and dict with all instantiated objects. + :return: tuple[dict, dict] with metrics and dict with all instantiated + objects. """ assert cfg.ckpt_path @@ -82,7 +83,9 @@ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: return metric_dict, object_dict -@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") +@hydra.main( + version_base="1.3", config_path="../configs", config_name="eval.yaml" +) def main(cfg: DictConfig) -> None: """Main entry point for evaluation. diff --git a/topobenchmarkx/evaluators/comparisons.py b/topobenchmarkx/evaluators/comparisons.py index ee316f70..ac3980e9 100644 --- a/topobenchmarkx/evaluators/comparisons.py +++ b/topobenchmarkx/evaluators/comparisons.py @@ -2,41 +2,51 @@ from scipy.stats import friedmanchisquare, wilcoxon -def signed_ranks_test(result1, result2): - """ - Calculates the p-value for the Wilcoxon signed-rank test between the results of two models. +def signed_ranks_test(results_1, results_2): + r"""Calculates the p-value for the Wilcoxon signed-rank test between the + results of two models. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html - :param results: A 2xN numpy array with the results from the two models. N is the number of datasets over which the models have been tested on. - :return: The p-value of the test + Args: + results_1 (numpy.array): A numpy array with the results from the first model. N + is the number of datasets over which the models have been tested on. + results_2 (numpy.array): A numpy array with the results from the second model. Needs to have the same shape as results_1. + Returns: + float: The p-value of the test. """ - xs = result1 - result2 + xs = results_1 - results_2 return wilcoxon(xs[xs != 0])[1] def friedman_test(results): - """ - Calculates the p-value of the Friedman test between M models on N datasets. + r"""Calculates the p-value of the Friedman test between M models on N + datasets. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.friedmanchisquare.html - :param results: A MxN numpy array with the results of M models over N dataset - :return: The p-value of the test + Args: + results (numpy.array): A MxN numpy array with the results of M models. + Returns: + float: The p-value of the test. """ res = [r for r in results] return friedmanchisquare(*res)[1] def compare_models(results, p_limit=0.05, verbose=False): - """ - Compares different models. First it uses the Friedman test to check that the models are significantly different, then it uses pairwise comparisons to study the ranking of the models. - - :param results: A MxN numpy array with the results of M models over N dataset - :param p_limit: The limit below which a hypothesis is considered false - :param verbose: Whether to print the results of the tests or not - :return average_rank: The average ranks of the models - :return groups: List of lists with the groups of models that are statistically similar + """Compares different models. First it uses the Friedman test to check that + the models are significantly different, then it uses pairwise comparisons + to study the ranking of the models. + + Args: + results (numpy.array): A MxN numpy array with the results of M models + over N dataset. + p_limit (float, optional): The limit below which a hypothesis is considered false. (default: 0.05) + verbose (bool, optional): Whether to print the results of the tests or not. (default: False) + Returns: + numpy.array: The average ranks of the models + list: A list of lists with the indices of the models that are in the same group. The first group is the best one. """ M = results.shape[0] @@ -56,8 +66,12 @@ def compare_models(results, p_limit=0.05, verbose=False): model_idx = np.where(np.argsort(average_ranks) == i)[0][0] group = [model_idx] while i + idx < M: - next_model_idx = np.where(np.argsort(average_ranks) == i + idx)[0][0] - p = signed_ranks_test(results[model_idx, :], results[model_idx + idx, :]) + next_model_idx = np.where(np.argsort(average_ranks) == i + idx)[0][ + 0 + ] + p = signed_ranks_test( + results[model_idx, :], results[model_idx + idx, :] + ) if verbose: print( f"P-value for Wilcoxon test between models {model_idx} and {next_model_idx}: {p}" @@ -77,6 +91,7 @@ def compare_models(results, p_limit=0.05, verbose=False): [0.6, 0.65, 0.7, 0.9, 0.5, 0.552, 0.843, 0.78, 0.665, 0.876], ] ) + print("Signed ranks test:") print(signed_ranks_test(results[0, :], results[1, :])) results2 = np.array( [ @@ -85,6 +100,7 @@ def compare_models(results, p_limit=0.05, verbose=False): [0.1, 0.2, 0.2, 0.3, 0.4, 0.5, 0.22, 0.32, 0.11, 0.4], ] ) + print("Friedman test with very different results:") print(friedman_test(results2)) results3 = np.array( [ @@ -93,7 +109,12 @@ def compare_models(results, p_limit=0.05, verbose=False): [0.89, 0.91, 0.79, 0.81, 0.69], ] ) + print("Friedman test with similar results:") print(friedman_test(results3)) + print("-"*50) + print("Compare models with different results:") print(compare_models(results2, verbose=True)) - print(compare_models(results3)) + print("-"*50) + print("Compare models with similar results:") + print(compare_models(results3, verbose=True)) diff --git a/topobenchmarkx/evaluators/evaluator.py b/topobenchmarkx/evaluators/evaluator.py index aa5eb670..3f1771f5 100755 --- a/topobenchmarkx/evaluators/evaluator.py +++ b/topobenchmarkx/evaluators/evaluator.py @@ -6,25 +6,18 @@ class TorchEvaluator: - r"""Evaluator class that is responsible for computing the metrics for a given task. - - Parameters - ---------- - task : str - The task type. It can be either "classification" or "regression". - - **kwargs : - Additional arguments for the class. The arguments depend on the task. - In "classification" scenario, the following arguments are expected: - - num_classes : int - The number of classes. - - classification_metrics : list - A list of classification metrics to be computed. - - In "regression" scenario, the following arguments are expected: - - regression_metrics : list - A list of regression metrics to be computed. - + r"""Evaluator class that is responsible for computing the metrics for a + given task. + + Args: + task (str): The task type. It can be either "classification" or "regression". + **kwargs : Additional arguments for the class. The arguments depend on the task. + In "classification" scenario, the following arguments are expected: + - num_classes (int): The number of classes. + - classification_metrics (list[str]): A list of classification metrics to be computed. + + In "regression" scenario, the following arguments are expected: + - regression_metrics (list[str]): A list of regression metrics to be computed. """ def __init__(self, task, **kwargs): @@ -43,7 +36,7 @@ def __init__(self, task, **kwargs): elif self.task == "multilabel classification": parameters = {"num_classes": kwargs["num_classes"]} parameters["task"] = "multilabel" - metric_names = kwargs["classification_metrics"] + metric_names = kwargs["classification_metrics"] elif self.task == "regression": parameters = {} @@ -57,31 +50,28 @@ def __init__(self, task, **kwargs): # ) metrics = {} - for name in metric_names: + for name in metric_names: if name in ["recall", "precision", "auroc"]: - metrics[name] = METRICS[name](average='macro', **parameters) - + metrics[name] = METRICS[name](average="macro", **parameters) + else: metrics[name] = METRICS[name](**parameters) self.metrics = MetricCollection(metrics) - - - self.best_metric = {} + def __repr__(self) -> str: + return f"{self.__class__.__name__}(task={self.task}, metrics={self.metrics})" + def update(self, model_out: dict): - """Update the metrics with the model output. - - Parameters - ---------- - model_out : dict - The model output. It should contain the following keys: - - logits : torch.Tensor - The model predictions. - - labels : torch.Tensor - The ground truth labels. - + r"""Update the metrics with the model output. + + Args: + model_out (dict): The model output. It should contain the following keys: + - logits : torch.Tensor + The model predictions. + - labels : torch.Tensor + The ground truth labels. """ preds = model_out["logits"].cpu() target = model_out["labels"].cpu() @@ -96,27 +86,25 @@ def update(self, model_out: dict): raise ValueError(f"Invalid task {self.task}") def compute(self): - """Compute the metrics. + r"""Compute the metrics. - Returns - ------- - res_dict : dict - A dictionary containing the computed metrics. + Args: + res_dict (dict): A dictionary containing the computed metrics. """ + return self.metrics.compute() - res_dict = self.metrics.compute() - - return res_dict + def reset(self): + """Reset the metrics. - def reset( - self, - ): - """Reset the metrics. This method should be called after each epoch""" + This method should be called after each epoch + """ self.metrics.reset() if __name__ == "__main__": evaluator = TorchEvaluator( - task="classification", num_classes=3, classification_metrics=["accuracy"] + task="classification", + num_classes=3, + classification_metrics=["accuracy"], ) print(evaluator.task) diff --git a/topobenchmarkx/graph_search.sh b/topobenchmarkx/graph_search.sh deleted file mode 100644 index d8f49210..00000000 --- a/topobenchmarkx/graph_search.sh +++ /dev/null @@ -1,14 +0,0 @@ -# GCN -python train.py dataset=cocitation_cora model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun #tags "[first_tag, second_tag]" -python train.py dataset=cocitation_citeseer model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=cocitation_pubmed model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=PROTEINS_TU model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=NCI1 model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun - -# python train.py dataset=IMDB-BINARY model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -# python train.py dataset=IMDB-MULTI model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=MUTAG model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=32,64 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=ZINC model=graph/gcn model.optimizer.lr=0.01,0.001 model.optimizer.weight_decay=0 model.backbone.hidden_channels=16,32,64,128 model.backbone.num_layers=1,2,3,4 dataset.parameters.batch_size=128,256 dataset.parameters.data_seed=0 model.backbone.dropout=0,0.25,0.5 logger.wandb.project=topobenchmark_22Apr2024 callbacks.early_stopping.patience=10 trainer=default --multirun -python train.py dataset=REDDIT-BINARY model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=default --multirun - - diff --git a/topobenchmarkx/hp_scripts/simplicial/SCN.sh b/topobenchmarkx/hp_scripts/simplicial/SCN.sh deleted file mode 100644 index 63c9045e..00000000 --- a/topobenchmarkx/hp_scripts/simplicial/SCN.sh +++ /dev/null @@ -1,114 +0,0 @@ -# Create a logger file in the same repo to keep track of the experiments executed - -# SCN model - Fixed split -python train.py \ - dataset=ZINC \ - model=simplicial/scn \ - model.backbone.n_layers=1,2,4 \ - model.feature_encoder.out_channels=16,64 \ - model.optimizer.lr=0.01,0.001 \ - dataset.parameters.batch_size=128 \ - dataset.parameters.data_seed=0,3 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - callbacks.early_stopping.min_delta=0.005 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun - -# Batch size = 1 -python train.py \ - dataset=cocitation_cora \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun - -python train.py \ - dataset=cocitation_citeseer \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun - -python train.py \ - dataset=cocitation_pubmed \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun - -# Vary batch size -python train.py \ - dataset=PROTEINS_TU \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=16,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.batch_size=32 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun - -python train.py \ - dataset=NCI1 \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=16,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.batch_size=32 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun - -python train.py \ - dataset=MUTAG \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=16,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.batch_size=32 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun diff --git a/topobenchmarkx/io/load/download_utils.py b/topobenchmarkx/io/load/download_utils.py index f6c66f96..66c71159 100644 --- a/topobenchmarkx/io/load/download_utils.py +++ b/topobenchmarkx/io/load/download_utils.py @@ -5,8 +5,7 @@ # Function to extract file ID from Google Drive URL def get_file_id_from_url(url): - """ - Extracts the file ID from a Google Drive file URL. + r"""Extracts the file ID from a Google Drive file URL. Args: url (str): The Google Drive file URL. @@ -21,10 +20,14 @@ def get_file_id_from_url(url): query_params = parse_qs(parsed_url.query) if "id" in query_params: # Case 1: URL format contains '?id=' file_id = query_params["id"][0] - elif "file/d/" in parsed_url.path: # Case 2: URL format contains '/file/d/' + elif ( + "file/d/" in parsed_url.path + ): # Case 2: URL format contains '/file/d/' file_id = parsed_url.path.split("/")[3] else: - raise ValueError("The provided URL is not a valid Google Drive file URL.") + raise ValueError( + "The provided URL is not a valid Google Drive file URL." + ) return file_id @@ -32,8 +35,8 @@ def get_file_id_from_url(url): def download_file_from_drive( file_link, path_to_save, dataset_name, file_format="tar.gz" ): - """ - Downloads a file from a Google Drive link and saves it to the specified path. + r"""Downloads a file from a Google Drive link and saves it to the specified + path. Args: file_link (str): The Google Drive link of the file to download. diff --git a/topobenchmarkx/io/load/heterophilic.py b/topobenchmarkx/io/load/heterophilic.py index c7901130..67b06e4e 100644 --- a/topobenchmarkx/io/load/heterophilic.py +++ b/topobenchmarkx/io/load/heterophilic.py @@ -1,35 +1,55 @@ -import numpy as np import os +import urllib.request + +import numpy as np import torch import torch_geometric -import urllib.request def load_heterophilic_data(name, path): - file_name = f'{name}.npz' + r"""Load a heterophilic dataset from a .npz file. + + Args: + name (str): The name of the dataset. + path (str): The path to the directory containing the dataset file. + Returns: + torch_geometric.data.Data: The dataset. + """ + file_name = f"{name}.npz" data = np.load(os.path.join(path, file_name)) - x = torch.tensor(data['node_features']) - y = torch.tensor(data['node_labels']) - edge_index = torch.tensor(data['edges']).T + x = torch.tensor(data["node_features"]) + y = torch.tensor(data["node_labels"]) + edge_index = torch.tensor(data["edges"]).T # Make edge_index undirected edge_index = torch_geometric.utils.to_undirected(edge_index) # Remove self-loops edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) - + data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index) return data + def download_hetero_datasets(name, path): - url = 'https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/' - name = f'{name}.npz' + r"""Download a heterophilic dataset from the OpenGSL repository. + + Args: + name (str): The name of the dataset. + path (str): The path to the directory where the dataset will be saved. + Raises: + Exception: If the download fails. + """ + url = "https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/" + name = f"{name}.npz" try: - print(f'Downloading {name}') + print(f"Downloading {name}") path2save = os.path.join(path, name) urllib.request.urlretrieve(url + name, path2save) - print('Done!') + print("Done!") except Exception as e: - raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''') from e \ No newline at end of file + raise Exception( + """Download failed! Make sure you have stable Internet connection and enter the right name""" + ) from e diff --git a/topobenchmarkx/io/load/loader.py b/topobenchmarkx/io/load/loader.py index c0eb168e..1f2b7fb4 100755 --- a/topobenchmarkx/io/load/loader.py +++ b/topobenchmarkx/io/load/loader.py @@ -9,28 +9,23 @@ class AbstractLoader(ABC): """Abstract class that provides an interface to load data. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. """ def __init__(self, parameters: DictConfig): self.cfg = parameters + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.cfg})" + @abstractmethod def load( self, ) -> torch_geometric.data.Data: """Load data into Data. - Parameters - ---------- - None - - Returns - ------- - Data - Data object containing the loaded data. + Raise: + NotImplementedError: If the method is not implemented. """ raise NotImplementedError diff --git a/topobenchmarkx/io/load/loaders.py b/topobenchmarkx/io/load/loaders.py index 3cb9d0e5..5829cd9c 100755 --- a/topobenchmarkx/io/load/loaders.py +++ b/topobenchmarkx/io/load/loaders.py @@ -8,8 +8,8 @@ from omegaconf import DictConfig from topobenchmarkx.data.datasets import CustomDataset -from topobenchmarkx.data.us_county_demos_dataset import USCountyDemosDataset from topobenchmarkx.data.heteriphilic_dataset import HeteroDataset +from topobenchmarkx.data.us_county_demos_dataset import USCountyDemosDataset from topobenchmarkx.io.load.loader import AbstractLoader from topobenchmarkx.io.load.preprocessor import Preprocessor from topobenchmarkx.io.load.split_utils import ( @@ -28,29 +28,24 @@ class CellComplexLoader(AbstractLoader): r"""Loader for cell complex datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. """ def __init__(self, parameters: DictConfig): super().__init__(parameters) self.parameters = parameters + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters})" def load( self, ) -> CustomDataset: r"""Load cell complex dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = load_cell_complex_dataset(self.parameters) dataset = CustomDataset([data]) @@ -60,29 +55,24 @@ def load( class SimplicialLoader(AbstractLoader): r"""Loader for simplicial datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. """ def __init__(self, parameters: DictConfig): super().__init__(parameters) self.parameters = parameters + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters})" def load( self, ) -> CustomDataset: r"""Load simplicial dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = load_simplicial_dataset(self.parameters) dataset = CustomDataset([data]) @@ -92,30 +82,26 @@ def load( class HypergraphLoader(AbstractLoader): r"""Loader for hypergraph datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. + transforms (DictConfig, optional): The parameters for the transforms to be applied to the dataset. (default: None) """ def __init__(self, parameters: DictConfig, transforms=None): super().__init__(parameters) self.parameters = parameters self.transforms_config = transforms + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters}, transforms={self.transforms_config})" def load( self, ) -> CustomDataset: r"""Load hypergraph dataset. - - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = load_hypergraph_pickle_dataset(self.parameters) data = load_split(data, self.parameters) @@ -126,29 +112,38 @@ def load( class GraphLoader(AbstractLoader): r"""Loader for graph datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. The parameters must contain the following keys: + - data_dir (str): The directory where the dataset is stored. + - data_name (str): The name of the dataset. + - data_type (str): The type of the dataset. + - split_type (str): The type of split to be used. It can be "fixed", "random", or "k-fold". + + If split_type is "random", the parameters must also contain the following keys: + - data_seed (int): The seed for the split. + - data_split_dir (str): The directory where the split is stored. + - train_prop (float): The proportion of the training set. + If split_type is "k-fold", the parameters must also contain the following keys: + - data_split_dir (str): The directory where the split is stored. + - k (int): The number of folds. + - data_seed (int): The seed for the split. + The parameters can be defined in a yaml file and then loaded using `omegaconf.OmegaConf.load('path/to/dataset/config.yaml')`. + transforms (DictConfig, optional): The parameters for the transforms to be applied to the dataset. The parameters for a transformation can be defined in a yaml file and then loaded using `omegaconf.OmegaConf.load('path/to/transform/config.yaml'). (default: None) """ - def __init__(self, parameters: DictConfig, transforms=None): super().__init__(parameters) self.parameters = parameters # Still not instantiated self.transforms_config = transforms + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters}, transforms={self.transforms_config})" def load(self) -> CustomDataset: r"""Load graph dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data_dir = os.path.join( self.parameters["data_dir"], self.parameters["data_name"] @@ -163,7 +158,9 @@ def load(self) -> CustomDataset: name=self.parameters["data_name"], ) if self.transforms_config is not None: - dataset = Preprocessor(data_dir, dataset, self.transforms_config) + dataset = Preprocessor( + data_dir, dataset, self.transforms_config + ) dataset = load_graph_cocitation_split(dataset, self.parameters) @@ -184,16 +181,19 @@ def load(self) -> CustomDataset: use_node_attr=False, ) if self.transforms_config is not None: - dataset = Preprocessor(data_dir, dataset, self.transforms_config) + dataset = Preprocessor( + data_dir, dataset, self.transforms_config + ) dataset = load_graph_tudataset_split(dataset, self.parameters) elif self.parameters.data_name in ["ZINC"]: datasets = [ - torch_geometric.datasets.ZINC( - root=self.parameters["data_dir"], - subset=True, - split=split, - ) for split in ["train", "val", "test"] + torch_geometric.datasets.ZINC( + root=self.parameters["data_dir"], + subset=True, + split=split, + ) + for split in ["train", "val", "test"] ] assert self.parameters.split_type == "fixed" @@ -221,7 +221,9 @@ def load(self) -> CustomDataset: ) # Split back the into train/val/test datasets - dataset = assing_train_val_test_mask_to_graphs(joined_dataset, split_idx) + dataset = assing_train_val_test_mask_to_graphs( + joined_dataset, split_idx + ) elif self.parameters.data_name in ["AQSOL"]: datasets = [] @@ -256,7 +258,9 @@ def load(self) -> CustomDataset: ) # Split back the into train/val/test datasets - dataset = assing_train_val_test_mask_to_graphs(joined_dataset, split_idx) + dataset = assing_train_val_test_mask_to_graphs( + joined_dataset, split_idx + ) elif self.parameters.data_name in ["US-county-demos"]: dataset = USCountyDemosDataset( @@ -268,13 +272,22 @@ def load(self) -> CustomDataset: if self.transforms_config is not None: # force_reload=True because in this datasets many variables can be trated as y dataset = Preprocessor( - data_dir, dataset, self.transforms_config, force_reload=True + data_dir, + dataset, + self.transforms_config, + force_reload=True, ) # We need to map original dataset into custom one to make batching work dataset = CustomDataset([dataset[0]]) - - elif self.parameters.data_name in ["amazon_ratings", "questions", "minesweeper","roman_empire", "tolokers"]: + + elif self.parameters.data_name in [ + "amazon_ratings", + "questions", + "minesweeper", + "roman_empire", + "tolokers", + ]: dataset = HeteroDataset( root=self.parameters["data_dir"], name=self.parameters["data_name"], @@ -284,7 +297,10 @@ def load(self) -> CustomDataset: if self.transforms_config is not None: # force_reload=True because in this datasets many variables can be trated as y dataset = Preprocessor( - data_dir, dataset, self.transforms_config, force_reload=True + data_dir, + dataset, + self.transforms_config, + force_reload=False, ) # We need to map original dataset into custom one to make batching work @@ -301,10 +317,9 @@ def load(self) -> CustomDataset: class ManualGraphLoader(AbstractLoader): r"""Loader for manual graph datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. + transforms (DictConfig, optional): The parameters for the transforms to be applied to the dataset. (default: None) """ def __init__(self, parameters: DictConfig, transforms=None): @@ -312,18 +327,15 @@ def __init__(self, parameters: DictConfig, transforms=None): self.parameters = parameters # Still not instantiated self.transforms_config = transforms + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters}, transforms={self.transforms_config})" def load(self) -> CustomDataset: r"""Load manual graph dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = manual_graph() @@ -331,7 +343,9 @@ def load(self) -> CustomDataset: data_dir = os.path.join( self.parameters["data_dir"], self.parameters["data_name"] ) - processor_dataset = Preprocessor(data_dir, data, self.transforms_config) + processor_dataset = Preprocessor( + data_dir, data, self.transforms_config + ) dataset = CustomDataset([processor_dataset[0]]) return dataset @@ -367,7 +381,7 @@ def manual_graph(): for tetrahedron in tetrahedrons: for i in range(len(tetrahedron)): for j in range(i + 1, len(tetrahedron)): - edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 + edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 # Create a graph G = nx.Graph() @@ -381,7 +395,11 @@ def manual_graph(): edge_list = torch.Tensor(list(G.edges())).T.long() # Generate feature from 0 to 9 - x = torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000, 10000]).unsqueeze(1).float() + x = ( + torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000, 10000]) + .unsqueeze(1) + .float() + ) data = torch_geometric.data.Data( x=x, @@ -419,7 +437,7 @@ def manual_simple_graph(): for tetrahedron in tetrahedrons: for i in range(len(tetrahedron)): for j in range(i + 1, len(tetrahedron)): - edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 + edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 # Create a graph G = nx.Graph() diff --git a/topobenchmarkx/io/load/preprocessor.py b/topobenchmarkx/io/load/preprocessor.py index fe22e9e0..5cb8a125 100644 --- a/topobenchmarkx/io/load/preprocessor.py +++ b/topobenchmarkx/io/load/preprocessor.py @@ -7,48 +7,59 @@ from topobenchmarkx.io.load.utils import ensure_serializable, make_hash +from torch_geometric.data.dataset import * + + + class Preprocessor(torch_geometric.data.InMemoryDataset): r"""Preprocessor for datasets. - Parameters - ---------- - data_dir : str - Path to the directory containing the data. - data_list : list - List of data objects. - transforms_config : DictConfig - Configuration parameters for the transforms. - **kwargs: optional - Additional arguments. + Args: + data_dir (str): Path to the directory containing the data. + data_list (list): List of data objects. + transforms_config (DictConfig): Configuration parameters for the transforms. + force_reload (bool): Whether to force reload the data. (default: False) + **kwargs: Optional additional arguments. """ def __init__( - self, data_dir, data_list, transforms_config, force_reload=False, **kwargs + self, + data_dir, + data_list, + transforms_config, + force_reload=False, + **kwargs, ): if isinstance(data_list, torch_geometric.data.Dataset): data_list = [data_list.get(idx) for idx in range(len(data_list))] elif isinstance(data_list, torch_geometric.data.Data): data_list = [data_list] self.data_list = data_list - pre_transform = self.instantiate_pre_transform(data_dir, transforms_config) + pre_transform = self.instantiate_pre_transform( + data_dir, transforms_config + ) + # Torch geometric introduces force_reload from 2.5.0 version, but there is weird bug + self.force_reload = force_reload + super().__init__( self.processed_data_dir, None, pre_transform, - force_reload=force_reload, + #force_reload=force_reload, **kwargs, ) self.save_transform_parameters() self.load(self.processed_paths[0]) + + def __repr__(self): + return f"{self.__class__.__name__}(data_dir={self.root}, data_list={self.data_list}, processed_data_dir={self.processed_data_dir}, processed_file_names={self.processed_file_names})" @property def processed_dir(self) -> str: r"""Return the path to the processed directory. - Returns - ------- - str - Path to the processed directory. + Returns: + str: Path to the processed directory. """ return self.root @@ -57,9 +68,7 @@ def processed_file_names(self) -> str: r"""Return the name of the processed file. Returns - ------- - str - Name of the processed file. + str: Name of the processed file. """ return "data.pt" @@ -68,23 +77,19 @@ def instantiate_pre_transform( ) -> torch_geometric.transforms.Compose: r"""Instantiate the pre-transforms. - Parameters - ---------- - data_dir : str - Path to the directory containing the data. - transforms_config : DictConfig - Configuration parameters for the transforms. - - Returns - ------- - torch_geometric.transforms.Compose - Pre-transform object. + Parameters: + data_dir (str): Path to the directory containing the data. + transforms_config (DictConfig): Configuration parameters for the transforms. + Returns: + torch_geometric.transforms.Compose: Pre-transform object. """ pre_transforms_dict = hydra.utils.instantiate(transforms_config) pre_transforms = torch_geometric.transforms.Compose( list(pre_transforms_dict.values()) ) - self.set_processed_data_dir(pre_transforms_dict, data_dir, transforms_config) + self.set_processed_data_dir( + pre_transforms_dict, data_dir, transforms_config + ) return pre_transforms def set_processed_data_dir( @@ -92,14 +97,10 @@ def set_processed_data_dir( ) -> None: r"""Set the processed data directory. - Parameters - ---------- - pre_transforms_dict : dict - Dictionary containing the pre-transforms. - data_dir : str - Path to the directory containing the data. - transforms_config : DictConfig - Configuration parameters for the transforms. + Args: + pre_transforms_dict (dict): Dictionary containing the pre-transforms. + data_dir (str): Path to the directory containing the data. + transforms_config (DictConfig): Configuration parameters for the transforms. """ # Use self.transform_parameters to define unique save/load path for each transform parameters repo_name = "_".join(list(transforms_config.keys())) @@ -109,7 +110,9 @@ def set_processed_data_dir( } params_hash = make_hash(transforms_parameters) self.transforms_parameters = ensure_serializable(transforms_parameters) - self.processed_data_dir = os.path.join(*[data_dir, repo_name, f"{params_hash}"]) + self.processed_data_dir = os.path.join( + *[data_dir, repo_name, f"{params_hash}"] + ) def save_transform_parameters(self) -> None: r"""Save the transform parameters.""" @@ -126,9 +129,13 @@ def save_transform_parameters(self) -> None: saved_transform_parameters = json.load(f) if saved_transform_parameters != self.transforms_parameters: - raise ValueError("Different transform parameters for the same data_dir") - - print(f"Transform parameters are the same, using existing data_dir: {self.processed_data_dir}") + raise ValueError( + "Different transform parameters for the same data_dir" + ) + + print( + f"Transform parameters are the same, using existing data_dir: {self.processed_data_dir}" + ) def process(self) -> None: r"""Process the data.""" @@ -138,4 +145,4 @@ def process(self) -> None: self._data_list = None # Reset cache. assert isinstance(self._data, torch_geometric.data.Data) - self.save(self.data_list, self.processed_paths[0]) + self.save(self.data_list, self.processed_paths[0]) \ No newline at end of file diff --git a/topobenchmarkx/io/load/split_utils.py b/topobenchmarkx/io/load/split_utils.py index 5d201c4f..df89ce09 100644 --- a/topobenchmarkx/io/load/split_utils.py +++ b/topobenchmarkx/io/load/split_utils.py @@ -9,21 +9,16 @@ # Generate splits in different fasions def k_fold_split(labels, parameters): - """ - Returns train and valid indices as in K-Fold Cross-Validation. If the split already exists - it loads it automatically, otherwise it creates the split file for the subsequent runs. - - Parameters - ---------- - labels : torch.Tensor - Label tensor. - parameters : DictConfig - Configuration parameters. - - Returns - ------- - dict - Dictionary containing the train, validation and test indices. + r"""Returns train and valid indices as in K-Fold Cross-Validation. If the + split already exists it loads it automatically, otherwise it creates the + split file for the subsequent runs. + + Args: + labels (torch.Tensor): Label tensor. + parameters (DictConfig): Configuration parameters. + + Returns: + dict: Dictionary containing the train, validation and test indices, with keys "train", "valid", and "test". """ data_dir = parameters.data_split_dir @@ -48,13 +43,22 @@ def k_fold_split(labels, parameters): skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42) - for fold_n, (train_idx, valid_idx) in enumerate(skf.split(x_idx, labels)): - split_idx = {"train": train_idx, "valid": valid_idx, "test": valid_idx} + for fold_n, (train_idx, valid_idx) in enumerate( + skf.split(x_idx, labels) + ): + split_idx = { + "train": train_idx, + "valid": valid_idx, + "test": valid_idx, + } # Check that all nodes/graph have been assigned to some split assert np.all( np.sort( - np.array(split_idx["train"].tolist() + split_idx["valid"].tolist()) + np.array( + split_idx["train"].tolist() + + split_idx["valid"].tolist() + ) ) == np.sort(np.arange(len(labels))) ), "Not every sample has been loaded." @@ -81,22 +85,15 @@ def k_fold_split(labels, parameters): def random_splitting(labels, parameters, global_data_seed=42): - """Adapted from https://github.com/CUAI/Non-Homophily-Benchmarks + r"""Adapted from https://github.com/CUAI/Non-Homophily-Benchmarks randomly splits label into train/valid/test splits. - Parameters - ---------- - labels : torch.Tensor - Label tensor. - parameters : DictConfig - Configuration parameters. - global_data_seed : int - Seed for the random number generator. - - Returns - ------- - dict - Dictionary containing the train, validation and test indices. + Args: + labels (torch.Tensor): Label tensor. + parameters (DictConfig): Configuration parameters. + global_data_seed (int, optional): Seed for the random number generator. (default: 42) + Returns: + dict: Dictionary containing the train, validation and test indices with keys "train", "valid", and "test". """ fold = parameters["data_seed"] data_dir = parameters["data_split_dir"] @@ -160,21 +157,14 @@ def random_splitting(labels, parameters, global_data_seed=42): def load_split(data, cfg, train_prop=0.5): - r"""Loads the split for generated by rand_train_test_idx function. - - Parameters - ---------- - data : torch_geometric.data.Data - Graph dataset. - cfg : DictConfig - Configuration parameters. - train_prop : float - Proportion of training data. - - Returns - ------- - torch_geometric.data.Data - Graph dataset with the specified split. + r"""Loads the split generated by rand_train_test_idx function. + + Args: + data (torch_geometric.data.Data): Graph dataset. + cfg (DictConfig): Configuration parameters. + train_prop (float): Proportion of training data. + Returns: + torch_geometric.data.Data: Graph dataset with the specified split. """ data_dir = os.path.join(cfg["data_split_dir"], f"train_prop={train_prop}") @@ -199,17 +189,11 @@ def load_split(data, cfg, train_prop=0.5): def assing_train_val_test_mask_to_graphs(dataset, split_idx): r"""Splits the graph dataset into train, validation, and test datasets. - Parameters - ---------- - dataset : torch_geometric.data.Dataset - Graph dataset. - split_idx : dict - Dictionary containing the indices for the train, validation, and test splits. - - Returns - ------- - datasets : list - List containing the train, validation, and test datasets. + Args: + dataset (torch_geometric.data.Dataset): Graph dataset. + split_idx (dict): Dictionary containing the indices for the train, validation, and test splits. + Returns: + list: List containing the train, validation, and test datasets. """ data_train_lst, data_val_lst, data_test_lst = [], [], [] @@ -253,17 +237,11 @@ def assing_train_val_test_mask_to_graphs(dataset, split_idx): def load_graph_tudataset_split(dataset, cfg): r"""Loads the graph dataset with the specified split. - Parameters - ---------- - dataset : torch_geometric.data.Dataset - Graph dataset. - cfg : DictConfig - Configuration parameters. - - Returns - ------- - list - List containing the train, validation, and test splits. + Args: + dataset (torch_geometric.data.Dataset): Graph dataset. + cfg (DictConfig): Configuration parameters. + Returns: + list: List containing the train, validation, and test splits. """ # Extract labels from dataset object assert ( @@ -282,8 +260,8 @@ def load_graph_tudataset_split(dataset, cfg): f"split_type {cfg.split_type} not valid. Choose either 'test' or 'k-fold'" ) - train_dataset, val_dataset, test_dataset = assing_train_val_test_mask_to_graphs( - dataset, split_idx + train_dataset, val_dataset, test_dataset = ( + assing_train_val_test_mask_to_graphs(dataset, split_idx) ) return [train_dataset, val_dataset, test_dataset] @@ -292,17 +270,11 @@ def load_graph_tudataset_split(dataset, cfg): def load_graph_cocitation_split(dataset, cfg): r"""Loads cocitation graph datasets with the specified split. - Parameters - ---------- - dataset : torch_geometric.data.Dataset - Graph dataset. - cfg : DictConfig - Configuration parameters. - - Returns - ------- - list - List containing the train, validation, and test splits. + Args: + dataset (torch_geometric.data.Dataset): Graph dataset. + cfg (DictConfig): Configuration parameters. + Returns: + list: List containing the train, validation, and test splits. """ # Extract labels from dataset object diff --git a/topobenchmarkx/io/load/us_county_demos.py b/topobenchmarkx/io/load/us_county_demos.py index fbbed4b9..6828692a 100644 --- a/topobenchmarkx/io/load/us_county_demos.py +++ b/topobenchmarkx/io/load/us_county_demos.py @@ -5,25 +5,20 @@ def load_us_county_demos(path, year=2012, y_col="Election"): - r"""Load US County Demos dataset - - Parameters - ---------- - path: str - Path to the dataset. - year: int - Year to load the features. - y_col: str - Column to use as label. - - Returns - ------- - torch_geometric.data.Data - Data object of the graph for the US County Demos dataset. + r"""Load US County Demos dataset. + + Args: + path (str): Path to the dataset. + year (int, optional): Year to load the features. (default: 2012) + y_col (str, optional): Column to use as label. Can be one of ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']. (default: "Election") + Returns: + torch_geometric.data.Data: Data object of the graph for the US County Demos dataset. """ - + edges_df = pd.read_csv(f"{path}/county_graph.csv") - stat = pd.read_csv(f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1") + stat = pd.read_csv( + f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1" + ) keep_cols = [ "FIPS", @@ -36,12 +31,12 @@ def load_us_county_demos(path, year=2012, y_col="Election"): "BachelorRate", "UnemploymentRate", ] - + # Select columns, replace ',' with '.' and convert to numeric stat = stat.loc[:, keep_cols] - stat["MedianIncome"] = stat["MedianIncome"].replace(',','.', regex=True) - stat = stat.apply(pd.to_numeric, errors='coerce') - + stat["MedianIncome"] = stat["MedianIncome"].replace(",", ".", regex=True) + stat = stat.apply(pd.to_numeric, errors="coerce") + # Step 2: Substitute NaN values with column mean for column in stat.columns: if column != "FIPS": @@ -58,7 +53,9 @@ def load_us_county_demos(path, year=2012, y_col="Election"): edges_df = edges_df[src_ & dst_] # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present - stat = stat[stat["FIPS"].isin(edges_df["SRC"]) & stat["FIPS"].isin(edges_df["DST"])] + stat = stat[ + stat["FIPS"].isin(edges_df["SRC"]) & stat["FIPS"].isin(edges_df["DST"]) + ] stat = stat.reset_index(drop=True) # Remove rows where SRC == DST @@ -91,7 +88,9 @@ def load_us_county_demos(path, year=2012, y_col="Election"): ) # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically) - edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index) + edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes( + edge_index + ) # Conver mask to index index = np.arange(mask.size(0))[mask] @@ -104,7 +103,9 @@ def load_us_county_demos(path, year=2012, y_col="Election"): stat["FIPS"] = stat.reset_index()["index"] # Create Election variable - stat["Election"] = (stat["DEM"] - stat["GOP"]) / (stat["DEM"] + stat["GOP"]) + stat["Election"] = (stat["DEM"] - stat["GOP"]) / ( + stat["DEM"] + stat["GOP"] + ) # Drop DEM and GOP columns and FIPS stat = stat.drop(columns=["DEM", "GOP", "FIPS"]) diff --git a/topobenchmarkx/io/load/utils.py b/topobenchmarkx/io/load/utils.py index 46690ccc..024f8817 100755 --- a/topobenchmarkx/io/load/utils.py +++ b/topobenchmarkx/io/load/utils.py @@ -18,19 +18,12 @@ def get_complex_connectivity(complex, max_rank, signed=False): r"""Gets the connectivity matrices for the complex. - Parameters - ---------- - complex : topnetx.CellComplex, topnetx.SimplicialComplex - Cell complex. - max_rank : int - Maximum rank of the complex. - signed : bool - If True, returns signed connectivity matrices. - + Args: + complex (topnetx.CellComplex, topnetx.SimplicialComplex): Cell complex. + max_rank (int): Maximum rank of the complex. + signed (bool, optional): If True, returns signed connectivity matrices. (default: False) Returns - ------- - dict - Dictionary containing the connectivity matrices. + dict: Dictionary containing the connectivity matrices. """ practical_shape = list( np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape))) @@ -54,13 +47,15 @@ def get_complex_connectivity(complex, max_rank, signed=False): if connectivity_info == "incidence": connectivity[f"{connectivity_info}_{rank_idx}"] = ( generate_zero_sparse_connectivity( - m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx] + m=practical_shape[rank_idx - 1], + n=practical_shape[rank_idx], ) ) else: connectivity[f"{connectivity_info}_{rank_idx}"] = ( generate_zero_sparse_connectivity( - m=practical_shape[rank_idx], n=practical_shape[rank_idx] + m=practical_shape[rank_idx], + n=practical_shape[rank_idx], ) ) connectivity["shape"] = practical_shape @@ -70,37 +65,32 @@ def get_complex_connectivity(complex, max_rank, signed=False): def generate_zero_sparse_connectivity(m, n): r"""Generates a zero sparse connectivity matrix. - Parameters - ---------- - m : int - Number of rows. - n : int - Number of columns. - - Returns - ------- - torch.sparse_coo_tensor - Zero sparse connectivity matrix. + Args: + m (int): Number of rows. + n (int): Number of columns. + Returns: + torch.sparse_coo_tensor: Zero sparse connectivity matrix. """ return torch.sparse_coo_tensor((m, n)).coalesce() def load_cell_complex_dataset(cfg): - r"""Loads cell complex datasets.""" + r"""Loads cell complex datasets. + + Args: + cfg (DictConfig): Configuration parameters. + """ def load_simplicial_dataset(cfg): r"""Loads simplicial datasets. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. + Args: + cfg (DictConfig): Configuration parameters. It needs to contain the following keys: + - data_name (str): Name of the dataset. - Returns - ------- - torch_geometric.data.Data - Simplicial dataset. + Returns: + torch_geometric.data.Data: Simplicial dataset. """ if cfg["data_name"] != "KarateClub": return NotImplementedError @@ -184,15 +174,11 @@ def load_simplicial_dataset(cfg): def load_hypergraph_pickle_dataset(cfg): r"""Loads hypergraph datasets from pickle files. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. + Args: + cfg (DictConfig): Configuration parameters. - Returns - ------- - torch_geometric.data.Data - Hypergraph dataset. + Returns: + torch_geometric.data.Data: Hypergraph dataset. """ data_dir = cfg["data_dir"] print(f"Loading {cfg['data_domain']} dataset name: {cfg['data_name']}") @@ -237,7 +223,9 @@ def load_hypergraph_pickle_dataset(cfg): # check that every node is in some hyperedge if len(np.unique(node_list)) != num_nodes: # add self hyperedges to isolated nodes - isolated_nodes = np.setdiff1d(np.arange(num_nodes), np.unique(node_list)) + isolated_nodes = np.setdiff1d( + np.arange(num_nodes), np.unique(node_list) + ) for node in isolated_nodes: node_list += [node] @@ -290,15 +278,12 @@ def load_hypergraph_pickle_dataset(cfg): def get_Planetoid_pyg(cfg): r"""Loads Planetoid graph datasets from torch_geometric. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. - - Returns - ------- - torch_geometric.data.Data - Graph dataset. + Args: + cfg (DictConfig): Configuration parameters. It needs to contain the following keys: + - data_dir (str): Path to the directory containing the data. + - data_name (str): Name of the dataset. + Returns: + torch_geometric.data.Data: Graph dataset. """ data_dir, data_name = cfg["data_dir"], cfg["data_name"] dataset = torch_geometric.datasets.Planetoid(data_dir, data_name) @@ -310,15 +295,12 @@ def get_Planetoid_pyg(cfg): def get_TUDataset_pyg(cfg): r"""Loads TU graph datasets from torch_geometric. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. - - Returns - ------- - list - List containing the graph dataset. + Args: + cfg (DictConfig): Configuration parameters. It needs to contain the following keys: + - data_dir (str): Path to the directory containing the data. + - data_name (str): Name of the dataset. + Returns: + list: List containing the graphs in the dataset. """ data_dir, data_name = cfg["data_dir"], cfg["data_name"] dataset = torch_geometric.datasets.TUDataset(root=data_dir, name=data_name) @@ -329,15 +311,10 @@ def get_TUDataset_pyg(cfg): def ensure_serializable(obj): r"""Ensures that the object is serializable. - Parameters - ---------- - obj : object - Object to ensure serializability. - - Returns - ------- - object - Object that is serializable. + Args: + obj (object): Object to ensure serializability. + Returns: + object: Object that is serializable. """ if isinstance(obj, dict): for key, value in obj.items(): @@ -360,15 +337,10 @@ def make_hash(o): contains only other hashable types (including any lists, tuples, sets, and dictionaries). - Parameters - ---------- - o : dict, list, tuple, set - Object to hash. - - Returns - ------- - int - Hash of the object. + Args: + o (dict, list, tuple, set): Object to hash. + Returns: + int: Hash of the object. """ sha1 = hashlib.sha1() sha1.update(str.encode(str(o))) diff --git a/topobenchmarkx/models/__init__.py b/topobenchmarkx/models/__init__.py index e69de29b..2e4f1c9a 100755 --- a/topobenchmarkx/models/__init__.py +++ b/topobenchmarkx/models/__init__.py @@ -0,0 +1,11 @@ +import topobenchmarkx.models.encoders +import topobenchmarkx.models.head_models +import topobenchmarkx.models.losses +import topobenchmarkx.models.readouts +import topobenchmarkx.models.wrappers + +from topobenchmarkx.models.default_network import TopologicalNetworkModule + +__all__ = [ + "TopologicalNetworkModule", +] diff --git a/topobenchmarkx/models/abstractions/__init__.py b/topobenchmarkx/models/abstractions/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/topobenchmarkx/models/abstractions/encoder.py b/topobenchmarkx/models/abstractions/encoder.py deleted file mode 100644 index fbe3189f..00000000 --- a/topobenchmarkx/models/abstractions/encoder.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod - -import torch -import torch_geometric - -class AbstractInitFeaturesEncoder(torch.nn.Module): - """abstract class that provides an interface to define a custom initial feature encoders""" - - def __init__(self): - return - - @abstractmethod - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - """Forward pass of the feature encoder model - - Parameters: - :data: torch_geometric.data.Data - - Returns: - :data: torch_geometric.data.Data - - """ diff --git a/topobenchmarkx/models/default_network.py b/topobenchmarkx/models/default_network.py new file mode 100755 index 00000000..49c4fb45 --- /dev/null +++ b/topobenchmarkx/models/default_network.py @@ -0,0 +1,305 @@ +from typing import Any + +import torch +from lightning import LightningModule +from torchmetrics import MeanMetric +from torch_geometric.data import Data + +class TopologicalNetworkModule(LightningModule): + r"""A `LightningModule` to define a network. + + Args: + backbone (torch.nn.Module): The backbone model to train. + backbone_wrapper (torch.nn.Module): The backbone wrapper class. + readout (torch.nn.Module): The readout class. + head_model (torch.nn.Module): The head model. + loss (torch.nn.Module): The loss class. + feature_encoder (torch.nn.Module, optional): The feature encoder. (default: None) + """ + def __init__( + self, + backbone: torch.nn.Module, + backbone_wrapper: torch.nn.Module, + readout: torch.nn.Module, + head_model: torch.nn.Module, + loss: torch.nn.Module, + feature_encoder: torch.nn.Module | None = None, + **kwargs, + ) -> None: + super().__init__() + + # This line allows to access init params with 'self.hparams' attribute + # also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=[]) + + self.feature_encoder = feature_encoder + self.backbone = backbone_wrapper(backbone) + self.readout = readout + self.head_model = head_model + + # Evaluator + self.evaluator = None + self.train_metrics_logged = False + + # Loss function + self.task_level = self.hparams["head_model"].task_level + self.loss = loss + + # Tracking best so far validation accuracy + self.val_acc_best = MeanMetric() + self.metric_collector_val = [] + self.metric_collector_val2 = [] + self.metric_collector_test = [] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(backbone={self.backbone}, readout={self.readout}, head_model={self.head_model}, loss={self.loss}, feature_encoder={self.feature_encoder})" + + def forward(self, batch: Data) -> dict: + r"""Perform a forward pass through the model `self.backbone`. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + torch.Tensor: A tensor of logits. + """ + return self.backbone(batch) + + def model_step( + self, batch: Data + ) -> dict: + r"""Perform a single model step on a batch of data. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the model output. + """ + + # Feature Encoder + batch = self.feature_encoder(batch) + + # Domain model + model_out = self.forward(batch) + + # Readout + model_out = self.readout(model_out=model_out, batch=batch) + + # Head model + model_out = self.head_model(model_out=model_out, batch=batch) + + # Loss + model_out = self.process_outputs(model_out=model_out, batch=batch) + + # Metric + model_out = self.loss(model_out=model_out, batch=batch) + self.evaluator.update(model_out) + + return model_out + + def training_step(self, batch: Data, batch_idx: int) -> torch.Tensor: + r"""Perform a single training step on a batch of data from the training + set. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + batch_idx (int): The index of the current batch. + Returns: + torch.Tensor: A tensor of losses between model predictions and targets. + """ + self.state_str = "Training" + model_out = self.model_step(batch) + + # Update and log metrics + self.log( + "train/loss", + model_out["loss"], + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=1, + ) + + # Return loss for backpropagation step + return model_out["loss"] + + def validation_step( + self, batch: Data, batch_idx: int + ) -> None: + r"""Perform a single validation step on a batch of data from the validation + set. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + batch_idx (int): The index of the current batch. + """ + self.state_str = "Validation" + model_out = self.model_step(batch) + + # Log Loss + self.log( + "val/loss", + model_out["loss"], + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=1, + ) + + def test_step( + self, batch: Data, batch_idx: int + ) -> None: + r"""Perform a single test step on a batch of data from the test + set. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + batch_idx (int): The index of the current batch. + """ + self.state_str = "Test" + model_out = self.model_step(batch) + + # Log loss + self.log( + "test/loss", + model_out["loss"], + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=1, + ) + + def process_outputs(self, model_out: dict, batch: Data) -> dict: + r"""Process model outputs. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + + # Get the correct mask + if self.state_str == "Training": + mask = batch.train_mask + elif self.state_str == "Validation": + mask = batch.val_mask + elif self.state_str == "Test": + mask = batch.test_mask + else: + raise ValueError("Invalid state_str") + + if self.task_level == "node": + # Keep only train data points + for key, val in model_out.items(): + if key in ["logits", "labels"]: + model_out[key] = val[mask] + + return model_out + + def log_metrics(self, mode=None): + r"""Log metrics. + + Args: + mode (str, optional): The mode of the model, either "train", "val", or "test". (default: None) + """ + metrics_dict = self.evaluator.compute() + for key in metrics_dict: + self.log( + f"{mode}/{key}", + metrics_dict[key], + prog_bar=True, + on_step=False, + ) + + # Reset evaluator for next epoch + self.evaluator.reset() + + def on_validation_epoch_start(self) -> None: + r"""According pytorch lightning documentation, this hook is called at + the beginning of the validation epoch. + + https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks + + Note that the validation step is within the train epoch. Hence here we have to log the train metrics + before we reset the evaluator to start the validation loop. + """ + + # Log train metrics and reset evaluator + self.log_metrics(mode="train") + self.train_metrics_logged = True + + def on_train_epoch_end(self) -> None: + r"""Lightning hook that is called when a train epoch ends. This hook is used to log the train metrics.""" + # Log train metrics and reset evaluator + if not self.train_metrics_logged: + self.log_metrics(mode="train") + self.train_metrics_logged = True + + def on_validation_epoch_end(self) -> None: + r"""Lightning hook that is called when a validation epoch ends. This hook is used to log the validation metrics.""" + # Log validation metrics and reset evaluator + self.log_metrics(mode="val") + + def on_test_epoch_end(self) -> None: + r"""Lightning hook that is called when a test epoch ends. This hook is used to log the test metrics.""" + self.log_metrics(mode="test") + print() + + def on_train_epoch_start(self) -> None: + r"""Lightning hook that is called when a train epoch begins. This hook is used to reset the train metrics.""" + self.evaluator.reset() + self.train_metrics_logged = False + + def on_val_epoch_start(self) -> None: + r"""Lightning hook that is called when a validation epoch begins. This hook is used to reset the validation metrics.""" + self.evaluator.reset() + + def on_test_epoch_start(self) -> None: + r"""Lightning hook that is called when a test epoch begins. This hook is used to reset the test metrics.""" + self.evaluator.reset() + + def setup(self, stage: str) -> None: + r"""Lightning hook that is called at the beginning of fit (train + + validate), validate, test, or predict. + + This is a good hook when you need to build models dynamically or adjust + something about them. This hook is called on every process when using + DDP. + + Args: + stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + """ + if self.hparams.compile and stage == "fit": + self.net = torch.compile(self.net) + + def configure_optimizers(self) -> dict[str, Any]: + r"""Choose what optimizers and learning-rate schedulers to use in your + optimization. Normally you'd need one. But in the case of GANs or + similar you might have multiple. + + Examples: + https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers + + Returns: + dict: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + """ + optimizer = self.hparams.optimizer( + params=list(self.trainer.model.parameters()) + + list(self.readout.parameters()) + ) + if self.hparams.scheduler is not None: + scheduler = self.hparams.scheduler(optimizer=optimizer) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "val/loss", + "interval": "epoch", + "frequency": 1, + }, + } + return {"optimizer": optimizer} + + +if __name__ == "__main__": + _ = TopologicalNetworkModule(None, None, None, None) diff --git a/topobenchmarkx/models/encoders/__init__.py b/topobenchmarkx/models/encoders/__init__.py index e69de29b..fec596d4 100644 --- a/topobenchmarkx/models/encoders/__init__.py +++ b/topobenchmarkx/models/encoders/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.encoders.encoder import AbstractFeatureEncoder +from topobenchmarkx.models.encoders.all_cell_encoder import AllCellFeatureEncoder + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.encoders.other_encoder_1 import OtherEncoder1 +# from topobenchmarkx.models.encoders.other_encoder_2 import OtherEncoder2 + +__all__ = [ + "AbstractFeatureEncoder" + "AllCellFeatureEncoder" + # "OtherEncoder1", + # "OtherEncoder2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/encoders/all_cell_encoder.py b/topobenchmarkx/models/encoders/all_cell_encoder.py new file mode 100644 index 00000000..cf086c4a --- /dev/null +++ b/topobenchmarkx/models/encoders/all_cell_encoder.py @@ -0,0 +1,104 @@ +import torch +import torch_geometric +from torch_geometric.nn.norm import GraphNorm +from topobenchmarkx.models.encoders.encoder import AbstractFeatureEncoder + + +class AllCellFeatureEncoder(AbstractFeatureEncoder): + r"""Encoder class to apply BaseEncoder to the features of higher order structures. The class creates a BaseEncoder for each dimension specified in selected_dimensions. Then during the forward pass, the BaseEncoders are applied to the features of the corresponding dimensions. + + Args: + in_channels (list[int]): Input dimensions for the features. + out_channels (list[int]): Output dimensions for the features. + proj_dropout (float, optional): Dropout for the BaseEncoders. (default: 0) + selected_dimensions (list[int], optional): List of indexes to apply the BaseEncoders to. (default: None) + **kwargs: Additional arguments. + """ + def __init__( + self, + in_channels, + out_channels, + proj_dropout=0, + selected_dimensions=None, + **kwargs + ): + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.dimensions = ( + selected_dimensions + if selected_dimensions is not None + else range(len(self.in_channels)) + ) + for i in self.dimensions: + setattr( + self, + f"encoder_{i}", + BaseEncoder( + self.in_channels[i], + self.out_channels, + dropout=proj_dropout, + ), + ) + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, dimensions={self.dimensions})" + + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + r"""Forward pass. The method applies the BaseEncoders to the features of the selected_dimensions. + + Args: + data (torch_geometric.data.Data): Input data object which should contain x_{i} features for each i in the selected_dimensions. + + Returns: + torch_geometric.data.Data: Output data object with updated x_{i} features. + """ + if not hasattr(data, "x_0"): + data.x_0 = data.x + + for i in self.dimensions: + if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): + batch = getattr(data, f"batch_{i}") + data[f"x_{i}"] = getattr(self, f"encoder_{i}")( + data[f"x_{i}"], batch + ) + return data + +class BaseEncoder(torch.nn.Module): + r"""Encoder class that uses two linear layers with GraphNorm, Relu + activation function, and dropout between the two layers. + + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimensions of output features. + dropout (float, optional): Percentage of channels to discard between the two linear layers. (default: 0) + """ + def __init__(self, in_channels, out_channels, dropout=0): + super().__init__() + self.linear1 = torch.nn.Linear(in_channels, out_channels) + self.linear2 = torch.nn.Linear(out_channels, out_channels) + self.relu = torch.nn.ReLU() + self.BN = GraphNorm(out_channels) + self.dropout = torch.nn.Dropout(dropout) + + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.linear1.in_features}, out_channels={self.linear1.out_features})" + + def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: + r"""Forward pass of the encoder. It applies two linear layers with GraphNorm, Relu activation function, and dropout between the two layers. + + Args: + x (torch.Tensor): Input tensor of dimensions [N, in_channels]. + batch (torch.Tensor): The batch vector which assigns each element to a specific example. + Returns: + torch.Tensor: Output tensor of shape [N, out_channels]. + """ + x = self.linear1(x) + x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x) + x = self.dropout(self.relu(x)) + x = self.linear2(x) + return x + + diff --git a/topobenchmarkx/models/encoders/default_encoders.py b/topobenchmarkx/models/encoders/default_encoders.py deleted file mode 100644 index 775015c2..00000000 --- a/topobenchmarkx/models/encoders/default_encoders.py +++ /dev/null @@ -1,187 +0,0 @@ -import torch -import torch_geometric -from torch_geometric.nn.norm import GraphNorm - -from topobenchmarkx.models.abstractions.encoder import AbstractInitFeaturesEncoder -from topobenchmarkx.models.encoders.perceiver import Perceiver - - -class BaseEncoder(torch.nn.Module): - r"""Encoder class that uses two linear layers with GraphNorm, Relu activation function, and dropout between the two layers. - - Parameters - ---------- - in_channels: int - Dimension of input features. - out_channels: int - Dimensions of output features. - dropout: float - Percentage of channels to discard between the two linear layers. - """ - def __init__(self, in_channels, out_channels, dropout=0): - super().__init__() - self.linear1 = torch.nn.Linear(in_channels, out_channels) - self.linear2 = torch.nn.Linear(out_channels, out_channels) - self.relu = torch.nn.ReLU() - self.BN = GraphNorm(out_channels) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: - r""" - Forward pass - - Parameters - ---------- - x: torch.Tensor - Input tensor of dimensions [N, in_channels]. - batch: torch.Tensor - The batch vector which assigns each element to a specific example. - - Returns - ------- - torch.Tensor - Output tensor of shape [N, out_channels]. - """ - x = self.linear1(x) - x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x) - x = self.dropout(self.relu(x)) - x = self.linear2(x) - return x - - -class BaseFeatureEncoder(AbstractInitFeaturesEncoder): - r"""Encoder class to apply BaseEncoder to the features of higher order structures. - - Parameters - ---------- - in_channels: list(int) - Input dimensions for the features. - out_channels: list(int) - Output dimensions for the features. - proj_dropout: float - Dropout for the BaseEncoders. - selected_dimensions: list(int) - List of indexes to apply the BaseEncoders to. - """ - def __init__( - self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None - ): - super(AbstractInitFeaturesEncoder, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.dimensions = ( - selected_dimensions - if selected_dimensions is not None - else range(len(self.in_channels)) - ) - for i in self.dimensions: - setattr( - self, - f"encoder_{i}", - BaseEncoder( - self.in_channels[i], self.out_channels, dropout=proj_dropout - ), - ) - - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - r""" - Forward pass - - Parameters - ---------- - data: torch_geometric.data.Data - Input data object which should contain x_{i} features for each i in the selected_dimensions. - - Returns - ------- - torch_geometric.data.Data - Output data object. - """ - if not hasattr(data, "x_0"): - data.x_0 = data.x - - for i in self.dimensions: - if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): - batch = getattr(data, f"batch_{i}") - data[f"x_{i}"] = getattr(self, f"encoder_{i}")(data[f"x_{i}"], batch) - return data - - -class SetFeatureEncoder(AbstractInitFeaturesEncoder): - r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures. - - Parameters - ---------- - in_channels: list(int) - Input dimensions for the features. - out_channels: list(int) - Output dimensions for the features. - proj_dropout: float - Dropout for the BaseEncoders. - selected_dimensions: list(int) - List of indexes to apply the BaseEncoders to. - """ - def __init__( - self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None - ): - super(AbstractInitFeaturesEncoder, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.dimensions = ( - selected_dimensions - if selected_dimensions is not None - else range(len(self.in_channels)) - ) - for idx, i in enumerate(self.dimensions): - if idx == 0: - setattr( - self, - f"encoder_{i}", - BaseEncoder( - self.in_channels[i], self.out_channels, dropout=proj_dropout - ), - ) - else: - setattr( - self, - f"encoder_{i}", - Perceiver( - dim=self.out_channels, - depth=1, - cross_heads=4, - cross_dim_head=self.out_channels, - latent_dim_head=self.out_channels, - ), - ) - - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - r""" - Forward pass - - Parameters - ---------- - data: torch_geometric.data.Data - Input data object which should contain x_{i} features for each i in the selected_dimensions. - - Returns - ------- - torch_geometric.data.Data - Output data object. - """ - if not hasattr(data, "x_0"): - data.x_0 = data.x - - for idx, i in enumerate(self.dimensions): - if idx == 0: - if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): - batch = data.batch if i == 0 else getattr(data, f"batch_{i}") - data[f"x_{i}"] = getattr(self, f"encoder_{i}")( - data[f"x_{i}"], batch - ) - else: - if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): - cell_features = data["x_0"][data[f"x_{i}"].long()] - data[f"x_{i}"] = getattr(self, f"encoder_{i}")(cell_features) - else: - data[f"x_{i}"] = torch.tensor([], device=data.x_0.device) - return data diff --git a/topobenchmarkx/models/encoders/encoder.py b/topobenchmarkx/models/encoders/encoder.py new file mode 100644 index 00000000..6426ef6b --- /dev/null +++ b/topobenchmarkx/models/encoders/encoder.py @@ -0,0 +1,30 @@ +from abc import abstractmethod + +import torch +import torch_geometric + + +class AbstractFeatureEncoder(torch.nn.Module): + r"""Abstract class that provides an interface to define a custom feature encoder.""" + + def __init__(self, **kwargs): + super().__init__() + return + + def __repr__(self): + return f"{self.__class__.__name__}()" + + def __call__(self, data): + return self.forward(data) + + @abstractmethod + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + r"""Forward pass of the feature encoder model. + + Args: + data (torch_geometric.data.Data): Input data object which should contain x features. + Returns: + torch_geometric.data.Data: Output data object with updated x features. + """ \ No newline at end of file diff --git a/topobenchmarkx/models/encoders/perceiver.py b/topobenchmarkx/models/encoders/perceiver.py index 2ca12cca..b12bb36f 100644 --- a/topobenchmarkx/models/encoders/perceiver.py +++ b/topobenchmarkx/models/encoders/perceiver.py @@ -8,7 +8,6 @@ # helpers - def exists(val): return val is not None @@ -65,36 +64,32 @@ def cached_fn(*args, _cache=True, **kwargs): class PreNorm(nn.Module): r"""Class to wrap together LayerNorm and a specified function. - - Parameters - ---------- - dim: int - Size of the dimension to normalize. - fn: torch.nn.Module - Function after LayerNorm. - context_dim: int - Size of the context to normalize. + + Args: + dim (int): Size of the dimension to normalize. + fn (torch.nn.Module): Function after LayerNorm. + context_dim (int, optional): Size of the context to normalize. (default: None) """ + def __init__(self, dim, fn, context_dim=None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) - self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None - + self.norm_context = ( + nn.LayerNorm(context_dim) if exists(context_dim) else None + ) + + def __repr__(self): + return f"{self.__class__.__name__}(dim={self.norm.normalized_shape[0]}, fn={self.fn}, context_dim={self.norm_context.normalized_shape[0] if exists(self.norm_context) else None})" + def forward(self, x, **kwargs): - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor. - kwargs: dict - Dictionary of keyword arguments. - - Returns - ------- - torch.Tensor - Output tensor. + r"""Forward pass of the PreNorm class. + + Args: + x (torch.Tensor): Input tensor. + **kwargs: Additional arguments. If context_dim is not None the context tensor should be passed. + Returns: + torch.Tensor: Output tensor. """ x = self.norm(x) @@ -108,57 +103,54 @@ def forward(self, x, **kwargs): class GEGLU(nn.Module): r"""GEGLU activation function.""" + def forward(self, x): - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor. + r"""Forward pass of the GEGLU activation function. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output tensor. """ x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) + class FeedForward(nn.Module): - r"""Feedforward network. - - Parameters - ---------- - dim: int - Size of the input dimension. - mult: int - Multiplier for the hidden dimension. + r"""Feedforward network with two linear layers and GEGLU activation function in between. + + Args: + dim (int): Size of the input dimension. + mult (int, optional): Multiplier for the hidden dimension. (default: 4) """ def __init__(self, dim, mult=4): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim) ) - + + def __repr__(self): + return f"{self.__class__.__name__}(dim={self.net[0].in_features}, mult={self.net[0].out_features // self.net[0].in_features})" + def forward(self, x): - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor. + r"""Forward pass of the FeedForward class. + + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output tensor. """ return self.net(x) class Attention(nn.Module): - r"""Attention function. - - Parameters - ---------- - query_dim: int - Size of the query dimension. - context_dim: int - Size of the context dimension. - heads: int - Number of heads. - dim_head: int - Size for each head. + r"""Attention class to calculate the attention weights. + + Args: + query_dim (int): Size of the query dimension. + context_dim (int, optional): Size of the context dimension. (default: None) + heads (int, optional): Number of heads. (default: 8) + dim_head (int, optional): Size for each head. (default: 64) """ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): super().__init__() @@ -171,22 +163,18 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, query_dim) + def __repr__(self): + return f"{self.__class__.__name__}(query_dim={self.to_q.in_features}, context_dim={self.to_kv.in_features // 2}, heads={self.heads}, dim_head={self.to_q.out_features // self.heads})" + def forward(self, x, context=None, mask=None): - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor. - context: torch.Tensor - Context tensor. - mask: torch.Tensor - Mask for attention calculation purposes. - - Returns - ------- - torch.Tensor - Output tensor. + r"""Forward pass of the Attention class. + + Args: + x (torch.Tensor): Input tensor. + context (torch.Tensor, optional): Context tensor. (default: None) + mask (torch.Tensor, optional): Mask for attention calculation purposes. (default: None) + Returns: + torch.Tensor: Output tensor. """ h = self.heads @@ -194,7 +182,9 @@ def forward(self, x, context=None, mask=None): context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v) + ) sim = einsum("b i d, b j d -> b i j", q, k) * self.scale @@ -216,29 +206,20 @@ def forward(self, x, context=None, mask=None): class Perceiver(nn.Module): - r"""Perceiver model. - - Parameters - ---------- - depth: int - Number of layers to add to the model. - dim: int - Size of the input dimension. - num_latents: int - Number of latent vectors. - cross_heads: int - Number of heads for cross attention. - latent_heads: int - Number of heads for latent attention. - cross_dim_head: int - Size of the cross attention head. - latent_dim_head: int - Size of the latent attention head. - weight_tie_layers: bool - Whether to tie the weights of the layers. - decoder_ff: bool - Whether to use a feedforward network in the decoder. + r"""Perceiver model. For more information https://arxiv.org/abs/2103.03206. + + Args: + depth (int): Number of layers to add to the model. + dim (int): Size of the input dimension. + num_latents (int, optional): Number of latent vectors. (default: 1) + cross_heads (int, optional): Number of heads for cross attention. (default: 1) + latent_heads (int, optional): Number of heads for latent attention. (default: 8) + cross_dim_head (int, optional): Size of the cross attention head. (default: 64) + latent_dim_head (int, optional): Size of the latent attention head. (default: 64) + weight_tie_layers (bool, optional): Whether to tie the weights of the layers. (default: False) + decoder_ff (bool, optional): Whether to use a feedforward network in the decoder. (default: False) """ + def __init__( self, *, @@ -267,7 +248,10 @@ def __init__( PreNorm( latent_dim, Attention( - latent_dim, dim, heads=cross_heads, dim_head=cross_dim_head + latent_dim, + dim, + heads=cross_heads, + dim_head=cross_dim_head, ), context_dim=dim, ), @@ -275,16 +259,20 @@ def __init__( ] ) - def get_latent_attn(): + def get_latent_attn(): return PreNorm( latent_dim, - Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head), + Attention( + latent_dim, heads=latent_heads, dim_head=latent_dim_head + ), ) - def get_latent_ff(): + def get_latent_ff(): return PreNorm(latent_dim, FeedForward(latent_dim)) - - get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff)) + + get_latent_attn, get_latent_ff = map( + cache_fn, (get_latent_attn, get_latent_ff) + ) self.layers = nn.ModuleList([]) cache_args = {"_cache": weight_tie_layers} @@ -292,36 +280,54 @@ def get_latent_ff(): for _ in range(depth): self.layers.append( nn.ModuleList( - [get_latent_attn(**cache_args), get_latent_ff(**cache_args)] + [ + get_latent_attn(**cache_args), + get_latent_ff(**cache_args), + ] ) ) self.decoder_cross_attn = PreNorm( queries_dim, Attention( - queries_dim, latent_dim, heads=cross_heads, dim_head=cross_dim_head + queries_dim, + latent_dim, + heads=cross_heads, + dim_head=cross_dim_head, ), context_dim=latent_dim, ) self.decoder_ff = ( - PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None + PreNorm(queries_dim, FeedForward(queries_dim)) + if decoder_ff + else None ) + self.dim = dim + self.num_latents = num_latents + self.cross_heads = cross_heads + self.latent_heads = latent_heads + self.cross_dim_head = cross_dim_head + self.latent_dim_head = latent_dim_head + self.weight_tie_layers = weight_tie_layers + self.decoder_ff = decoder_ff + # self.to_logits = ( # nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity() # ) + + def __repr__(self): + return f"{self.__class__.__name__}(depth={len(self.layers)}, dim={self.dim}, num_latents={self.num_latents}, cross_heads={self.cross_heads}, latent_heads={self.latent_heads}, cross_dim_head={self.cross_dim_head}, latent_dim_head={self.latent_dim_head}, weight_tie_layers={self.weight_tie_layers}, decoder_ff={self.decoder_ff}" def forward(self, data, mask=None, queries=None): - r"""Forward pass. - - Parameters - ---------- - data: torch.Tensor - Input tensor. - mask: torch.Tensor - Mask for attention calculation purposes. - queries: torch.Tensor - Queries tensor. + r"""Forward pass of the Perceiver model. + + Args: + data (torch.Tensor): Input tensor. + mask (torch.Tensor, optional): Mask for attention calculation purposes. (default: None) + queries (torch.Tensor, optional): Queries tensor. (default: None) + Returns: + torch.Tensor: Output tensor. """ b, *_ = *data.shape @@ -365,4 +371,86 @@ def forward(self, data, mask=None, queries=None): # final linear out # return x #self.to_logits(latents) - return + return None + + + +# from topobenchmarkx.models.encoders.perceiver import Perceiver +# class SetFeatureEncoder(AbstractInitFeaturesEncoder): +# r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures. + +# Parameters +# ---------- +# in_channels: list(int) +# Input dimensions for the features. +# out_channels: list(int) +# Output dimensions for the features. +# proj_dropout: float +# Dropout for the BaseEncoders. +# selected_dimensions: list(int) +# List of indexes to apply the BaseEncoders to. +# """ +# def __init__( +# self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None +# ): +# super(AbstractInitFeaturesEncoder, self).__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.dimensions = ( +# selected_dimensions +# if selected_dimensions is not None +# else range(len(self.in_channels)) +# ) +# for idx, i in enumerate(self.dimensions): +# if idx == 0: +# setattr( +# self, +# f"encoder_{i}", +# BaseEncoder( +# self.in_channels[i], self.out_channels, dropout=proj_dropout +# ), +# ) +# else: +# setattr( +# self, +# f"encoder_{i}", +# Perceiver( +# dim=self.out_channels, +# depth=1, +# cross_heads=4, +# cross_dim_head=self.out_channels, +# latent_dim_head=self.out_channels, +# ), +# ) + +# def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: +# r""" +# Forward pass + +# Parameters +# ---------- +# data: torch_geometric.data.Data +# Input data object which should contain x_{i} features for each i in the selected_dimensions. + +# Returns +# ------- +# torch_geometric.data.Data +# Output data object. +# """ +# if not hasattr(data, "x_0"): +# data.x_0 = data.x + +# for idx, i in enumerate(self.dimensions): +# if idx == 0: +# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): +# batch = data.batch if i == 0 else getattr(data, f"batch_{i}") +# data[f"x_{i}"] = getattr(self, f"encoder_{i}")( +# data[f"x_{i}"], batch +# ) +# else: +# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): +# cell_features = data["x_0"][data[f"x_{i}"].long()] +# data[f"x_{i}"] = getattr(self, f"encoder_{i}")(cell_features) +# else: +# data[f"x_{i}"] = torch.tensor([], device=data.x_0.device) +# return data diff --git a/topobenchmarkx/models/head_model/__init__.py b/topobenchmarkx/models/head_model/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/topobenchmarkx/models/head_model/models.py b/topobenchmarkx/models/head_model/models.py deleted file mode 100644 index 2da2f79e..00000000 --- a/topobenchmarkx/models/head_model/models.py +++ /dev/null @@ -1,61 +0,0 @@ -import torch -from torch_geometric.utils import scatter - - -class DefaultHead(torch.nn.Module): - r"""Head model. - - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. - """ - def __init__( - self, - in_channels: int, - out_channels: int, - task_level: str, - pooling_type: str = "sum", - ): - super().__init__() - self.linear = torch.nn.Linear(in_channels, out_channels) - - assert task_level in ["graph", "node"], "Invalid task_level" - self.task_level = task_level - - assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" - self.pooling_type = pooling_type - - def forward(self, model_out: dict): - r"""Forward pass. - - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - - Returns - ------- - dict - Dictionary containing the updated model output. Resulting key is "logits". - """ - x = model_out["x_0"] - batch = model_out["batch_0"] - if self.task_level == "graph": - if self.pooling_type == "max": - x = scatter(x, batch, dim=0, reduce="max") - - elif self.pooling_type == "mean": - x = scatter(x, batch, dim=0, reduce="mean") - - elif self.pooling_type == "sum": - x = scatter(x, batch, dim=0, reduce="sum") - - model_out["logits"] = self.linear(x) - return model_out diff --git a/topobenchmarkx/models/head_models/__init__.py b/topobenchmarkx/models/head_models/__init__.py new file mode 100644 index 00000000..4e4577e5 --- /dev/null +++ b/topobenchmarkx/models/head_models/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.head_models.head_model import AbstractHeadModel +from topobenchmarkx.models.head_models.zero_cell_model import ZeroCellModel + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherheadModel1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherheadModel2 + +__all__ = [ + "AbstractHeadModel", + "ZeroCellModel", + # "OtherheadModel1", + # "OtherheadModel2", + # ... add other readout classes here +] diff --git a/topobenchmarkx/models/head_models/head_model.py b/topobenchmarkx/models/head_models/head_model.py new file mode 100644 index 00000000..bd26baea --- /dev/null +++ b/topobenchmarkx/models/head_models/head_model.py @@ -0,0 +1,40 @@ +import torch +import torch_geometric +from abc import abstractmethod + +class AbstractHeadModel(torch.nn.Module): + r"""Abstract head model class. + + Args: + in_channels (int): Input dimension. + out_channels (int): Output dimension. + """ + def __init__( + self, + in_channels: int, + out_channels: int, + + ): + super().__init__() + self.linear = torch.nn.Linear(in_channels, out_channels) + + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.linear.in_features}, out_channels={self.linear.out_features})" + + def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + x = self.forward(model_out, batch) + model_out["logits"] = self.linear(x) + return model_out + + @abstractmethod + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass of the head model. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + torch.Tensor: Output tensor over which the final linear layer is applied. + """ + pass + \ No newline at end of file diff --git a/topobenchmarkx/models/head_models/zero_cell_model.py b/topobenchmarkx/models/head_models/zero_cell_model.py new file mode 100644 index 00000000..8c6dacec --- /dev/null +++ b/topobenchmarkx/models/head_models/zero_cell_model.py @@ -0,0 +1,56 @@ +import torch +import torch_geometric +from torch_geometric.utils import scatter +from topobenchmarkx.models.head_models.head_model import AbstractHeadModel + +class ZeroCellModel(AbstractHeadModel): + r"""Zero cell head model. This model produces an output based only on the features of the nodes (the zero cells). The output is obtained by applying a linear layer to the input features. Based on the task level, the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph or return a value for each node. + + Args: + in_channels (int): Input dimension. + out_channels (int): Output dimension. + task_level (str): Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. + pooling_type (str, optional): Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. (default: "sum") + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + task_level: str, + pooling_type: str = "sum", + **kwargs, + ): + super().__init__(in_channels, out_channels) + + assert task_level in ["graph", "node"], "Invalid task_level" + self.task_level = task_level + + assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" + self.pooling_type = pooling_type + + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.linear.in_features}, out_channels={self.linear.out_features}, task_level={self.task_level}, pooling_type={self.pooling_type})" + + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass of the zero cell head model. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + torch.Tensor: Output tensor. + """ + x = model_out["x_0"] + batch = batch["batch_0"] + if self.task_level == "graph": + if self.pooling_type == "max": + x = scatter(x, batch, dim=0, reduce="max") + + elif self.pooling_type == "mean": + x = scatter(x, batch, dim=0, reduce="mean") + + elif self.pooling_type == "sum": + x = scatter(x, batch, dim=0, reduce="sum") + + return x \ No newline at end of file diff --git a/topobenchmarkx/models/losses/__init__.py b/topobenchmarkx/models/losses/__init__.py index e69de29b..bd2158c3 100755 --- a/topobenchmarkx/models/losses/__init__.py +++ b/topobenchmarkx/models/losses/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.losses.loss import AbstractltLoss +from topobenchmarkx.models.losses.default_loss import DefaultLoss + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.losses.other_loss_1 import OtherLoss1 +# from topobenchmarkx.models.losses.other_loss_2 import OtherLoss2 + +__all__ = [ + "AbstractltLoss", + "DefaultLoss" + # "OtherLoss1", + # "OtherLoss2", + # ... add other loss classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/losses/default_loss.py b/topobenchmarkx/models/losses/default_loss.py new file mode 100644 index 00000000..3f96383e --- /dev/null +++ b/topobenchmarkx/models/losses/default_loss.py @@ -0,0 +1,51 @@ +import torch +import torch_geometric +from topobenchmarkx.models.losses.loss import AbstractltLoss + +class DefaultLoss(AbstractltLoss): + r"""Abstract class that provides an interface to loss logic within + netowrk. + + Args: + task (str): Task type, either "classification" or "regression". + loss_type (str, optional): Loss type, either "cross_entropy", "mse", or "mae". (default: None) + """ + + def __init__(self, task, loss_type=None): + super().__init__() + self.task = task + if task == "classification" and loss_type == "cross_entropy": + self.criterion = torch.nn.CrossEntropyLoss() + + elif task == "regression" and loss_type == "mse": + self.criterion = torch.nn.MSELoss() + + elif task == "regression" and loss_type == "mae": + self.criterion = torch.nn.L1Loss() + + else: + raise Exception("Loss is not defined") + self.loss_type = loss_type + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})' + + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass of the loss function. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + model_out (dict): Dictionary containing the model output with the loss. + """ + logits = model_out["logits"] + target = model_out["labels"] + + if self.task == "regression": + target = target.unsqueeze(1) + + model_out["loss"] = self.criterion(logits, target) + + return model_out + \ No newline at end of file diff --git a/topobenchmarkx/models/losses/loss.py b/topobenchmarkx/models/losses/loss.py index 42a6d47d..e41aca40 100755 --- a/topobenchmarkx/models/losses/loss.py +++ b/topobenchmarkx/models/losses/loss.py @@ -1,83 +1,18 @@ -import torch - -# import hydra -# from omegaconf import DictConfig - - -class DefaultLoss: - """Abstract class that provides an interface to loss logic within netowrk""" - - def __init__(self, task, loss_type=None): - self.task = task - if task == "classification" and loss_type == "cross_entropy": - self.criterion = torch.nn.CrossEntropyLoss() - - elif task == "regression" and loss_type == "mse": - self.criterion = torch.nn.MSELoss() - - elif task == "regression" and loss_type == "mae": - self.criterion = torch.nn.L1Loss() - - else: - raise Exception("Loss is not defined") - - def __call__(self, model_output): - """Loss logic based on model_output""" - - logits = model_output["logits"] - target = model_output["labels"] - - if self.task == "regression": - target = target.unsqueeze(1) - - model_output["loss"] = self.criterion(logits, target) - - return model_output - - -# class NodeTaskLoss: -# """Abstract class that provides an interface to loss logic within netowrk""" - -# def __init__(self, task): -# if task == "classification": -# self.criterion = torch.nn.CrossEntropyLoss() - -# elif task == "regression": -# self.criterion == torch.nn.mse() - -# else: -# raise Exception("Loss is not defined") - -# def __call__(self, model_output): -# """Loss logic based on model_output""" - -# logits = model_output["logits"] -# target = model_output["labels"] -# model_output["loss"] = self.criterion(logits, target) - -# return model_output - - -# from abc import ABC, abstractmethod - -# import hydra -# from omegaconf import DictConfig - -# # logger = logging.getLogger(__name__) - - -# class AbstractLoss(ABC): -# """Abstract class that provides an interface to loss logic within netowrk""" - -# def __init__(self, cfg: DictConfig): -# self.cfg = cfg - -# @abstractmethod -# def init_loss( -# self, -# ): -# """Initialize loss""" - -# @abstractmethod -# def forward(self, model_output): -# """Loss logic based on model_output""" +import torch_geometric +from abc import ABC, abstractmethod + +class AbstractltLoss(ABC): + r"""Abstract class for the loss class.""" + def __init__(self,): + super().__init__() + + def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + r"""Loss logic based on model_output.""" + return self.forward(model_out, batch) + + @abstractmethod + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' \ No newline at end of file diff --git a/topobenchmarkx/models/losses/losses.py b/topobenchmarkx/models/losses/losses.py deleted file mode 100755 index fbdb7dc8..00000000 --- a/topobenchmarkx/models/losses/losses.py +++ /dev/null @@ -1,34 +0,0 @@ -# import hydra -# import torch -# from omegaconf import DictConfig - -# from topobenchmarkx.models.losses.loss import AbstractLoss - - -# class DefaultLoss(AbstractLoss): -# """Abstract class that provides an interface to loss logic within netowrk""" - -# def __init__(self, cfg: DictConfig): -# super().__init__(cfg) - -# def init_loss( -# self, -# ): -# if self.cfg.task == 'classification': -# self.criterion = torch.nn.CrossEntropyLoss() - -# elif self.cfg.task == 'regression': -# self.criterion == torch.nn.mse() - -# else: -# raise Exception("Loss is not defined") - - -# def forward(self, model_output): -# """Loss logic based on model_output""" - -# logits = model_output["logits"] -# target = model_output["labels"] -# model_output["loss"] = self.criterion(logits, target) - -# return model_output diff --git a/topobenchmarkx/models/network_module.py b/topobenchmarkx/models/network_module.py deleted file mode 100755 index 61b2ee8e..00000000 --- a/topobenchmarkx/models/network_module.py +++ /dev/null @@ -1,339 +0,0 @@ -from typing import Any, Union - -import torch -from lightning import LightningModule -from torchmetrics import MeanMetric - -# import topomodelx - - -class NetworkModule(LightningModule): - """A `LightningModule` implements 8 key methods: - - Docs: - https://lightning.ai/docs/pytorch/latest/common/lightning_module.html - """ - - def __init__( - self, - backbone: torch.nn.Module, - backbone_wrapper: torch.nn.Module, - readout: torch.nn.Module, - head_model: torch.nn.Module, - loss: torch.nn.Module, - feature_encoder: Union[torch.nn.Module, None] = None, - **kwargs, - ) -> None: - """Initialize a `NetworkModule`. - - :param backbone: The backbone model to train. - :param readout: The readout class. - :param loss: The loss class. - :param optimizer: The optimizer to use for training. - :param scheduler: The learning rate scheduler to use for training. - """ - super().__init__() - - # This line allows to access init params with 'self.hparams' attribute - # also ensures init params will be stored in ckpt - self.save_hyperparameters( - logger=False, - ignore=[] - ) - - self.feature_encoder = feature_encoder - self.backbone = backbone_wrapper(backbone) - self.readout = readout - self.head_model = head_model - - # Evaluator - self.evaluator = None - self.train_metrics_logged = False - - # Loss function - self.task_level = self.hparams["head_model"].task_level - self.criterion = loss - - # Tracking best so far validation accuracy - self.val_acc_best = MeanMetric() - self.metric_collector_val = [] - self.metric_collector_val2 = [] - self.metric_collector_test = [] - - def forward(self, batch) -> dict: - """Perform a forward pass through the model `self.backbone`. - - :param x: A tensor of images. - :return: A tensor of logits. - """ - return self.backbone(batch) - - def model_step(self, batch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform a single model step on a batch of data. - - :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. - - :return: A tuple containing (in order): - - A tensor of losses. - - A tensor of predictions. - - A tensor of target labels. - """ - # Pipeline - if self.feature_encoder: - batch = self.feature_encoder(batch) - - model_out = self.forward(batch) - model_out = self.readout(model_out, batch) - model_out = self.head_model(model_out) - - # Criterion and metric - model_out = self.process_outputs(batch, model_out) - model_out = self.criterion(model_out) - self.evaluator.update(model_out) - - return model_out - - def training_step(self, batch, batch_idx: int) -> torch.Tensor: - """Perform a single training step on a batch of data from the training set. - - :param batch: A batch of data (a tuple) containing the input tensor of images and target - labels. - :param batch_idx: The index of the current batch. - :return: A tensor of losses between model predictions and targets. - """ - self.state_str = "Training" - model_out = self.model_step(batch) - - # Update and log metrics - self.log( - "train/loss", - model_out["loss"], - on_step=False, - on_epoch=True, - prog_bar=True, - batch_size=1, - ) - - # Return loss for backpropagation step - return model_out["loss"] - - def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> None: - """Perform a single validation step on a batch of data from the validation set. - - :param batch: A batch of data (a tuple) containing the input tensor of images and target - labels. - :param batch_idx: The index of the current batch. - """ - self.state_str = "Validation" - model_out = self.model_step(batch) - - # # Keep only validation data points - # if self.task_level == "node": - # for key, val in model_out.items(): - # # if key not in ["loss", "hyperedge"]: - # if key in ["logits", "labels"]: - # model_out[key] = val[batch.val_mask] - - # # Criterion - # model_out = self.criterion(model_out) - - # # Evaluation - # self.evaluator.update(model_out) - # self.metric_collector_val.append((model_out["logits"], model_out["labels"])) - - # Log Loss - self.log( - "val/loss", - model_out["loss"], - on_step=False, - on_epoch=True, - prog_bar=True, - batch_size=1, - ) - - def test_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int - ) -> None: - """Perform a single test step on a batch of data from the test set. - - :param batch: A batch of data (a tuple) containing the input tensor of images and target - labels. - :param batch_idx: The index of the current batch. - """ - self.state_str = "Test" - model_out = self.model_step(batch) - - # if self.task_level == "node": - # # Keep only test data points - # for key, val in model_out.items(): - # if key in ["logits", "labels"]: - # model_out[key] = val[batch.test_mask] - - # # Criterion - # model_out = self.criterion(model_out) - - # Log loss - self.log( - "test/loss", - model_out["loss"], - on_step=False, - on_epoch=True, - prog_bar=True, - batch_size=1, - ) - - # Evaluation - # self.evaluator.update(model_out) - # self.metric_collector_test.append((model_out["logits"], model_out["labels"])) - - def process_outputs(self, batch, model_out: dict) -> dict: - """Process model outputs.""" - - # Get the correct mask - if self.state_str == "Training": - mask = batch.train_mask - elif self.state_str == "Validation": - mask = batch.val_mask - elif self.state_str == "Test": - mask = batch.test_mask - else: - raise ValueError("Invalid state_str") - - if self.task_level == "node": - # Keep only train data points - for key, val in model_out.items(): - if key in ["logits", "labels"]: - model_out[key] = val[mask] - - return model_out - - def log_metrics(self, mode=None): - """Log metrics.""" - metrics_dict = self.evaluator.compute() - for key in metrics_dict: - self.log( - f"{mode}/{key}", - metrics_dict[key], - prog_bar=True, - on_step=False, - ) - - # Reset evaluator for next epoch - self.evaluator.reset() - - def on_validation_epoch_start(self) -> None: - """According pytorch lightning documentation, this hook is called at the beginning of the validation epoch. - - https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks - - Note that the validation step is within the train epoch. Hence here we have to log the train metrics - before we reset the evaluator to start the validation loop. - """ - - # Log train metrics and reset evaluator - self.log_metrics(mode="train") - self.train_metrics_logged = True - - def on_train_epoch_end(self) -> None: - # Log train metrics and reset evaluator - if not self.train_metrics_logged: - self.log_metrics(mode="train") - self.train_metrics_logged = True - - - def on_validation_epoch_end(self) -> None: - """Lightning hook that is called when a test epoch ends.""" - # Log validation metrics and reset evaluator - self.log_metrics(mode="val") - - def on_test_epoch_end(self) -> None: - """Lightning hook that is called when a test epoch ends.""" - self.log_metrics(mode="test") - print() - - def on_train_epoch_start(self) -> None: - """Lightning hook that is called when a test epoch ends.""" - self.evaluator.reset() - self.train_metrics_logged = False - - def on_val_epoch_start(self) -> None: - """Lightning hook that is called when a test epoch ends.""" - self.evaluator.reset() - - def on_test_epoch_start(self) -> None: - """Lightning hook that is called when a test epoch ends.""" - self.evaluator.reset() - - def setup(self, stage: str) -> None: - """Lightning hook that is called at the beginning of fit (train + validate), validate, - test, or predict. - - This is a good hook when you need to build models dynamically or adjust something about - them. This hook is called on every process when using DDP. - - :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. - """ - if self.hparams.compile and stage == "fit": - self.net = torch.compile(self.net) - - def configure_optimizers(self) -> dict[str, Any]: - """Choose what optimizers and learning-rate schedulers to use in your optimization. - Normally you'd need one. But in the case of GANs or similar you might have multiple. - - Examples: - https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers - - :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. - """ - optimizer = self.hparams.optimizer( - params=list(self.trainer.model.parameters()) - + list(self.readout.parameters()) - ) - if self.hparams.scheduler is not None: - scheduler = self.hparams.scheduler(optimizer=optimizer) - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": scheduler, - "monitor": "val/loss", - "interval": "epoch", - "frequency": 1, - }, - } - return {"optimizer": optimizer} - - -# Collect validation statistics -# self.val_acc_best.update(model_out["metrics"]["acc"]) -# self.metric_collector.append(model_out["metrics"]["acc"]) - - -# def on_train_start(self) -> None: -# """Lightning hook that is called when training begins.""" -# # by default lightning executes validation step sanity checks before training starts, -# # so it's worth to make sure validation metrics don't store results from these checks -# # self.val_loss.reset() -# # self.val_acc.reset() -# self.val_acc_best.reset() - - -# def on_validation_epoch_end(self) -> None: -# "Lightning hook that is called when a validation epoch ends." -# pass -# self.criterion = torch.nn.CrossEntropyLoss() - -# self.evaluator = evaluator -# # metric objects for calculating and averaging accuracy across batches -# self.train_acc = Accuracy(task="multiclass", num_classes=7) -# self.val_acc = Accuracy(task="multiclass", num_classes=7) -# self.test_acc = Accuracy(task="multiclass", num_classes=7) - -# for averaging loss across batches -# self.train_loss = MeanMetric() -# self.val_loss = MeanMetric() -# self.test_loss = MeanMetric() - -if __name__ == "__main__": - _ = NetworkModule(None, None, None, None) diff --git a/topobenchmarkx/models/readouts/__init__.py b/topobenchmarkx/models/readouts/__init__.py index e69de29b..4280f106 100644 --- a/topobenchmarkx/models/readouts/__init__.py +++ b/topobenchmarkx/models/readouts/__init__.py @@ -0,0 +1,18 @@ +from topobenchmarkx.models.readouts.readout import AbstractReadOut +from topobenchmarkx.models.readouts.propagate_signal_down import PropagateSignalDown +from topobenchmarkx.models.readouts.identical import NoReadOut + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherReadout1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherReadout2 + +# Export all readouts and the dictionary +__all__ = [ + "AbstractReadOut" + "PropagateSignalDown", + "NoReadOut" + # "OtherReadout1", + # "OtherReadout2", + # ... add other readout classes here +] diff --git a/topobenchmarkx/models/readouts/identical.py b/topobenchmarkx/models/readouts/identical.py new file mode 100644 index 00000000..32c96620 --- /dev/null +++ b/topobenchmarkx/models/readouts/identical.py @@ -0,0 +1,23 @@ + +import torch_geometric +from topobenchmarkx.models.readouts.readout import AbstractReadOut + + +class NoReadOut(AbstractReadOut): + r"""No readout layer. This readout layer does not perform any operation on the node embeddings.""" + def __init__(self, **kwargs): + super().__init__() + + def forward(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + r"""Forward pass of the no readout layer. It returns the model output without any modification. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + model_out (dict): Dictionary containing the model output. + """ + return model_out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/topobenchmarkx/models/readouts/old_readout.py b/topobenchmarkx/models/readouts/old_readout.py deleted file mode 100755 index 9db97518..00000000 --- a/topobenchmarkx/models/readouts/old_readout.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from torch_geometric.utils import scatter - -from topobenchmarkx.models.abstractions.readout import AbstractReadOut - - -class GNNBatchReadOut(AbstractReadOut): - r"""Readout layer for GNNs that operates on the batch level. - - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. - """ - def __init__( - self, - in_channels: int, - out_channels: int, - task_level: str, - pooling_type: str = "sum", - ): - super(AbstractReadOut, self).__init__() - self.linear = torch.nn.Linear(in_channels, out_channels) - - assert task_level in ["graph", "node"], "Invalid task_level" - self.task_level = task_level - - assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" - self.pooling_type = pooling_type - - def forward(self, model_out: dict): - r"""Forward pass. - - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - - Returns - ------- - dict - Dictionary containing the updated model output. Resulting key is "logits". - """ - x = model_out["x_0"] - batch = model_out["batch"] - if self.task_level == "graph": - if self.pooling_type == "max": - x = scatter(x, batch, dim=0, reduce="max") - - elif self.pooling_type == "mean": - x = scatter(x, batch, dim=0, reduce="mean") - - elif self.pooling_type == "sum": - x = scatter(x, batch, dim=0, reduce="sum") - - model_out["logits"] = self.linear(x) - return model_out diff --git a/topobenchmarkx/models/readouts/propagate_signal_down.py b/topobenchmarkx/models/readouts/propagate_signal_down.py new file mode 100644 index 00000000..4088dbbf --- /dev/null +++ b/topobenchmarkx/models/readouts/propagate_signal_down.py @@ -0,0 +1,59 @@ +import torch +import torch_geometric +import topomodelx +from topobenchmarkx.models.readouts.readout import AbstractReadOut + +class PropagateSignalDown(AbstractReadOut): + r"""Propagate signal down readout layer. This readout layer propagates the signal from cells of a certain order to the cells of the lower order. + + Args: + num_cell_dimensions (int): Highest order of cells considered by the model. + hidden_dim (int): Dimension of the cells representations. + readout_name (str): Readout name. + """ + def __init__(self, **kwargs): + super().__init__() + + self.name = kwargs["readout_name"] + self.dimensions = range(kwargs["num_cell_dimensions"] - 1, 0, -1) + hidden_dim = kwargs["hidden_dim"] + + for i in self.dimensions: + setattr( + self, + f"agg_conv_{i}", + topomodelx.base.conv.Conv( + hidden_dim, hidden_dim, aggr_norm=False + ), + ) + + setattr(self, f"ln_{i}", torch.nn.LayerNorm(hidden_dim)) + + setattr( + self, + f"projector_{i}", + torch.nn.Linear(2 * hidden_dim, hidden_dim), + ) + + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass of the propagate signal down readout layer. The layer takes the embeddings of the cells of a certain order and applies a convolutional layer to them. Layer normalization is then applied to the features. The output is concatenated with the initial embeddings of the cells and the result is projected with the use of a linear layer to the dimensions of the cells of lower rank. The process is repeated until the nodes embeddings, which are the cells of rank 0, are reached. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + model_out (dict): Dictionary containing the model output. + """ + for i in self.dimensions: + x_i = getattr(self, f"agg_conv_{i}")( + model_out[f"x_{i}"], batch[f"incidence_{i}"] + ) + x_i = getattr(self, f"ln_{i}")(x_i) + model_out[f"x_{i-1}"] = getattr(self, f"projector_{i}")( + torch.cat([x_i, model_out[f"x_{i-1}"]], dim=1) + ) + + return model_out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_cell_dimensions={len(self.dimensions)}, hidden_dim={self.hidden_dim}, readout_name={self.name}" diff --git a/topobenchmarkx/models/readouts/readout.py b/topobenchmarkx/models/readouts/readout.py index c2189d4d..d663da16 100755 --- a/topobenchmarkx/models/readouts/readout.py +++ b/topobenchmarkx/models/readouts/readout.py @@ -1,59 +1,29 @@ import torch import torch_geometric -from torch_geometric.utils import scatter - - -from topobenchmarkx.models.readouts.readouts import PropagateSignalDown -# Implemented Poolings -READOUTS = { - "PropagateSignalDown": PropagateSignalDown -} +from abc import abstractmethod class AbstractReadOut(torch.nn.Module): r"""Readout layer for GNNs that operates on the batch level. - - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. """ - def __init__( - self, - **kwargs - ): - super().__init__() - - self.signal_readout = kwargs["readout_name"] != "None" - if self.signal_readout: - signal_readout_name = kwargs.get("readout_name") - self.readout = READOUTS[signal_readout_name](**kwargs) - def forward(self, model_out: dict, batch: torch_geometric.data.Data): - r"""Forward pass. - - Parameters - ---------- - model_out: dict - Dictionary containing the model output. + def __init__(self,): + super().__init__() - Returns - ------- - dict - Dictionary containing the updated model output. Resulting key is "logits". - """ - # Propagate signal - if self.signal_readout: - model_out = self.readout(model_out, batch) + def __repr__(self): + return f"{self.__class__.__name__}()" - return model_out - + def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + """Readout logic based on model_output.""" + return self.forward(model_out, batch) + @abstractmethod + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass. - \ No newline at end of file + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + dict: Dictionary containing the updated model output. + """ \ No newline at end of file diff --git a/topobenchmarkx/models/readouts/readouts.py b/topobenchmarkx/models/readouts/readouts.py deleted file mode 100644 index 2a340fd9..00000000 --- a/topobenchmarkx/models/readouts/readouts.py +++ /dev/null @@ -1,46 +0,0 @@ -import torch -import topomodelx - - -class PropagateSignalDown(torch.nn.Module): - def __init__(self, **kwargs): - super().__init__() - - self.dimensions = range(kwargs["num_cell_dimensions"] - 1, 0, -1) - hidden_dim = kwargs["hidden_dim"] - - for i in self.dimensions: - setattr( - self, - f"agg_conv_{i}", - topomodelx.base.conv.Conv( - hidden_dim, - hidden_dim, - aggr_norm=False - ) - ) - - setattr( - self, - f"ln_{i}", - torch.nn.LayerNorm(hidden_dim) - ) - - setattr( - self, - f"projector_{i}", - torch.nn.Linear(2*hidden_dim, hidden_dim) - ) - - def __call__(self, model_out, batch): - return self.forward(model_out, batch) - - - def forward(self, model_out, batch): - for i in self.dimensions: - x_i = getattr(self, f"agg_conv_{i}")(model_out[f"x_{i}"], batch[f"incidence_{i}"]) - x_i = getattr(self, f"ln_{i}")(x_i) - model_out[f"x_{i-1}"] = getattr(self, f"projector_{i}")(torch.cat([x_i, model_out[f"x_{i-1}"]], dim=1)) - - return model_out - \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/__init__.py b/topobenchmarkx/models/wrappers/__init__.py index d9f57a40..99d31d5c 100755 --- a/topobenchmarkx/models/wrappers/__init__.py +++ b/topobenchmarkx/models/wrappers/__init__.py @@ -1,25 +1,29 @@ -import hydra # noqa: F401 -import torch -from omegaconf import DictConfig # noqa: F401 +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper +from topobenchmarkx.models.wrappers.graph import GNNWrapper +from topobenchmarkx.models.wrappers.hypergraph import HypergraphWrapper +from topobenchmarkx.models.wrappers.simplicial import SANWrapper, SCNWrapper, SCCNNWrapper, SCCNWrapper +from topobenchmarkx.models.wrappers.cell import CANWrapper, CCCNWrapper, CWNWrapper, CCXNWrapper +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.wrappers.other_wrapper_1 import OtherWrapper1 +# from topobenchmarkx.models.wrappers.other_wrapper_2 import OtherWrapper2 -class DefaultLoss: - """Abstract class that provides an interface to loss logic within netowrk""" - def __init__(self, task): - if task == "classification": - self.criterion = torch.nn.CrossEntropyLoss() - - elif task == "regression": - self.criterion = torch.nn.mse() - else: - raise Exception("Loss is not defined") - - def __call__(self, model_output): - """Loss logic based on model_output""" - - logits = model_output["logits"] - target = model_output["labels"] - model_output["loss"] = self.criterion(logits, target) - - return model_output +# Export all wrappers +__all__ = [ + "DefaultWrapper", + "GNNWrapper", + "HypergraphWrapper", + "SANWrapper", + "SCNWrapper", + "SCCNNWrapper", + "SCCNWrapper", + "CANWrapper", + "CCCNWrapper", + "CWNWrapper", + "CCXNWrapper", + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/__init__.py b/topobenchmarkx/models/wrappers/cell/__init__.py new file mode 100644 index 00000000..efa0e5d9 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/__init__.py @@ -0,0 +1,20 @@ +from topobenchmarkx.models.wrappers.cell.can_wrapper import CANWrapper +from topobenchmarkx.models.wrappers.cell.cccn_wrapper import CCCNWrapper +from topobenchmarkx.models.wrappers.cell.cwn_wrapper import CWNWrapper +from topobenchmarkx.models.wrappers.cell.ccxn_wrapper import CCXNWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +__all__ = [ + "CANWrapper", + "CCCNWrapper", + "CWNWrapper", + "CCXNWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/can_wrapper.py b/topobenchmarkx/models/wrappers/cell/can_wrapper.py new file mode 100644 index 00000000..7b3da4d3 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/can_wrapper.py @@ -0,0 +1,27 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CANWrapper(DefaultWrapper): + r"""Wrapper for the CAN model. This wrapper defines the forward pass of the model. The CAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.""" + + def forward(self, batch): + r"""Forward pass for the CAN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + + x_1 = self.backbone( + x_0=batch.x_0, + x_1=batch.x_1, + adjacency_0=batch.adjacency_0.coalesce(), + down_laplacian_1=batch.down_laplacian_1.coalesce(), + up_laplacian_1=batch.up_laplacian_1.coalesce(), + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_1"] = x_1 + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py b/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py new file mode 100644 index 00000000..8f275d94 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py @@ -0,0 +1,26 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CCCNWrapper(DefaultWrapper): + r"""Wrapper for the CCCN model. This wrapper defines the forward pass of the model. The CCCN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.""" + + def forward(self, batch): + r"""Forward pass for the CCCN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + + x_1 = self.backbone( + batch.x_1, + batch.down_laplacian_1.coalesce(), + batch.up_laplacian_1.coalesce(), + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + + model_out["x_1"] = x_1 + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py b/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py new file mode 100644 index 00000000..0c80c813 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py @@ -0,0 +1,26 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CCXNWrapper(DefaultWrapper): + r"""Wrapper for the CCXN model. This wrapper defines the forward pass of the model. The CCXN model returns the embeddings of the cells of rank 0, 1, and 2.""" + + def forward(self, batch): + r"""Forward pass for the CCXN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + dict: Dictionary containing the updated model output. + """ + + x_0, x_1, x_2 = self.backbone( + x_0=batch.x_0, + x_1=batch.x_1, + adjacency_0=batch.adjacency_0, + incidence_2_t=batch.incidence_2.T, + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["x_1"] = x_1 + model_out["x_2"] = x_2 + return model_out diff --git a/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py b/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py new file mode 100644 index 00000000..a4efd697 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py @@ -0,0 +1,28 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CWNWrapper(DefaultWrapper): + r"""Wrapper for the CWN model. This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2.""" + + def forward(self, batch): + r"""Forward pass for the CWN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + dict: Dictionary containing the updated model output. + """ + + x_0, x_1, x_2 = self.backbone( + x_0=batch.x_0, + x_1=batch.x_1, + x_2=batch.x_2, + incidence_1_t=batch.incidence_1.T, + adjacency_0=batch.adjacency_1, + incidence_2=batch.incidence_2, + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["x_1"] = x_1 + model_out["x_2"] = x_2 + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/graph/__init__.py b/topobenchmarkx/models/wrappers/graph/__init__.py new file mode 100644 index 00000000..74d5787d --- /dev/null +++ b/topobenchmarkx/models/wrappers/graph/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.wrappers.graph.gnn_wrapper import GNNWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +# Export all wrappers +__all__ = [ + "GNNWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] diff --git a/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py b/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py new file mode 100644 index 00000000..0e7b816f --- /dev/null +++ b/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py @@ -0,0 +1,19 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class GNNWrapper(DefaultWrapper): + r"""Wrapper for the GNN models. This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.""" + + def forward(self, batch): + r"""Forward pass for the GNN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + x_0 = self.backbone(batch.x_0, batch.edge_index) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/hypergraph/__init__.py b/topobenchmarkx/models/wrappers/hypergraph/__init__.py new file mode 100644 index 00000000..869f46b0 --- /dev/null +++ b/topobenchmarkx/models/wrappers/hypergraph/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.wrappers.hypergraph.hypergraph_wrapper import HypergraphWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +# Export all wrappers +__all__ = [ + "HypergraphWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py b/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py new file mode 100644 index 00000000..c98a1572 --- /dev/null +++ b/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py @@ -0,0 +1,19 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class HypergraphWrapper(DefaultWrapper): + r"""Wrapper for the hypergraph models. This wrapper defines the forward pass of the model. The hypergraph model return the embeddings of the cells of rank 0, and 1 (the hyperedges).""" + + def forward(self, batch): + r"""Forward pass for the hypergraph wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + x_0, x_1 = self.backbone(batch.x_0, batch.incidence_hyperedges) + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["hyperedge"] = x_1 + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/default_wrapper.py b/topobenchmarkx/models/wrappers/old_wrapper.py similarity index 81% rename from topobenchmarkx/models/wrappers/default_wrapper.py rename to topobenchmarkx/models/wrappers/old_wrapper.py index b3afd5b0..405693d4 100755 --- a/topobenchmarkx/models/wrappers/default_wrapper.py +++ b/topobenchmarkx/models/wrappers/old_wrapper.py @@ -1,18 +1,16 @@ from abc import ABC, abstractmethod -import topomodelx import torch -from torch_geometric.nn.norm import GraphNorm -import torch.nn as nn - +import torch.nn as nn class DefaultWrapper(ABC, torch.nn.Module): - """Abstract class that provides an interface to handle the network output""" + """Abstract class that provides an interface to handle the network + output.""" def __init__(self, backbone, **kwargs): super().__init__() - self.backbone = backbone + self.backbone = backbone out_channels = kwargs["out_channels"] self.dimensions = range(kwargs["num_cell_dimensions"]) @@ -24,84 +22,100 @@ def __init__(self, backbone, **kwargs): ) def __call__(self, batch): - """Define logic for forward pass""" + """Define logic for forward pass.""" model_out = self.forward(batch) model_out = self.residual_connection(model_out=model_out, batch=batch) return model_out def residual_connection(self, model_out, batch): for i in self.dimensions: - if (f"x_{i}" in batch) and hasattr(self, f"ln_{i}") and (f"x_{i}" in model_out): + if ( + (f"x_{i}" in batch) + and hasattr(self, f"ln_{i}") + and (f"x_{i}" in model_out) + ): residual = model_out[f"x_{i}"] + batch[f"x_{i}"] model_out[f"x_{i}"] = getattr(self, f"ln_{i}")(residual) return model_out - + @abstractmethod def forward(self, batch): - """Define handling output here""" + """Define handling output here.""" + class GNNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" # def __init__(self, backbone, **kwargs): # super().__init__(backbone) def forward(self, batch): - """Define logic for forward pass""" + """Define logic for forward pass.""" x_0 = self.backbone(batch.x_0, batch.edge_index) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 - + return model_out class HypergraphWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" + """Define logic for forward pass.""" x_0, x_1 = self.backbone(batch.x_0, batch.incidence_hyperedges) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 model_out["hyperedge"] = x_1 - + return model_out class SANWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - x_1 = self.backbone(batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1) + """Define logic for forward pass.""" + x_1 = self.backbone( + batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1 + ) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) model_out["x_1"] = x_1 return model_out + class SCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - - + """Define logic for forward pass.""" + laplacian_0 = self.normalize_matrix(batch.hodge_laplacian_0) laplacian_1 = self.normalize_matrix(batch.hodge_laplacian_1) laplacian_2 = self.normalize_matrix(batch.hodge_laplacian_2) x_0, x_1, x_2 = self.backbone( - batch.x_0, batch.x_1, batch.x_2, laplacian_0, laplacian_1, laplacian_2 + batch.x_0, + batch.x_1, + batch.x_2, + laplacian_0, + laplacian_1, + laplacian_2, ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_2"] = x_2 model_out["x_1"] = x_1 model_out["x_0"] = x_0 return model_out - + def normalize_matrix(self, matrix): matrix_ = matrix.to_dense() n, _ = matrix_.shape @@ -117,14 +131,16 @@ def normalize_matrix(self, matrix): diag_indices, diag_sum, matrix_.shape, device=matrix.device ).coalesce() normalized_matrix = diag_matrix @ (matrix @ diag_matrix) - return normalized_matrix + return normalized_matrix + class SCCNNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_all = (batch.x_0, batch.x_1, batch.x_2) laplacian_all = ( batch.hodge_laplacian_0, @@ -136,22 +152,23 @@ def forward(self, batch): incidence_all = (batch.incidence_1, batch.incidence_2) x_0, x_1, x_2 = self.backbone(x_all, laplacian_all, incidence_all) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} - + model_out["x_0"] = x_0 model_out["x_1"] = x_1 model_out["x_2"] = x_2 - + return model_out class SCCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + features = { f"rank_{r}": batch[f"x_{r}"] for r in range(self.backbone.layers[0].max_rank + 1) @@ -169,30 +186,37 @@ def forward(self, batch): # TODO: First decide which strategy is the best then make code general model_out = {"labels": batch.y, "batch_0": batch.batch_0} if len(output) == 3: - x_0, x_1, x_2 = output["rank_0"], output["rank_1"], output["rank_2"] - + x_0, x_1, x_2 = ( + output["rank_0"], + output["rank_1"], + output["rank_2"], + ) + model_out["x_2"] = x_2 model_out["x_1"] = x_1 model_out["x_0"] = x_0 elif len(output) == 2: x_0, x_1 = output["rank_0"], output["rank_1"] - + model_out["x_1"] = x_1 model_out["x_0"] = x_0 - + else: - raise ValueError(f"Invalid number of output tensors: {len(output)}") + raise ValueError( + f"Invalid number of output tensors: {len(output)}" + ) return model_out class CANWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_1 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, @@ -200,7 +224,7 @@ def forward(self, batch): down_laplacian_1=batch.down_laplacian_1.coalesce(), up_laplacian_1=batch.up_laplacian_1.coalesce(), ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_1"] = x_1 model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) @@ -208,30 +232,32 @@ def forward(self, batch): class CWNDCMWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_1 = self.backbone( batch.x_1, batch.down_laplacian_1.coalesce(), batch.up_laplacian_1.coalesce(), ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} - + model_out["x_1"] = x_1 - model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) return model_out class CWNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, @@ -249,18 +275,19 @@ def forward(self, batch): class CCXNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, adjacency_0=batch.adjacency_0, incidence_2_t=batch.incidence_2.T, ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 model_out["x_1"] = x_1 diff --git a/topobenchmarkx/models/wrappers/simplicial/__init__.py b/topobenchmarkx/models/wrappers/simplicial/__init__.py new file mode 100644 index 00000000..7dd8b690 --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/__init__.py @@ -0,0 +1,21 @@ +from topobenchmarkx.models.wrappers.simplicial.san_wrapper import SANWrapper +from topobenchmarkx.models.wrappers.simplicial.scn_wrapper import SCNWrapper +from topobenchmarkx.models.wrappers.simplicial.sccnn_wrapper import SCCNNWrapper +from topobenchmarkx.models.wrappers.simplicial.sccn_wrapper import SCCNWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +# Export all wrappers and the dictionary +__all__ = [ + "SANWrapper", + "SCNWrapper", + "SCCNNWrapper", + "SCCNWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py new file mode 100644 index 00000000..022fc8f5 --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py @@ -0,0 +1,22 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SANWrapper(DefaultWrapper): + r"""Wrapper for the SAN model. This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.""" + + def forward(self, batch): + r"""Forward pass for the SAN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + x_1 = self.backbone( + batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1 + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + model_out["x_1"] = x_1 + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py new file mode 100644 index 00000000..9433274b --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py @@ -0,0 +1,53 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SCCNWrapper(DefaultWrapper): + r"""Wrapper for the SCCN model. This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.""" + + def forward(self, batch): + r"""Forward pass for the SCCN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + + features = { + f"rank_{r}": batch[f"x_{r}"] + for r in range(self.backbone.layers[0].max_rank + 1) + } + incidences = { + f"rank_{r}": batch[f"incidence_{r}"] + for r in range(1, self.backbone.layers[0].max_rank + 1) + } + adjacencies = { + f"rank_{r}": batch[f"hodge_laplacian_{r}"] + for r in range(self.backbone.layers[0].max_rank + 1) + } + output = self.backbone(features, incidences, adjacencies) + + # TODO: First decide which strategy is the best then make code general + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + if len(output) == 3: + x_0, x_1, x_2 = ( + output["rank_0"], + output["rank_1"], + output["rank_2"], + ) + + model_out["x_2"] = x_2 + model_out["x_1"] = x_1 + model_out["x_0"] = x_0 + + elif len(output) == 2: + x_0, x_1 = output["rank_0"], output["rank_1"] + + model_out["x_1"] = x_1 + model_out["x_0"] = x_0 + + else: + raise ValueError( + f"Invalid number of output tensors: {len(output)}" + ) + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py new file mode 100644 index 00000000..07e7687e --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py @@ -0,0 +1,33 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SCCNNWrapper(DefaultWrapper): + r"""Wrapper for the SCCNN model. This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.""" + + def forward(self, batch): + r"""Forward pass for the SCCNN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + + x_all = (batch.x_0, batch.x_1, batch.x_2) + laplacian_all = ( + batch.hodge_laplacian_0, + batch.down_laplacian_1, + batch.up_laplacian_1, + batch.down_laplacian_2, + batch.up_laplacian_2, + ) + + incidence_all = (batch.incidence_1, batch.incidence_2) + x_0, x_1, x_2 = self.backbone(x_all, laplacian_all, incidence_all) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + + model_out["x_0"] = x_0 + model_out["x_1"] = x_1 + model_out["x_2"] = x_2 + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py new file mode 100644 index 00000000..75cb5005 --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py @@ -0,0 +1,57 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SCNWrapper(DefaultWrapper): + r"""Wrapper for the SCNW model. This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.""" + + def forward(self, batch): + r"""Forward pass for the SCNW wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + + laplacian_0 = self.normalize_matrix(batch.hodge_laplacian_0) + laplacian_1 = self.normalize_matrix(batch.hodge_laplacian_1) + laplacian_2 = self.normalize_matrix(batch.hodge_laplacian_2) + x_0, x_1, x_2 = self.backbone( + batch.x_0, + batch.x_1, + batch.x_2, + laplacian_0, + laplacian_1, + laplacian_2, + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_2"] = x_2 + model_out["x_1"] = x_1 + model_out["x_0"] = x_0 + + return model_out + + def normalize_matrix(self, matrix): + r"""Normalize the input matrix. The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows. + + Args: + matrix (torch.sparse.FloatTensor): Input matrix to be normalized. + Returns: + torch.sparse.FloatTensor: Normalized matrix. + """ + matrix_ = matrix.to_dense() + n, _ = matrix_.shape + abs_matrix = abs(matrix_) + diag_sum = abs_matrix.sum(axis=1) + + # Handle division by zero + idxs = torch.where(diag_sum != 0) + diag_sum[idxs] = 1.0 / torch.sqrt(diag_sum[idxs]) + + diag_indices = torch.stack([torch.arange(n), torch.arange(n)]) + diag_matrix = torch.sparse_coo_tensor( + diag_indices, diag_sum, matrix_.shape, device=matrix.device + ).coalesce() + normalized_matrix = diag_matrix @ (matrix @ diag_matrix) + return normalized_matrix \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/wrapper.py b/topobenchmarkx/models/wrappers/wrapper.py new file mode 100755 index 00000000..b9d6fd9e --- /dev/null +++ b/topobenchmarkx/models/wrappers/wrapper.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +import torch +import torch.nn as nn + +class DefaultWrapper(ABC, torch.nn.Module): + r"""Abstract class that provides an interface to handle the network + output. + + Args: + backbone (torch.nn.Module): Backbone model. + out_channels (int): Number of output channels. + num_cell_dimensions (int): Number of cell dimensions. + """ + def __init__(self, backbone, **kwargs): + super().__init__() + self.backbone = backbone + out_channels = kwargs["out_channels"] + self.dimensions = range(kwargs["num_cell_dimensions"]) + + for i in self.dimensions: + setattr( + self, + f"ln_{i}", + nn.LayerNorm(out_channels), + ) + + def __repr__(self): + return f"{self.__class__.__name__}(backbone={self.backbone}, out_channels={self.backbone.out_channels}, dimensions={self.dimensions})" + + def __call__(self, batch): + r"""Forward pass for the model. This method calls the forward method and adds the residual connection.""" + model_out = self.forward(batch) + model_out = self.residual_connection(model_out=model_out, batch=batch) + return model_out + + def residual_connection(self, model_out, batch): + r"""Residual connection for the model. This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.""" + for i in self.dimensions: + if ( + (f"x_{i}" in batch) + and hasattr(self, f"ln_{i}") + and (f"x_{i}" in model_out) + ): + residual = model_out[f"x_{i}"] + batch[f"x_{i}"] + model_out[f"x_{i}"] = getattr(self, f"ln_{i}")(residual) + return model_out + + @abstractmethod + def forward(self, batch): + r"""Forward pass for the model. This method should be implemented by the child class. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ + pass diff --git a/topobenchmarkx/play.ipynb b/topobenchmarkx/play.ipynb index 0f18b067..ac254ec3 100644 --- a/topobenchmarkx/play.ipynb +++ b/topobenchmarkx/play.ipynb @@ -7,8 +7,9 @@ "outputs": [], "source": [ "import torch\n", - "run_1 = torch.load('/home/lev/projects/TopoBenchmarkX/run_1')\n", - "run_2 = torch.load('/home/lev/projects/TopoBenchmarkX/run_2')\n" + "\n", + "run_1 = torch.load(\"/home/lev/projects/TopoBenchmarkX/run_1\")\n", + "run_2 = torch.load(\"/home/lev/projects/TopoBenchmarkX/run_2\")" ] }, { @@ -65,8 +66,7 @@ ], "source": [ "for key in run_1.keys():\n", - " print(f'{key} {(run_1[key] == run_2[key]).all()}')\n", - " " + " print(f\"{key} {(run_1[key] == run_2[key]).all()}\")" ] }, { @@ -76,7 +76,11 @@ "outputs": [], "source": [ "import numpy as np\n", - "run1 = np.load('/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz', allow_pickle=True)" + "\n", + "run1 = np.load(\n", + " \"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz\",\n", + " allow_pickle=True,\n", + ")" ] }, { @@ -85,7 +89,10 @@ "metadata": {}, "outputs": [], "source": [ - "run2 = np.load('/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz', allow_pickle=True)" + "run2 = np.load(\n", + " \"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz\",\n", + " allow_pickle=True,\n", + ")" ] }, { diff --git a/topobenchmarkx/run_cellular_scripts.sh b/topobenchmarkx/run_cellular_scripts.sh new file mode 100644 index 00000000..075d6824 --- /dev/null +++ b/topobenchmarkx/run_cellular_scripts.sh @@ -0,0 +1,9 @@ +# Run the scripts from the hypergraph directory +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CWN.sh +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CCCN.sh +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CAN.sh +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CCXN.sh + +# Run in case we have time +# bash ~/TopoBenchmarkX/hp_scripts/main_exp/cellular/left_out.sh + diff --git a/topobenchmarkx/run_graph_scripts.sh b/topobenchmarkx/run_graph_scripts.sh new file mode 100644 index 00000000..21685efb --- /dev/null +++ b/topobenchmarkx/run_graph_scripts.sh @@ -0,0 +1,5 @@ +# Run the scripts from the graph directory +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/graph/gcn.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/graph/gin.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/graph/gat.sh + diff --git a/topobenchmarkx/run_hypergraph_scripts.sh b/topobenchmarkx/run_hypergraph_scripts.sh new file mode 100644 index 00000000..42c205e2 --- /dev/null +++ b/topobenchmarkx/run_hypergraph_scripts.sh @@ -0,0 +1,7 @@ +# Run the scripts from the hypergraph directory +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/edgnn.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/allsettransformer.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/unignn2.sh + +# Run in case we have time +# bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/left_out.sh \ No newline at end of file diff --git a/topobenchmarkx/simplicial.sh b/topobenchmarkx/simplicial.sh index 2e20131e..9573e371 100644 --- a/topobenchmarkx/simplicial.sh +++ b/topobenchmarkx/simplicial.sh @@ -23,6 +23,9 @@ # # python train.py dataset=IMDB-BINARY model=simplicial/sccn model.optimizer.lr=0.01,0.001 model.feature_encoder.out_channels=16,64 model.backbone.n_layers=1,2 dataset.parameters.batch_size=128 dataset.parameters.data_seed=0,3,5 trainer.check_val_every_n_epoch=5 callbacks.early_stopping.patience=10 trainer=default logger.wandb.project=topobenchmark_0503 model.backbone_wrapper.wrapper_readout=original,signal_prop_down model.readout.pooling_type=sum,mean --multirun # # python train.py dataset=IMDB-MULTI model=simplicial/sccn model.optimizer.lr=0.01,0.001 model.feature_encoder.out_channels=16,64 model.backbone.n_layers=1,2 dataset.parameters.batch_size=128 dataset.parameters.data_seed=0,3,5 trainer.check_val_every_n_epoch=5 callbacks.early_stopping.patience=10 trainer=default logger.wandb.project=topobenchmark_0503 model.backbone_wrapper.wrapper_readout=original,signal_prop_down model.readout.pooling_type=sum,mean --multirun +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x + + # # SCCNN # # Fixed split # python train.py dataset=ZINC model=simplicial/sccnn model.optimizer.lr=0.01 model.feature_encoder.out_channels=16,64 model.backbone.n_layers=2,4 dataset.parameters.batch_size=128 dataset.parameters.data_seed=0 trainer.check_val_every_n_epoch=5 callbacks.early_stopping.patience=10 trainer=default logger.wandb.project=topobenchmark_0503 dataset.transforms.graph2simplicial_lifting.complex_dim=3 model.backbone_wrapper.wrapper_readout=original,signal_prop_down model.readout.pooling_type=sum,mean callbacks.early_stopping.min_delta=0.005 dataset.transforms.graph2simplicial_lifting.signed=True,False --multirun diff --git a/topobenchmarkx/stat.sh b/topobenchmarkx/stat.sh new file mode 100644 index 00000000..67e10248 --- /dev/null +++ b/topobenchmarkx/stat.sh @@ -0,0 +1,61 @@ + +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +models=( 'cell/cwn' ) +for model in ${models[*]} +do + + +python dataset_statistics.py \ + dataset=us_country_demos \ + model=$model \ + + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python dataset_statistics.py \ + dataset=$dataset \ + model=$model + +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python dataset_statistics.py \ + dataset=ZINC \ + model=$model \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x + + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings minesweeper ) + +for dataset in ${datasets[*]} +do + python dataset_statistics.py \ + dataset=$dataset \ + model=$model +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python dataset_statistics.py \ + dataset=MUTAG \ + model=$model + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI') # + +for dataset in ${datasets[*]} +do + python dataset_statistics.py \ + dataset=$dataset \ + model=$model +done + +done \ No newline at end of file diff --git a/topobenchmarkx/train.py b/topobenchmarkx/train.py index 33fa9e68..d777de81 100755 --- a/topobenchmarkx/train.py +++ b/topobenchmarkx/train.py @@ -1,25 +1,18 @@ -import numpy as np import random -from typing import Any, Optional +from typing import Any import hydra import lightning as L +import numpy as np import rootutils +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) import torch from lightning import Callback, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig, OmegaConf -from topobenchmarkx.utils.config_resolvers import ( - get_default_transform, - get_monitor_metric, - get_monitor_mode, - infer_in_channels, - infere_list_length, -) - -from topobenchmarkx.data.dataloader_fullbatch import DefaultDataModule +from topobenchmarkx.data.dataloaders import DefaultDataModule from topobenchmarkx.utils import ( RankedLogger, extras, @@ -30,6 +23,14 @@ task_wrapper, ) +from topobenchmarkx.utils.config_resolvers import ( + get_default_transform, + get_monitor_metric, + get_monitor_mode, + infer_in_channels, + infere_list_length, +) + # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: # - adding project root dir to PYTHONPATH @@ -47,7 +48,6 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # -rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) OmegaConf.register_new_resolver("get_default_transform", get_default_transform) OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) @@ -64,18 +64,19 @@ @task_wrapper def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. + """Trains the model. Can additionally evaluate on a testset, using best + weights obtained during training. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. + This method is wrapped in optional @task_wrapper decorator, that controls + the behavior during failure. Useful for multiruns, saving info about the + crash, etc. :param cfg: A DictConfig configuration composed by Hydra. :return: A tuple with metrics and dict with all instantiated objects. """ # Set seed for random number generators in pytorch, numpy and python.random - #if cfg.get("seed"): + # if cfg.get("seed"): L.seed_everything(cfg.seed, workers=True) # Seed for torch torch.manual_seed(cfg.seed) @@ -84,7 +85,6 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: # Seed for python random random.seed(cfg.seed) - # Instantiate and load dataset dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False) dataset = dataset.load() @@ -140,7 +140,9 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + trainer.fit( + model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") + ) train_metrics = trainer.callback_metrics @@ -148,7 +150,9 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") + log.warning( + "Best ckpt not found! Using current weights for testing..." + ) ckpt_path = None trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best ckpt path: {ckpt_path}") @@ -164,8 +168,8 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: def count_number_of_parameters( model: torch.nn.Module, only_trainable: bool = True ) -> int: - """ - Counts the number of trainable params. If all params, specify only_trainable = False. + """Counts the number of trainable params. If all params, specify + only_trainable = False. Ref: - https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=brando_miranda @@ -173,15 +177,19 @@ def count_number_of_parameters( :return: """ if only_trainable: - num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad) + num_params: int = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) else: # counts trainable and none-traibale num_params: int = sum(p.numel() for p in model.parameters() if p) assert num_params > 0, f"Err: {num_params=}" return int(num_params) -@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") -def main(cfg: DictConfig) -> Optional[float]: +@hydra.main( + version_base="1.3", config_path="../configs", config_name="train.yaml" +) +def main(cfg: DictConfig) -> float | None: """Main entry point for training. :param cfg: DictConfig configuration composed by Hydra. @@ -204,6 +212,4 @@ def main(cfg: DictConfig) -> Optional[float]: if __name__ == "__main__": - main() - diff --git a/topobenchmarkx/transforms/__init__.py b/topobenchmarkx/transforms/__init__.py index e69de29b..5baa9ac9 100755 --- a/topobenchmarkx/transforms/__init__.py +++ b/topobenchmarkx/transforms/__init__.py @@ -0,0 +1,62 @@ +# Data manipulation transforms +from topobenchmarkx.transforms.data_manipulations import ( + CalculateSimplicialCurvature, + EqualGausFeatures, + IdentityTransform, + InfereKNNConnectivity, + InfereRadiusConnectivity, + KeepOnlyConnectedComponent, + KeepSelectedDataFields, + NodeDegrees, + NodeFeaturesToFloat, + OneHotDegreeFeatures, +) + +# Feature liftings +from topobenchmarkx.transforms.feature_liftings import ( + ConcatentionLifting, + ProjectionSum, + SetLifting, +) + +# Topology Liftings +from topobenchmarkx.transforms.liftings import ( + CellCyclesLifting, + HypergraphKHopLifting, + HypergraphKNearestNeighborsLifting, + SimplicialCliqueLifting, + SimplicialNeighborhoodLifting, +) + +# Dictionalry of all available transforms +TRANSFORMS = { + # Graph -> Hypergraph + "HypergraphKHopLifting": HypergraphKHopLifting, + "HypergraphKNearestNeighborsLifting": HypergraphKNearestNeighborsLifting, + # Graph -> Simplicial Complex + "SimplicialNeighborhoodLifting": SimplicialNeighborhoodLifting, + "SimplicialCliqueLifting": SimplicialCliqueLifting, + # Graph -> Cell Complex + "CellCyclesLifting": CellCyclesLifting, + # Feature Liftings + "ProjectionSum": ProjectionSum, + "ConcatentionLifting": ConcatentionLifting, + "SetLifting": SetLifting, + # Data Manipulations + "Identity": IdentityTransform, + "InfereKNNConnectivity": InfereKNNConnectivity, + "InfereRadiusConnectivity": InfereRadiusConnectivity, + "NodeDegrees": NodeDegrees, + "OneHotDegreeFeatures": OneHotDegreeFeatures, + "EqualGausFeatures": EqualGausFeatures, + "NodeFeaturesToFloat": NodeFeaturesToFloat, + "CalculateSimplicialCurvature": CalculateSimplicialCurvature, + "KeepOnlyConnectedComponent": KeepOnlyConnectedComponent, + "KeepSelectedDataFields": KeepSelectedDataFields, +} + + + +__all__ = [ + "TRANSFORMS", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/__init__.py b/topobenchmarkx/transforms/data_manipulations/__init__.py index e69de29b..e626400b 100644 --- a/topobenchmarkx/transforms/data_manipulations/__init__.py +++ b/topobenchmarkx/transforms/data_manipulations/__init__.py @@ -0,0 +1,26 @@ +from topobenchmarkx.transforms.data_manipulations.identity_transform import IdentityTransform +from topobenchmarkx.transforms.data_manipulations.infere_knn_connectivity import InfereKNNConnectivity +from topobenchmarkx.transforms.data_manipulations.infere_radius_connectivity import InfereRadiusConnectivity +from topobenchmarkx.transforms.data_manipulations.equal_gaus_features import EqualGausFeatures +from topobenchmarkx.transforms.data_manipulations.node_features_to_float import NodeFeaturesToFloat +from topobenchmarkx.transforms.data_manipulations.node_degrees import NodeDegrees +from topobenchmarkx.transforms.data_manipulations.keep_only_connected_component import KeepOnlyConnectedComponent +from topobenchmarkx.transforms.data_manipulations.calculate_simplicial_curvature import CalculateSimplicialCurvature +from topobenchmarkx.transforms.data_manipulations.one_hot_degree_features import OneHotDegreeFeatures +from topobenchmarkx.transforms.data_manipulations.keep_selected_data_fields import KeepSelectedDataFields + + + + +__all__ = [ + "IdentityTransform", + "InfereKNNConnectivity", + "InfereRadiusConnectivity", + "EqualGausFeatures", + "NodeFeaturesToFloat", + "NodeDegrees", + "KeepOnlyConnectedComponent", + "CalculateSimplicialCurvature", + "OneHotDegreeFeatures", + "KeepSelectedDataFields", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py b/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py new file mode 100644 index 00000000..d0e9e572 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py @@ -0,0 +1,93 @@ +import torch +import torch_geometric + +class CalculateSimplicialCurvature(torch_geometric.transforms.BaseTransform): + r"""A transform that calculates the simplicial curvature of the input graph. + + Args: + kwargs (optional): Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "simplicial_curvature" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + data = self.one_cell_curvature(data) + data = self.zero_cell_curvature(data) + data = self.two_cell_curvature(data) + return data + + def zero_cell_curvature( + self, + data: torch_geometric.data.Data, + ) -> torch_geometric.data.Data: + r"""Calculate the zero cell curvature of the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the zero cell curvature. + """ + data["0_cell_curvature"] = torch.mm( + abs(data["incidence_1"]), data["1_cell_curvature"] + ) + return data + + def one_cell_curvature( + self, + data: torch_geometric.data.Data, + ) -> torch_geometric.data.Data: + r"""Calculate the one cell curvature of the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the one cell curvature. + """ + data["1_cell_curvature"] = ( + 4 + - torch.mm(abs(data["incidence_1"]).T, data["0_cell_degrees"]) + + 3 * data["1_cell_degrees"] + ) + return data + + def two_cell_curvature( + self, + data: torch_geometric.data.Data, + ) -> torch_geometric.data.Data: + r"""Calculate the two cell curvature of the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the two cell curvature. + """ + # Term 1 is simply the degree of the 2-cell (i.e. each triangle belong to n tetrahedrons) + term1 = data["2_cell_degrees"] + # Find triangles that belong to multiple tetrahedrons + two_cell_degrees = data["2_cell_degrees"].clone() + idx = torch.where(data["2_cell_degrees"] > 1)[0] + two_cell_degrees[idx] = 0 + up = data["incidence_3"].to_dense() @ data["incidence_3"].to_dense().T + down = ( + data["incidence_2"].to_dense().T @ data["incidence_2"].to_dense() + ) + mask = torch.eye(up.size()[0]).bool() + up.masked_fill_(mask, 0) + down.masked_fill_(mask, 0) + diff = (down - up) * 1 + term2 = diff.sum(1, keepdim=True) + data["2_cell_curvature"] = 3 + term1 - term2 + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py b/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py new file mode 100644 index 00000000..1926d47d --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py @@ -0,0 +1,38 @@ +import torch +import torch_geometric + +class EqualGausFeatures(torch_geometric.transforms.BaseTransform): + r"""A transform that generates equal Gaussian features for all nodes in the + input graph. + + Args: + mean (float): The mean of the Gaussian distribution. + std (float): The standard deviation of the Gaussian distribution. + num_features (int): The number of features to generate. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "generate_non_informative_features" + + # Torch generate feature vector from gaus distribution + self.mean = kwargs["mean"] + self.std = kwargs["std"] + self.feature_vector = kwargs["num_features"] + self.feature_vector = torch.normal( + mean=self.mean, std=self.std, size=(1, self.feature_vector) + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, mean={self.mean!r}, std={self.std!r}, feature_vector={self.feature_vector!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + data.x = self.feature_vector.expand(data.num_nodes, -1) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/identity_transform.py b/topobenchmarkx/transforms/data_manipulations/identity_transform.py new file mode 100644 index 00000000..14b7bdc1 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/identity_transform.py @@ -0,0 +1,26 @@ +import torch_geometric + +class IdentityTransform(torch_geometric.transforms.BaseTransform): + r"""An identity transform that does nothing to the input data. + + Args: + kwargs (optional): Parameters for the base transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "domain2domain" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The same data. + """ + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py b/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py new file mode 100644 index 00000000..403c280f --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py @@ -0,0 +1,32 @@ +import torch_geometric +from torch_geometric.nn import knn_graph + +class InfereKNNConnectivity(torch_geometric.transforms.BaseTransform): + r"""A transform that generates the k-nearest neighbor connectivity of the + input point cloud. + + Args: + kwargs (optional): Parameters for the base transform.""" + + def __init__(self, **kwargs): + super().__init__() + self.type = "infere_knn_connectivity" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + + edge_index = knn_graph(data.x, **self.parameters["args"]) + + # Remove duplicates + data.edge_index = edge_index + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py b/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py new file mode 100644 index 00000000..d4481984 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py @@ -0,0 +1,28 @@ +import torch_geometric +from torch_geometric.nn import radius_graph + +class InfereRadiusConnectivity(torch_geometric.transforms.BaseTransform): + r"""A transform that generates the radius connectivity of the input point + cloud. + + Args: + kwargs (optional): Parameters for the base transform.""" + + def __init__(self, **kwargs): + super().__init__() + self.type = "infere_radius_connectivity" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + data.edge_index = radius_graph(data.x, **self.parameters["args"]) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py b/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py new file mode 100644 index 00000000..d65643a4 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py @@ -0,0 +1,35 @@ +import torch_geometric +from torch_geometric.transforms import LargestConnectedComponents + +class KeepOnlyConnectedComponent(torch_geometric.transforms.BaseTransform): + """A transform that keeps only the largest connected components of the + input graph. + + Args: + kwargs (optional): Parameters for the base transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "keep_connected_component" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + """Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + + # torch_geometric.transforms.largest_connected_components() + num_components = self.parameters["num_components"] + lcc = LargestConnectedComponents( + num_components=num_components, connection="strong" + ) + data = lcc(data) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py b/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py new file mode 100644 index 00000000..c9d400a8 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py @@ -0,0 +1,35 @@ +import torch_geometric + +class KeepSelectedDataFields(torch_geometric.transforms.BaseTransform): + r"""A transform that keeps only the selected fields of the input data. + + Args: + kwargs (optional): Parameters for the base transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "keep_selected_data_fields" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + # Keeps all the fields + fields_to_keep = ( + self.parameters["base_fields"] + + self.parameters["preserved_fields"] + ) + + for key in data: + if key not in fields_to_keep: + del data[key] + return data diff --git a/topobenchmarkx/transforms/data_manipulations/manipulations.py b/topobenchmarkx/transforms/data_manipulations/manipulations.py index 3ae9d463..7cbc147a 100644 --- a/topobenchmarkx/transforms/data_manipulations/manipulations.py +++ b/topobenchmarkx/transforms/data_manipulations/manipulations.py @@ -5,49 +5,55 @@ class IdentityTransform(torch_geometric.transforms.BaseTransform): - r"""An identity transform that does nothing to the input data.""" + r"""An identity transform that does nothing to the input data. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "domain2domain" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The (un)transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The same data. """ return data class InfereKNNConnectivity(torch_geometric.transforms.BaseTransform): - r"""A transform that generates the k-nearest neighbor connectivity of the input point cloud.""" + r"""A transform that generates the k-nearest neighbor connectivity of the + input point cloud. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "infere_knn_connectivity" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ - edge_index = knn_graph(data.x, **self.parameters["args"]) # Remove duplicates @@ -56,35 +62,41 @@ def forward(self, data: torch_geometric.data.Data): class InfereRadiusConnectivity(torch_geometric.transforms.BaseTransform): - r"""A transform that generates the radius connectivity of the input point cloud.""" + r"""A transform that generates the radius connectivity of the input point + cloud. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "infere_radius_connectivity" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.edge_index = radius_graph(data.x, **self.parameters["args"]) return data class EqualGausFeatures(torch_geometric.transforms.BaseTransform): - r"""A transform that generates equal Gaussian features for all nodes in the input graph. + r"""A transform that generates equal Gaussian features for all nodes in the + input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + mean (float): The mean of the Gaussian distribution. + std (float): The standard deviation of the Gaussian distribution. + num_features (int): The number of features to generate. """ def __init__(self, **kwargs): @@ -99,18 +111,16 @@ def __init__(self, **kwargs): mean=self.mean, std=self.std, size=(1, self.feature_vector) ) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, mean={self.mean!r}, std={self.std!r}, feature_vector={self.feature_vector!r}" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.x = self.feature_vector.expand(data.num_nodes, -1) return data @@ -118,29 +128,25 @@ def forward(self, data: torch_geometric.data.Data): class NodeFeaturesToFloat(torch_geometric.transforms.BaseTransform): r"""A transform that converts the node features of the input graph to float. - - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): super().__init__() self.type = "map_node_features_to_float" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.x = data.x.float() return data @@ -149,33 +155,31 @@ def forward(self, data: torch_geometric.data.Data): class NodeDegrees(torch_geometric.transforms.BaseTransform): r"""A transform that calculates the node degrees of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): super().__init__() self.type = "node_degrees" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r}" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ - field_to_process = [key for key in data - for field_substring in self.parameters["selected_fields"] - if field_substring in key and key != "incidence_0" + field_to_process = [ + key + for key in data.keys() + for field_substring in self.parameters["selected_fields"] + if field_substring in key and key != "incidence_0" ] for field in field_to_process: data = self.calculate_node_degrees(data, field) @@ -187,16 +191,11 @@ def calculate_node_degrees( ) -> torch_geometric.data.Data: r"""Calculate the node degrees of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - field : str - The field to calculate the node degrees. - - Returns - ------- - torch_geometric.data.Data + Args: + data (torch_geometric.data.Data): The input data. + field (str): The field to calculate the node degrees. + Returns: + torch_geometric.data.Data: The transformed data. """ if data[field].is_sparse: degrees = abs(data[field].to_dense()).sum(1) @@ -204,17 +203,25 @@ def calculate_node_degrees( assert ( field == "edge_index" ), "Following logic of finding degrees is only implemented for edge_index" + + # Get number of nodes + if data.get("num_nodes", None): + max_num_nodes = data["num_nodes"] + else: + max_num_nodes = data["x"].shape[0] degrees = ( torch_geometric.utils.to_dense_adj( data[field], - max_num_nodes=data["x"].shape[0], # data["num_nodes"] + max_num_nodes=max_num_nodes, ) .squeeze(0) .sum(1) ) if "incidence" in field: - field_name = str(int(field.split("_")[1]) - 1) + "_cell" + "_degrees" + field_name = ( + str(int(field.split("_")[1]) - 1) + "_cell" + "_degrees" + ) else: field_name = "node_degrees" @@ -223,13 +230,11 @@ def calculate_node_degrees( class KeepOnlyConnectedComponent(torch_geometric.transforms.BaseTransform): - """ - A transform that keeps only the largest connected components of the input graph. + """A transform that keeps only the largest connected components of the + input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -237,19 +242,16 @@ def __init__(self, **kwargs): self.type = "keep_connected_component" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): - """ - Apply the transform to the input data. + """Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ from torch_geometric.transforms import LargestConnectedComponents @@ -263,13 +265,10 @@ def forward(self, data: torch_geometric.data.Data): class CalculateSimplicialCurvature(torch_geometric.transforms.BaseTransform): - """ - A transform that calculates the simplicial curvature of the input graph. + """A transform that calculates the simplicial curvature of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -277,19 +276,16 @@ def __init__(self, **kwargs): self.type = "simplicial_curvature" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r}" + def forward(self, data: torch_geometric.data.Data): - """ - Apply the transform to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. + """Apply the transform to the input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data = self.one_cell_curvature(data) data = self.zero_cell_curvature(data) @@ -300,18 +296,12 @@ def zero_cell_curvature( self, data: torch_geometric.data.Data, ) -> torch_geometric.data.Data: - """ - Calculate the zero cell curvature of the input data. + """Calculate the zero cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the zero cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the zero cell curvature added as a field. """ data["0_cell_curvature"] = torch.mm( abs(data["incidence_1"]), data["1_cell_curvature"] @@ -324,15 +314,10 @@ def one_cell_curvature( ) -> torch_geometric.data.Data: r"""Calculate the one cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the one cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the one cell curvature added as a field. """ data["1_cell_curvature"] = ( 4 @@ -347,15 +332,10 @@ def two_cell_curvature( ) -> torch_geometric.data.Data: r"""Calculate the two cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the two cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the two cell curvature added as a field. """ # Term 1 is simply the degree of the 2-cell (i.e. each triangle belong to n tetrahedrons) term1 = data["2_cell_degrees"] @@ -364,7 +344,9 @@ def two_cell_curvature( idx = torch.where(data["2_cell_degrees"] > 1)[0] two_cell_degrees[idx] = 0 up = data["incidence_3"].to_dense() @ data["incidence_3"].to_dense().T - down = data["incidence_2"].to_dense().T @ data["incidence_2"].to_dense() + down = ( + data["incidence_2"].to_dense().T @ data["incidence_2"].to_dense() + ) mask = torch.eye(up.size()[0]).bool() up.masked_fill_(mask, 0) down.masked_fill_(mask, 0) @@ -375,12 +357,11 @@ def two_cell_curvature( class OneHotDegreeFeatures(torch_geometric.transforms.BaseTransform): - r"""A transform that adds the node degree as one hot encodings to the node features. + r"""A transform that adds the node degree as one hot encodings to the node + features. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -390,35 +371,32 @@ def __init__(self, **kwargs): self.features_fields = kwargs["features_fields"] self.transform = OneHotDegree(max_degree=kwargs["max_degrees"]) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, degrees_field={self.deg_field!r}, features_field={self.features_fields!r}" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data = self.transform.forward( - data, degrees_field=self.deg_field, features_field=self.features_fields + data, + degrees_field=self.deg_field, + features_field=self.features_fields, ) return data class OneHotDegree(torch_geometric.transforms.BaseTransform): - r"""Adds the node degree as one hot encodings to the node features - - Parameters - ---------- - max_degree : int - The maximum degree of the graph. - cat : bool, optional - If set to `True`, the one hot encodings are concatenated to the node features. + r"""Adds the node degree as one hot encodings to the node features. + + Args: + max_degree (int): The maximum degree of the graph. + cat (bool, optional): Whether to concatenate the one hot encoding to the node features. (default: False) """ def __init__( @@ -429,24 +407,23 @@ def __init__( self.max_degree = max_degree self.cat = cat + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_degree={self.max_degree}, cat={self.cat})" + def forward( - self, data: torch_geometric.data.Data, degrees_field: str, features_field: str + self, + data: torch_geometric.data.Data, + degrees_field: str, + features_field: str, ) -> torch_geometric.data.Data: r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - degrees_field : str - The field containing the node degrees. - features_field : str - The field containing the node features. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + degrees_field (str): The field containing the node degrees. + features_field (str): The field containing the node features. + Returns: + torch_geometric.data.Data: The transformed data. """ assert data.edge_index is not None @@ -473,10 +450,8 @@ def __repr__(self) -> str: class KeepSelectedDataFields(torch_geometric.transforms.BaseTransform): r"""A transform that keeps only the selected fields of the input data. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -484,22 +459,21 @@ def __init__(self, **kwargs): self.type = "keep_selected_data_fields" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r}" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ # Keeps all the fields fields_to_keep = ( - self.parameters["base_fields"] + self.parameters["preserved_fields"] + self.parameters["base_fields"] + + self.parameters["preserved_fields"] ) # if len(self.parameters["keep_fields"]) == 1: # return data diff --git a/topobenchmarkx/transforms/data_manipulations/node_degrees.py b/topobenchmarkx/transforms/data_manipulations/node_degrees.py new file mode 100644 index 00000000..d8f07aa9 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/node_degrees.py @@ -0,0 +1,78 @@ +import torch +import torch_geometric + +class NodeDegrees(torch_geometric.transforms.BaseTransform): + r"""A transform that calculates the node degrees of the input graph. + + Args: + kwargs (optional): Parameters for the base transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "node_degrees" + self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + field_to_process = [ + key + for key in data.keys() + for field_substring in self.parameters["selected_fields"] + if field_substring in key and key != "incidence_0" + ] + for field in field_to_process: + data = self.calculate_node_degrees(data, field) + + return data + + def calculate_node_degrees( + self, data: torch_geometric.data.Data, field: str + ) -> torch_geometric.data.Data: + r"""Calculate the node degrees of the input data. + + Args: + data (torch_geometric.data.Data): The input data. + field (str): The field to calculate the node degrees. + Returns: + torch_geometric.data.Data: The transformed data. + """ + if data[field].is_sparse: + degrees = abs(data[field].to_dense()).sum(1) + else: + assert ( + field == "edge_index" + ), "Following logic of finding degrees is only implemented for edge_index" + + # Get number of nodes + if data.get("num_nodes", None): + max_num_nodes = data["num_nodes"] + else: + max_num_nodes = data["x"].shape[0] + degrees = ( + torch_geometric.utils.to_dense_adj( + data[field], + max_num_nodes=max_num_nodes, + ) + .squeeze(0) + .sum(1) + ) + + if "incidence" in field: + field_name = ( + str(int(field.split("_")[1]) - 1) + "_cell" + "_degrees" + ) + else: + field_name = "node_degrees" + + data[field_name] = degrees.unsqueeze(1) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py b/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py new file mode 100644 index 00000000..a49689f4 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py @@ -0,0 +1,26 @@ +import torch_geometric + +class NodeFeaturesToFloat(torch_geometric.transforms.BaseTransform): + r"""A transform that converts the node features of the input graph to float. + + Args: + kwargs (optional): Parameters for the base transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "map_node_features_to_float" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + data.x = data.x.float() + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py new file mode 100644 index 00000000..80a2611a --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py @@ -0,0 +1,55 @@ +import torch +import torch_geometric +from torch_geometric.utils import one_hot + +class OneHotDegree(torch_geometric.transforms.BaseTransform): + r"""Adds the node degree as one hot encodings to the node features. + + Args: + max_degree (int): The maximum degree of the graph. + cat (bool, optional): If set to `True`, the one hot encodings are concatenated to the node features. (default: False) + """ + def __init__( + self, + max_degree: int, + cat: bool = False, + **kwargs, + ) -> None: + self.max_degree = max_degree + self.cat = cat + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_degree={self.max_degree}, cat={self.cat})" + + def forward( + self, + data: torch_geometric.data.Data, + degrees_field: str, + features_field: str, + ) -> torch_geometric.data.Data: + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + degrees_field (str): The field containing the node degrees. + features_field (str): The field containing the node features. + Returns: + torch_geometric.data.Data: The transformed data. + """ + assert data.edge_index is not None + + deg = data[degrees_field].to(torch.long) + + if len(deg.shape) == 2: + deg = deg.squeeze(1) + + deg = one_hot(deg, num_classes=self.max_degree + 1) + + if self.cat: + x = data[features_field] + x = x.view(-1, 1) if x.dim() == 1 else x + data[features_field] = torch.cat([x, deg.to(x.dtype)], dim=-1) + else: + data[features_field] = deg + + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py new file mode 100644 index 00000000..c12d1c9c --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py @@ -0,0 +1,37 @@ +import torch_geometric +from topobenchmarkx.transforms.data_manipulations.one_hot_degree import OneHotDegree + + + +class OneHotDegreeFeatures(torch_geometric.transforms.BaseTransform): + r"""A transform that adds the node degree as one hot encodings to the node + features. + + Args: + kwargs (optional): Parameters for the base transform. + """ + def __init__(self, **kwargs): + super().__init__() + self.type = "one_hot_degree_features" + self.deg_field = kwargs["degrees_fields"] + self.features_fields = kwargs["features_fields"] + self.transform = OneHotDegree(**kwargs) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, degrees_field={self.deg_field!r}, features_field={self.features_fields!r})" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. + """ + data = self.transform.forward( + data, + degrees_field=self.deg_field, + features_field=self.features_fields, + ) + + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_transform.py b/topobenchmarkx/transforms/data_transform.py index e66c108b..701cd58a 100755 --- a/topobenchmarkx/transforms/data_transform.py +++ b/topobenchmarkx/transforms/data_transform.py @@ -1,70 +1,13 @@ -# from abc import ABC, abstractmethod - import torch_geometric - -from topobenchmarkx.transforms.data_manipulations.manipulations import ( - CalculateSimplicialCurvature, - EqualGausFeatures, - IdentityTransform, - InfereKNNConnectivity, - InfereRadiusConnectivity, - KeepOnlyConnectedComponent, - KeepSelectedDataFields, - NodeDegrees, - NodeFeaturesToFloat, - OneHotDegreeFeatures, -) -from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( - ConcatentionLifting, - ProjectionSum, - SetLifting, -) -from topobenchmarkx.transforms.liftings.graph2cell import CellCyclesLifting -from topobenchmarkx.transforms.liftings.graph2hypergraph import ( - HypergraphKHopLifting, - HypergraphKNearestNeighborsLifting, -) -from topobenchmarkx.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting, - SimplicialNeighborhoodLifting, -) - -TRANSFORMS = { - # Graph -> Hypergraph - "HypergraphKHopLifting": HypergraphKHopLifting, - "HypergraphKNearestNeighborsLifting": HypergraphKNearestNeighborsLifting, - # Graph -> Simplicial Complex - "SimplicialNeighborhoodLifting": SimplicialNeighborhoodLifting, - "SimplicialCliqueLifting": SimplicialCliqueLifting, - # Graph -> Cell Complex - "CellCyclesLifting": CellCyclesLifting, - # Feature Liftings - "ProjectionSum": ProjectionSum, - "ConcatentionLifting": ConcatentionLifting, - "SetLifting": SetLifting, - # Data Manipulations - "Identity": IdentityTransform, - "InfereKNNConnectivity": InfereKNNConnectivity, - "InfereRadiusConnectivity": InfereRadiusConnectivity, - "NodeDegrees": NodeDegrees, - "OneHotDegreeFeatures": OneHotDegreeFeatures, - "EqualGausFeatures": EqualGausFeatures, - "NodeFeaturesToFloat": NodeFeaturesToFloat, - "CalculateSimplicialCurvature": CalculateSimplicialCurvature, - "KeepOnlyConnectedComponent": KeepOnlyConnectedComponent, - "KeepSelectedDataFields": KeepSelectedDataFields, -} - +from topobenchmarkx.transforms import TRANSFORMS class DataTransform(torch_geometric.transforms.BaseTransform): - """Abstract class that provides an interface to define a custom data lifting. + r"""Abstract class that provides an interface to define a custom data + lifting. - Parameters - ---------- - transform_name : str - The name of the transform to be used. - **kwargs : optional - Additional arguments for the class. + Args: + transform_name (str): The name of the transform to be used. + **kwargs: Additional arguments for the class. """ def __init__(self, transform_name, **kwargs): @@ -74,21 +17,20 @@ def __init__(self, transform_name, **kwargs): self.parameters = kwargs self.transform = ( - TRANSFORMS[transform_name](**kwargs) if transform_name is not None else None + TRANSFORMS[transform_name](**kwargs) + if transform_name is not None + else None ) - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - """Forward pass of the lifting. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + r"""Forward pass of the lifting. - Returns - ------- - transformed_data : torch_geometric.data.Data - The lifted data. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + transformed_data (torch_geometric.data.Data): The lifted data. """ transformed_data = self.transform(data) return transformed_data diff --git a/topobenchmarkx/transforms/feature_liftings/__init__.py b/topobenchmarkx/transforms/feature_liftings/__init__.py index e69de29b..c7c378c7 100644 --- a/topobenchmarkx/transforms/feature_liftings/__init__.py +++ b/topobenchmarkx/transforms/feature_liftings/__init__.py @@ -0,0 +1,11 @@ +from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( + ConcatentionLifting, + ProjectionSum, + SetLifting, +) + +__all__ = [ + "ConcatentionLifting", + "ProjectionSum", + "SetLifting", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/feature_liftings/feature_liftings.py b/topobenchmarkx/transforms/feature_liftings/feature_liftings.py index 3b7117f8..70ee93b2 100644 --- a/topobenchmarkx/transforms/feature_liftings/feature_liftings.py +++ b/topobenchmarkx/transforms/feature_liftings/feature_liftings.py @@ -5,30 +5,29 @@ class ProjectionSum(torch_geometric.transforms.BaseTransform): r"""Lifts r-cell features to r+1-cells by projection. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ - def __init__(self, **kwargs): super().__init__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: - r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix. - - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data.""" - keys = sorted([key.split("_")[1] for key in data if "incidence" in key]) + r"""Projects r-cell features of a graph to r+1-cell structures using the + incidence matrix. + + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The data with the lifted features. + """ + keys = sorted( + [key.split("_")[1] for key in data if "incidence" in key] + ) for elem in keys: if f"x_{elem}" not in data: idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 @@ -43,15 +42,10 @@ def forward( ) -> torch_geometric.data.Data | dict: r"""Applies the lifting to the input data. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ data = self.lift_features(data) return data @@ -60,31 +54,29 @@ def forward( class ConcatentionLifting(torch_geometric.transforms.BaseTransform): r"""Lifts r-cell features to r+1-cells by concatenation. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ - def __init__(self, **kwargs): super().__init__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: - r"""Concatenates r-cell features to r+1-cell structures using the incidence matrix. - - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data.""" + r"""Concatenates r-cell features to r+1-cell structures using the + incidence matrix. - keys = sorted([key.split("_")[1] for key in data if "incidence" in key]) + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. + """ + keys = sorted( + [key.split("_")[1] for key in data if "incidence" in key] + ) for elem in keys: if f"x_{elem}" not in data: idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 @@ -112,15 +104,10 @@ def forward( ) -> torch_geometric.data.Data | dict: r"""Applies the lifting to the input data. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ data = self.lift_features(data) return data @@ -129,34 +116,33 @@ def forward( class SetLifting(torch_geometric.transforms.BaseTransform): r"""Lifts r-cell features to r+1-cells by set operations. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ - def __init__(self, **kwargs): super().__init__() + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: - r"""Concatenates r-cell features to r+1-cell structures using the incidence matrix. + r"""Concatenates r-cell features to r+1-cell structures using the + incidence matrix. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data.""" + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. + """ - keys = sorted([key.split("_")[1] for key in data if "incidence" in key]) + keys = sorted( + [key.split("_")[1] for key in data if "incidence" in key] + ) for elem in keys: if f"x_{elem}" not in data: - #idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 + # idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 incidence = data["incidence_" + elem] _, n = incidence.shape @@ -192,15 +178,10 @@ def forward( ) -> torch_geometric.data.Data | dict: r"""Applies the lifting to the input data. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ data = self.lift_features(data) return data diff --git a/topobenchmarkx/transforms/liftings/__init__.py b/topobenchmarkx/transforms/liftings/__init__.py index e69de29b..ac9a7d4d 100755 --- a/topobenchmarkx/transforms/liftings/__init__.py +++ b/topobenchmarkx/transforms/liftings/__init__.py @@ -0,0 +1,18 @@ +from topobenchmarkx.transforms.liftings.graph2cell import CellCyclesLifting + +from topobenchmarkx.transforms.liftings.graph2hypergraph import ( + HypergraphKHopLifting, + HypergraphKNearestNeighborsLifting, +) +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, + SimplicialNeighborhoodLifting, +) + +__all__ = [ + "CellCyclesLifting", + "HypergraphKHopLifting", + "HypergraphKNearestNeighborsLifting", + "SimplicialCliqueLifting", + "SimplicialNeighborhoodLifting", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/liftings/graph2cell.py b/topobenchmarkx/transforms/liftings/graph2cell.py index ffff21ae..84a86066 100755 --- a/topobenchmarkx/transforms/liftings/graph2cell.py +++ b/topobenchmarkx/transforms/liftings/graph2cell.py @@ -17,51 +17,43 @@ class Graph2CellLifting(GraphLifting): r"""Abstract class for lifting graphs to cell complexes. - Parameters - ---------- - complex_dim : int, optional - The dimension of the cell complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. + Args: + complex_dim (int, optional): The dimension of the cell complex to be generated. (default: 2) + kwargs (optional): Additional arguments for the class. """ - def __init__(self, complex_dim=2, **kwargs): super().__init__(**kwargs) self.complex_dim = complex_dim self.type = "graph2cell" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(complex_dim={self.complex_dim!r}, type={self.type!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to cell complex domain. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError - def _get_lifted_topology(self, cell_complex: CellComplex, graph: nx.Graph) -> dict: + def _get_lifted_topology( + self, cell_complex: CellComplex, graph: nx.Graph + ) -> dict: r"""Returns the lifted topology. - Parameters - ---------- - cell_complex : CellComplex - The cell complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. + Args: + cell_complex (CellComplex): The cell complex. + graph (nx.Graph): The input graph. + Returns: + dict: The lifted topology. """ - lifted_topology = get_complex_connectivity(cell_complex, self.complex_dim) + lifted_topology = get_complex_connectivity( + cell_complex, self.complex_dim + ) lifted_topology["x_0"] = torch.stack( list(cell_complex.get_cell_attributes("features", 0).values()) ) @@ -76,34 +68,28 @@ def _get_lifted_topology(self, cell_complex: CellComplex, graph: nx.Graph) -> di class CellCyclesLifting(Graph2CellLifting): - r"""Lifts graphs to cell complexes by identifying the cycles as 2-cells. + r"""Lifts graphs to cell complexes by taking as 2-cells a cycle base for the graph. - Parameters - ---------- - max_cell_length : int, optional - The maximum length of the cycles to be lifted. Default is None. - **kwargs : optional - Additional arguments for the class. + Args: + max_cell_length (int, optional): The maximum length of the cycles to be lifted. Default is None. + kwargs (optional): Additional arguments for the class. """ - def __init__(self, max_cell_length=None, **kwargs): super().__init__(**kwargs) self.complex_dim = 2 self.max_cell_length = max_cell_length - + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_cell_length={self.max_cell_length!r}, complex_dim={self.complex_dim!r})" + def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Finds the cycles of a graph and lifts them to 2-cells. + r"""Finds a cycle base for the graph and lifts its cycles to 2-cells. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. - """ + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. + """ G = self._generate_graph_from_data(data) cycles = nx.cycle_basis(G) cell_complex = CellComplex(G) @@ -112,7 +98,9 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: cycles = [cycle for cycle in cycles if len(cycle) != 1] # Eliminate cycles that are greater than the max_cell_lenght if self.max_cell_length is not None: - cycles = [cycle for cycle in cycles if len(cycle) <= self.max_cell_length] + cycles = [ + cycle for cycle in cycles if len(cycle) <= self.max_cell_length + ] if len(cycles) != 0: cell_complex.add_cells_from(cycles, rank=self.complex_dim) return self._get_lifted_topology(cell_complex, G) diff --git a/topobenchmarkx/transforms/liftings/graph2hypergraph.py b/topobenchmarkx/transforms/liftings/graph2hypergraph.py index 7ed3bd2f..62307eda 100755 --- a/topobenchmarkx/transforms/liftings/graph2hypergraph.py +++ b/topobenchmarkx/transforms/liftings/graph2hypergraph.py @@ -14,60 +14,51 @@ class Graph2HypergraphLifting(GraphLifting): r"""Abstract class for lifting graphs to hypergraphs. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.type = "graph2hypergraph" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to hypergraph domain. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError class HypergraphKHopLifting(Graph2HypergraphLifting): - r"""Lifts graphs to hypergraph domain by considering k-hop neighborhoods. - - Parameters - ---------- - k_value : int, optional - The number of hops to consider. Default is 1. - **kwargs : optional - Additional arguments for the class. + r"""Lifts graphs to hypergraph domain by considering k-hop neighborhoods of a node. This lifting extracts a number of hyperedges equal to the number of nodes in the graph. + + Args: + k_value (int, optional): The number of hops to consider. (default: 1) + kwargs (optional): Additional arguments for the class. """ - def __init__(self, k_value=1, **kwargs): super().__init__(**kwargs) self.k = k_value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(k={self.k!r})" def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to hypergraph domain by considering k-hop neighborhoods. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. + r"""Lifts the topology of a graph to hypergraph domain by considering + k-hop neighborhoods. - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ # Check if data has instance x: if hasattr(data, "x") and data.x is not None: @@ -79,13 +70,17 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: edge_index = torch_geometric.utils.to_undirected(data.edge_index) # Detect isolated nodes - isolated_nodes = [i for i in range(num_nodes) if i not in edge_index[0]] + isolated_nodes = [ + i for i in range(num_nodes) if i not in edge_index[0] + ] if len(isolated_nodes) > 0: # Add completely isolated nodes to the edge_index edge_index = torch.cat( [ edge_index, - torch.tensor([isolated_nodes, isolated_nodes], dtype=torch.long), + torch.tensor( + [isolated_nodes, isolated_nodes], dtype=torch.long + ), ], dim=1, ) @@ -106,36 +101,31 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: class HypergraphKNearestNeighborsLifting(Graph2HypergraphLifting): - r"""Lifts graphs to hypergraph domain by considering k-nearest neighbors. - - Parameters - ---------- - k_value : int, optional - The number of nearest neighbors to consider. Default is 1. - loop: boolean, optional - If True the hyperedges will contain the node they were created from. - **kwargs : optional - Additional arguments for the class. + r"""Lifts graphs to hypergraph domain by considering k-nearest neighbors. This lifting extracts a number of hyperedges equal to the number of nodes in the graph. The hyperedges all contain the same number of nodes, which is equal to the number of nearest neighbors considered. + + Args: + k_value (int, optional): The number of nearest neighbors to consider. (default: 1) + loop (bool, optional): If True the hyperedges will contain the node they were created from. (default: True) + cosine (bool, optional): If True the cosine distance will be used instead of the Euclidean distance. (default: False) + kwargs (optional): Additional arguments for the class. """ - - def __init__(self, k_value=1, loop=True, **kwargs): + def __init__(self, k_value=1, loop=True, cosine=False, **kwargs): super().__init__() self.k = k_value self.loop = loop - self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop) + self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop, cosine=cosine) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(k={self.k!r}, loop={self.loop!r})" + def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to hypergraph domain by considering k-nearest neighbors. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + r"""Lifts the topology of a graph to hypergraph domain by considering + k-nearest neighbors. + + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ num_nodes = data.x.shape[0] data.pos = data.x diff --git a/topobenchmarkx/transforms/liftings/graph2simplicial.py b/topobenchmarkx/transforms/liftings/graph2simplicial.py index b647b9db..89f961a5 100755 --- a/topobenchmarkx/transforms/liftings/graph2simplicial.py +++ b/topobenchmarkx/transforms/liftings/graph2simplicial.py @@ -20,33 +20,27 @@ class Graph2SimplicialLifting(GraphLifting): r"""Abstract class for lifting graphs to simplicial complexes. - Parameters - ---------- - complex_dim : int, optional - The dimension of the simplicial complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. + Args: + complex_dim (int, optional): The maximum dimension of the simplicial complex to be generated. (default: 2) + kwargs (optional): Additional arguments for the class. """ - def __init__(self, complex_dim=2, **kwargs): super().__init__(**kwargs) self.complex_dim = complex_dim self.type = "graph2simplicial" self.signed = kwargs.get("signed", False) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(complex_dim={self.complex_dim!r}, type={self.type!r}, signed={self.signed!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to simplicial complex domain. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError @@ -55,61 +49,59 @@ def _get_lifted_topology( ) -> dict: r"""Returns the lifted topology. - Parameters - ---------- - simplicial_complex : SimplicialComplex - The simplicial complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. + Args: + simplicial_complex (SimplicialComplex): The simplicial complex. + graph (nx.Graph): The input graph. + Returns: + dict: The lifted topology. """ lifted_topology = get_complex_connectivity( simplicial_complex, self.complex_dim, signed=self.signed ) lifted_topology["x_0"] = torch.stack( - list(simplicial_complex.get_simplex_attributes("features", 0).values()) + list( + simplicial_complex.get_simplex_attributes( + "features", 0 + ).values() + ) ) # If new edges have been added during the lifting process, we discard the edge attributes if self.contains_edge_attr and simplicial_complex.shape[1] == ( graph.number_of_edges() ): lifted_topology["x_1"] = torch.stack( - list(simplicial_complex.get_simplex_attributes("features", 1).values()) + list( + simplicial_complex.get_simplex_attributes( + "features", 1 + ).values() + ) ) return lifted_topology class SimplicialNeighborhoodLifting(Graph2SimplicialLifting): - r"""Lifts graphs to simplicial complex domain by considering k-hop neighborhoods. - - Parameters - ---------- - max_k_simplices : int, optional - The maximum number of k-simplices to consider. Default is 5000. - **kwargs : optional - Additional arguments for the class. - """ + r"""Lifts graphs to simplicial complex domain by considering k-hop + neighborhoods. For each node its neighborhood is selected and then all the possible simplices, when considering the neighborhood as a clique, are added to the simplicial complex. For this reason this lifting does not conserve the initial graph topology. + Args: + max_k_simplices (int, optional): The maximum number of k-simplices to consider. (default: 5000) + kwargs (optional): Additional arguments for the class. + """ def __init__(self, max_k_simplices=5000, **kwargs): super().__init__(**kwargs) self.max_k_simplices = max_k_simplices + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_k_simplices={self.max_k_simplices!r})" def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to simplicial complex domain by considering k-hop neighborhoods. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. + r"""Lifts the topology of a graph to simplicial complex domain by + considering k-hop neighborhoods. - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ graph = self._generate_graph_from_data(data) simplicial_complex = SimplicialComplex(graph) @@ -117,7 +109,9 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: simplices = [set() for _ in range(2, self.complex_dim + 1)] for n in range(graph.number_of_nodes()): # Find 1-hop node n neighbors - neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph(n, 1, edge_index) + neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph( + n, 1, edge_index + ) if n not in neighbors: neighbors.append(n) neighbors = neighbors.numpy() @@ -135,29 +129,23 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: class SimplicialCliqueLifting(Graph2SimplicialLifting): - r"""Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. + r"""Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices, considering also all the combinations with lower rank. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ def __init__(self, **kwargs): super().__init__(**kwargs) def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. + r"""Lifts the topology of a graph to a simplicial complex by identifying + the cliques as k-simplices. - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ graph = self._generate_graph_from_data(data) simplicial_complex = SimplicialComplex(graph) diff --git a/topobenchmarkx/transforms/liftings/graph_lifting.py b/topobenchmarkx/transforms/liftings/graph_lifting.py index ec25308b..4dab940b 100644 --- a/topobenchmarkx/transforms/liftings/graph_lifting.py +++ b/topobenchmarkx/transforms/liftings/graph_lifting.py @@ -19,90 +19,78 @@ class GraphLifting(torch_geometric.transforms.BaseTransform): - r"""Abstract class for lifting graph topologies to higher-order topological domains. + r"""Abstract class for lifting graph topologies to higher-order topological + domains. - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'projection'. - preserve_edge_attr : bool, optional - Whether to preserve edge attributes. Default is False. - **kwargs : optional - Additional arguments for the class. + Args: + feature_lifting (str, optional): The feature lifting method to be used. (default: 'projection') + preserve_edge_attr (bool, optional): Whether to preserve edge attributes. (default: False) + kwargs (optional): Additional arguments for the class. """ - def __init__( self, feature_lifting="projection", preserve_edge_attr=False, **kwargs ): super().__init__() self.feature_lifting = FEATURE_LIFTINGS[feature_lifting]() self.preserve_edge_attr = preserve_edge_attr + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(feature_lifting={self.feature_lifting!r}, preserve_edge_attr={self.preserve_edge_attr!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to higher-order topological domains. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: r"""Applies the full lifting (topology + features) to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data - The lifted data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The output data. """ initial_data = data.to_dict() lifted_topology = self.lift_topology(data) lifted_topology = self.feature_lifting(lifted_topology) - lifted_data = torch_geometric.data.Data(**initial_data, **lifted_topology) + lifted_data = torch_geometric.data.Data( + **initial_data, **lifted_topology + ) return lifted_data def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: r"""Checks if the input data object has edge attributes. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - bool - Whether the data object has edge attributes. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + bool: Whether the data object has edge attributes. """ return hasattr(data, "edge_attr") and data.edge_attr is not None - def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph: + def _generate_graph_from_data( + self, data: torch_geometric.data.Data + ) -> nx.Graph: r"""Generates a NetworkX graph from the input data object. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - nx.Graph - The generated NetworkX graph. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + nx.Graph: The generated NetworkX graph. """ # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? - nodes = [(n, dict(features=data.x[n], dim=0)) for n in range(data.x.shape[0])] + nodes = [ + (n, dict(features=data.x[n], dim=0)) + for n in range(data.x.shape[0]) + ] if self.preserve_edge_attr and self._data_has_edge_attr(data): # In case edge features are given, assign features to every edge @@ -125,7 +113,9 @@ def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph # If edge_attr is not present, return list list of edges edges = [ (i.item(), j.item()) - for i, j in zip(data.edge_index[0], data.edge_index[1], strict=False) + for i, j in zip( + data.edge_index[0], data.edge_index[1], strict=False + ) ] self.contains_edge_attr = False graph = nx.Graph() diff --git a/topobenchmarkx/utils/__init__.py b/topobenchmarkx/utils/__init__.py index cb6b0bc7..b2beffb3 100755 --- a/topobenchmarkx/utils/__init__.py +++ b/topobenchmarkx/utils/__init__.py @@ -1,8 +1,17 @@ from topobenchmarkx.utils.instantiators import ( instantiate_callbacks, # noqa: F401 - instantiate_loggers, # noqa: F401 + instantiate_loggers, # noqa: F401 +) +from topobenchmarkx.utils.logging_utils import ( + log_hyperparameters, # noqa: F401 ) -from topobenchmarkx.utils.logging_utils import log_hyperparameters # noqa: F401 from topobenchmarkx.utils.pylogger import RankedLogger # noqa: F401 -from topobenchmarkx.utils.rich_utils import enforce_tags, print_config_tree # noqa: F401 -from topobenchmarkx.utils.utils import extras, get_metric_value, task_wrapper # noqa: F401 +from topobenchmarkx.utils.rich_utils import ( + enforce_tags, + print_config_tree, +) +from topobenchmarkx.utils.utils import ( + extras, + get_metric_value, + task_wrapper, +) diff --git a/topobenchmarkx/utils/config_resolvers.py b/topobenchmarkx/utils/config_resolvers.py index 9af53491..f6fcab17 100644 --- a/topobenchmarkx/utils/config_resolvers.py +++ b/topobenchmarkx/utils/config_resolvers.py @@ -1,22 +1,13 @@ def get_default_transform(data_domain, model): r"""Get default transform for a given data domain and model. - - Parameters - ---------- - data_domain: str - Data domain. - model: str - Model name. Should be in the format "model_domain/name". - - Returns - ------- - str - Default transform. - - Raises - ------ - ValueError - If the combination of data_domain and model is invalid. + + Args: + data_domain (str): Data domain. + model (str): Model name. Should be in the format "model_domain/name". + Returns: + str: Default transform. + Raises: + ValueError: If the combination of data_domain and model is invalid. """ model_domain = model.split("/")[0] if data_domain == model_domain: @@ -31,27 +22,16 @@ def get_default_transform(data_domain, model): def get_monitor_metric(task, metric): r"""Get monitor metric for a given task and loss. - - Parameters - ---------- - task: str - Task, either "classification" or "regression". - loss: str - Name of the loss function. - - Returns - ------- - str - Monitor metric. - - Raises - ------ - ValueError - If the task is invalid. + + Args: + task (str): Task, either "classification" or "regression". + loss (str): Name of the loss function. + Returns: + str: Monitor metric. + Raises: + ValueError: If the task is invalid. """ - if task == "classification": - return f"val/{metric}" - elif task == "regression": + if task == "classification" or task == "regression": return f"val/{metric}" else: raise ValueError(f"Invalid task {task}") @@ -59,21 +39,13 @@ def get_monitor_metric(task, metric): def get_monitor_mode(task): r"""Get monitor mode for a given task. - - Parameters - ---------- - task: str - Task, either "classification" or "regression". - - Returns - ------- - str - Monitor mode, either "max" or "min". - - Raises - ------ - ValueError - If the task is invalid. + + Args: + task (str): Task, either "classification" or "regression". + Returns: + str: Monitor mode, either "max" or "min". + Raises: + ValueError: If the task is invalid. """ if task == "classification": return "max" @@ -85,31 +57,20 @@ def get_monitor_mode(task): def infer_in_channels(dataset): r"""Infer the number of input channels for a given dataset. - - Parameters - ---------- - dataset: torch_geometric.data.Dataset - Input dataset. - - Returns - ------- - list - List with dimensions of the input channels. + + Args: + dataset (torch_geometric.data.Dataset): Input dataset. + Returns: + list: List with dimensions of the input channels. """ def find_complex_lifting(dataset): r"""Find if there is a complex lifting in the dataset. - - Parameters - ---------- - dataset: torch_geometric.data.Dataset - Input dataset. - - Returns - ------- - bool - True if there is a complex lifting, False otherwise. - str - Name of the complex lifting, if it exists. + + Args: + dataset (torch_geometric.data.Dataset): Input dataset. + Returns: + bool: True if there is a complex lifting, False otherwise. + str: Name of the complex lifting, if it exists. """ if "transforms" not in dataset: return False, None @@ -125,18 +86,12 @@ def find_complex_lifting(dataset): def check_for_type_feature_lifting(dataset, lifting): r"""Check the type of feature lifting in the dataset. - - Parameters - ---------- - dataset: torch_geometric.data.Dataset - Input dataset. - lifting: str - Name of the complex lifting. - - Returns - ------- - str - Type of feature lifting. + + Args: + dataset (torch_geometric.data.Dataset): Input dataset. + lifting (str): Name of the complex lifting. + Returns: + str: Type of feature lifting. """ lifting_params_keys = dataset.transforms[lifting].keys() if "feature_lifting" in lifting_params_keys: @@ -169,21 +124,29 @@ def check_for_type_feature_lifting(dataset, lifting): lifting ].complex_dim else: - if not dataset.transforms[lifting].preserve_edge_attr: + # Case when the dataset has not edge attributes + if dataset.transforms[lifting].preserve_edge_attr == False: + if feature_lifting == "projection": - return [dataset.parameters.num_features[0]] * dataset.transforms[ - lifting - ].complex_dim + return [ + dataset.parameters.num_features[0] + ] * dataset.transforms[lifting].complex_dim + elif feature_lifting == "concatenation": return_value = [dataset.parameters.num_features] - for i in range(2, dataset.transforms[lifting].complex_dim + 1): - return_value += [int(dataset.parameters.num_features * i)] + for i in range( + 2, dataset.transforms[lifting].complex_dim + 1 + ): + return_value += [ + int(dataset.parameters.num_features * i) + ] return return_value + else: - return [dataset.parameters.num_features] * dataset.transforms[ - lifting - ].complex_dim + return [ + dataset.parameters.num_features + ] * dataset.transforms[lifting].complex_dim else: return list(dataset.parameters.num_features) + [ @@ -199,5 +162,13 @@ def check_for_type_feature_lifting(dataset, lifting): else: return [dataset.parameters.num_features[0]] + def infere_list_length(list): - return len(list) \ No newline at end of file + r"""Infer the length of a list. + + Args: + list (list): Input list. + Returns: + int: Length of the input list. + """ + return len(list) diff --git a/topobenchmarkx/utils/instantiators.py b/topobenchmarkx/utils/instantiators.py index 3b94bf4e..8f926cc8 100755 --- a/topobenchmarkx/utils/instantiators.py +++ b/topobenchmarkx/utils/instantiators.py @@ -9,10 +9,12 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: - """Instantiates callbacks from config. + r"""Instantiates callbacks from config. - :param callbacks_cfg: A DictConfig object containing callback configurations. - :return: A list of instantiated callbacks. + Args: + callbacks_cfg (DictConfig): A DictConfig object containing callback configurations. + Returns: + list[Callback]: A list of instantiated callbacks. """ callbacks: list[Callback] = [] @@ -32,10 +34,12 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]: - """Instantiates loggers from config. + r"""Instantiates loggers from config. - :param logger_cfg: A DictConfig object containing logger configurations. - :return: A list of instantiated loggers. + Args: + logger_cfg (DictConfig): A DictConfig object containing logger configurations. + Returns: + list[Logger]: A list of instantiated loggers. """ logger: list[Logger] = [] diff --git a/topobenchmarkx/utils/logging_utils.py b/topobenchmarkx/utils/logging_utils.py index 84feb757..459675fe 100755 --- a/topobenchmarkx/utils/logging_utils.py +++ b/topobenchmarkx/utils/logging_utils.py @@ -10,15 +10,16 @@ @rank_zero_only def log_hyperparameters(object_dict: dict[str, Any]) -> None: - """Controls which config parts are saved by Lightning loggers. + r"""Controls which config parts are saved by Lightning loggers. Additionally saves: - Number of model parameters - :param object_dict: A dictionary containing the following objects: - - `"cfg"`: A DictConfig object containing the main config. - - `"model"`: The Lightning model. - - `"trainer"`: The Lightning trainer. + Args: + object_dict (dict[str, Any]): A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. """ hparams = {} @@ -51,6 +52,7 @@ def log_hyperparameters(object_dict: dict[str, Any]) -> None: hparams["tags"] = cfg.get("tags") hparams["ckpt_path"] = cfg.get("ckpt_path") hparams["seed"] = cfg.get("seed") + hparams["paths"] = cfg.get("paths") # send hparams to all loggers for logger in trainer.loggers: diff --git a/topobenchmarkx/utils/pylogger.py b/topobenchmarkx/utils/pylogger.py index 31a76c37..3a24e0df 100755 --- a/topobenchmarkx/utils/pylogger.py +++ b/topobenchmarkx/utils/pylogger.py @@ -1,7 +1,10 @@ import logging -from typing import Mapping, Optional +from collections.abc import Mapping -from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only +from lightning_utilities.core.rank_zero import ( + rank_prefixed_message, + rank_zero_only, +) class RankedLogger(logging.LoggerAdapter): @@ -11,31 +14,37 @@ def __init__( self, name: str = __name__, rank_zero_only: bool = False, - extra: Optional[Mapping[str, object]] = None, + extra: Mapping[str, object] | None = None, ) -> None: - """Initializes a multi-GPU-friendly python command line logger that logs on all processes - with their rank prefixed in the log message. + r"""Initializes a multi-GPU-friendly python command line logger that + logs on all processes with their rank prefixed in the log message. - :param name: The name of the logger. Default is ``__name__``. - :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. - :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + Args: + name (str, optional): The name of the logger. (default: __name__) + rank_zero_only (bool, optional): Whether to force all logs to only occur on the rank zero process. (default: False) + extra (Mapping[str, object], optional): A dict-like object which provides contextual information. See `logging.LoggerAdapter`. (default: None) """ logger = logging.getLogger(name) super().__init__(logger=logger, extra=extra) self.rank_zero_only = rank_zero_only + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.logger.name!r}, rank_zero_only={self.rank_zero_only!r}, extra={self.extra})" def log( - self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + self, level: int, msg: str, rank: int | None = None, *args, **kwargs ) -> None: - """Delegate a log call to the underlying logger, after prefixing its message with the rank - of the process it's being logged from. If `'rank'` is provided, then the log will only - occur on that rank/process. + r"""Delegate a log call to the underlying logger, after prefixing its + message with the rank of the process it's being logged from. If + `'rank'` is provided, then the log will only occur on that + rank/process. - :param level: The level to log at. Look at `logging.__init__.py` for more information. - :param msg: The message to log. - :param rank: The rank to log at. - :param args: Additional args to pass to the underlying logging function. - :param kwargs: Any additional keyword args to pass to the underlying logging function. + Args: + level (int): The level to log at. Look at `logging.__init__.py` for more information. + msg (str): The message to log. + rank (int, optional): The rank to log at. (default: None) + args: Additional args to pass to the underlying logging function. + kwargs: Any additional keyword args to pass to the underlying logging function. """ if self.isEnabledFor(level): msg, kwargs = self.process(msg, kwargs) @@ -49,7 +58,5 @@ def log( if current_rank == 0: self.logger.log(level, msg, *args, **kwargs) else: - if rank is None: - self.logger.log(level, msg, *args, **kwargs) - elif current_rank == rank: + if rank is None or current_rank == rank: self.logger.log(level, msg, *args, **kwargs) diff --git a/topobenchmarkx/utils/rich_utils.py b/topobenchmarkx/utils/rich_utils.py index 692b8158..7a900459 100755 --- a/topobenchmarkx/utils/rich_utils.py +++ b/topobenchmarkx/utils/rich_utils.py @@ -29,13 +29,14 @@ def print_config_tree( resolve: bool = False, save_to_file: bool = False, ) -> None: - """Prints the contents of a DictConfig as a tree structure using the Rich library. - - :param cfg: A DictConfig composed by Hydra. - :param print_order: Determines in what order config components are printed. Default is ``("data", "model", - "callbacks", "logger", "trainer", "paths", "extras")``. - :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. - :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + r"""Prints the contents of a DictConfig as a tree structure using the Rich + library. + + Args: + cfg (DictConfig): A DictConfig composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. (default: `("data", "model", "callbacks", "logger", "trainer", "paths", "extras")`). + resolve (bool, optional): Whether to resolve reference fields of DictConfig. (default: False) + save_to_file (bool, optional): Whether to export config to the hydra output folder. (default: False) """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) @@ -80,17 +81,23 @@ def print_config_tree( @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config. + r"""Prompts user to input tags from command line if no tags are provided in + config. - :param cfg: A DictConfig composed by Hydra. - :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + Args: + cfg (DictConfig): A DictConfig composed by Hydra. + save_to_file (bool, optional): Whether to export tags to the hydra output folder. (default: False). """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + log.warning( + "No tags provided in config. Prompting user to input tags..." + ) + tags = Prompt.ask( + "Enter a list of comma separated tags", default="dev" + ) tags = [t.strip() for t in tags.split(",") if t != ""] with open_dict(cfg): diff --git a/topobenchmarkx/utils/utils.py b/topobenchmarkx/utils/utils.py index 436a85ae..2c6768ae 100755 --- a/topobenchmarkx/utils/utils.py +++ b/topobenchmarkx/utils/utils.py @@ -1,6 +1,7 @@ import warnings +from collections.abc import Callable from importlib.util import find_spec -from typing import Any, Callable, Optional +from typing import Any from omegaconf import DictConfig @@ -10,14 +11,15 @@ def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. + r"""Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing - :param cfg: A DictConfig object containing the config tree. + Args: + cfg (DictConfig): A DictConfig object containing the config tree. """ # return if no `extras` config if not cfg.get("extras"): @@ -26,7 +28,9 @@ def extras(cfg: DictConfig) -> None: # disable python warnings if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") + log.info( + "Disabling python warnings! " + ) warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config @@ -36,12 +40,15 @@ def extras(cfg: DictConfig) -> None: # pretty print config tree using Rich library if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") + log.info( + "Printing config tree with Rich! " + ) rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. + r"""Optional decorator that controls the failure behavior when executing the + task function. This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) @@ -56,10 +63,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... return metric_dict, object_dict ``` - - :param task_func: The task function to be wrapped. - - :return: The wrapped task function. + Args: + task_func: The task function to be wrapped. + Returns: + The wrapped task function. """ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: @@ -96,13 +103,15 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: def get_metric_value( - metric_dict: dict[str, Any], metric_name: Optional[str] -) -> Optional[float]: - """Safely retrieves value of the metric logged in LightningModule. - - :param metric_dict: A dict containing metric values. - :param metric_name: If provided, the name of the metric to retrieve. - :return: If a metric name was provided, the value of the metric. + metric_dict: dict[str, Any], metric_name: str | None +) -> float | None: + r"""Safely retrieves value of the metric logged in LightningModule. + + Args: + metric_dict: A dict containing metric values. + metric_name: If provided, the name of the metric to retrieve. + Returns: + If a metric name was provided, the value of the metric. """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") diff --git a/tutorials/add_new_dataset.ipynb b/tutorials/add_new_dataset.ipynb index 75cc9e18..2486d5d9 100644 --- a/tutorials/add_new_dataset.ipynb +++ b/tutorials/add_new_dataset.ipynb @@ -227,9 +227,7 @@ "source": [ "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", - "import torch\n", "from omegaconf import DictConfig\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.io import fs\n", @@ -286,9 +284,9 @@ " root: str,\n", " name: str,\n", " parameters: DictConfig,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " ) -> None:\n", " # Assign the class variables that would be needed for steps 1, 2, 4, and 3\n", @@ -405,57 +403,50 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import os\n", - "import torch\n", - "import torch_geometric\n", "import urllib.request\n", "\n", "\n", - "def hetero_load(name, path='./data/hetero_data'):\n", - " file_name = f'{name}.npz'\n", + "def hetero_load(name, path=\"./data/hetero_data\"):\n", + " file_name = f\"{name}.npz\"\n", "\n", " data = np.load(os.path.join(path, file_name))\n", "\n", - " x = torch.tensor(data['node_features'])\n", - " y = torch.tensor(data['node_labels'])\n", - " edge_index = torch.tensor(data['edges']).T\n", + " x = torch.tensor(data[\"node_features\"])\n", + " y = torch.tensor(data[\"node_labels\"])\n", + " edge_index = torch.tensor(data[\"edges\"]).T\n", "\n", " # Make edge_index undirected\n", " edge_index = torch_geometric.utils.to_undirected(edge_index)\n", "\n", " # Remove self-loops\n", " edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)\n", - " \n", + "\n", " data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)\n", " return data\n", "\n", + "\n", "def download_hetero_datasets(name, path):\n", - " url = 'https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/'\n", - " name = f'{name}.npz'\n", + " url = \"https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/\"\n", + " name = f\"{name}.npz\"\n", " try:\n", - " print(f'Downloading {name}')\n", + " print(f\"Downloading {name}\")\n", " path2save = os.path.join(path, name)\n", " urllib.request.urlretrieve(url + name, path2save)\n", - " print('Done!')\n", + " print(\"Done!\")\n", " except:\n", - " raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')\n", - "\n", + " raise Exception(\n", + " \"\"\"Download failed! Make sure you have stable Internet connection and enter the right name\"\"\"\n", + " )\n", "\n", "\n", - "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", - "import torch\n", "from omegaconf import DictConfig\n", - "from torch_geometric.data import Data, InMemoryDataset\n", - "from torch_geometric.io import fs\n", + "from torch_geometric.data import InMemoryDataset\n", "\n", "from topobenchmarkx.io.load.us_county_demos import load_us_county_demos\n", "\n", - "from topobenchmarkx.io.load.split_utils import random_splitting\n", - "\n", "\n", "class HeteroDataset(InMemoryDataset):\n", " r\"\"\"\n", @@ -496,14 +487,14 @@ " root: str,\n", " name: str,\n", " parameters: DictConfig,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " use_node_attr: bool = False,\n", " use_edge_attr: bool = False,\n", " ) -> None:\n", - " self.name = name #.replace(\"_\", \"-\")\n", + " self.name = name # .replace(\"_\", \"-\")\n", " self.parameters = parameters\n", " super().__init__(\n", " root, transform, pre_transform, pre_filter, force_reload=force_reload\n", @@ -542,7 +533,7 @@ " @property\n", " def processed_file_names(self) -> str:\n", " return \"data.pt\"\n", - " \n", + "\n", " @property\n", " def raw_file_names(self) -> list[str]:\n", " \"\"\"Spefify the downloaded raw fine name\"\"\"\n", @@ -569,7 +560,7 @@ " Returns:\n", " None\n", " \"\"\"\n", - " \n", + "\n", " data = hetero_load(name=self.name, path=self.raw_dir)\n", " data = data if self.pre_transform is None else self.pre_transform(data)\n", " self.save([data], self.processed_paths[0])\n", @@ -578,23 +569,24 @@ " return f\"{self.name}()\"\n", "\n", "\n", + "data_dir = \"/home/lev/projects/TopoBenchmarkX/datasets\"\n", + "data_domain = \"graph\"\n", + "data_type = \"heterophilic\"\n", + "data_name = \"amazon_ratings\"\n", "\n", - "data_dir = '/home/lev/projects/TopoBenchmarkX/datasets'\n", - "data_domain = 'graph'\n", - "data_type = 'heterophilic'\n", - "data_name = 'amazon_ratings'\n", - "\n", - "data_dir = f'{data_dir}/{data_domain}/{data_type}'\n", + "data_dir = f\"{data_dir}/{data_domain}/{data_type}\"\n", "\n", - "parameters={\n", - " 'split_type': 'random',\n", - " 'k': 10,\n", - " 'train_prop': 0.5,\n", - " 'data_seed':0,\n", - " 'data_split_dir': f'/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}'\n", - " }\n", + "parameters = {\n", + " \"split_type\": \"random\",\n", + " \"k\": 10,\n", + " \"train_prop\": 0.5,\n", + " \"data_seed\": 0,\n", + " \"data_split_dir\": f\"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}\",\n", + "}\n", "\n", - "dataset = HeteroDataset(name=data_name, root = data_dir, parameters=parameters, force_reload=True)" + "dataset = HeteroDataset(\n", + " name=data_name, root=data_dir, parameters=parameters, force_reload=True\n", + ")" ] } ],