From 8ed265dece5254a9e3a58ecc25eb1bcec82d0675 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 01:06:22 +0200 Subject: [PATCH 01/32] fixed sccn --- .pre-commit-config.yaml | 25 +-- configs/model/simplicial/san.yaml | 1 - configs/model/simplicial/sccn.yaml | 4 +- configs/train.yaml | 2 +- custom_models/cell/cin.py | 40 ++-- custom_models/hypergraph/edgnn.py | 39 +++- custom_models/simplicial/sccnn.py | 75 ++++--- format_and_lint.sh | 10 + notebooks/curvature_results.ipynb | 17 +- notebooks/data.ipynb | 33 +-- notebooks/play.ipynb | 78 +++---- notebooks/result_processing.ipynb | 7 +- notebooks/test_feature_lifting_dev.ipynb | 14 +- notebooks/test_hypergraph_lifting_dev.ipynb | 6 - notebooks/test_simplicialclique_dev.ipynb | 6 +- pyproject.toml | 9 + .../test_ConcatenationLifting.py | 4 +- .../test_ProjectionLifting.py | 4 +- .../feature_liftings/test_SetLifting.py | 8 +- .../liftings/cell/test_CellCyclesLifting.py | 128 ++++++++++- .../hypergraph/test_HypergraphKHopLifting.py | 10 +- ...test_HypergraphKNearestNeighborsLifting.py | 14 +- .../test_SimplicialCliqueLifting.py | 169 +++++++++++++-- .../test_SimplicialNeighborhoodLifting.py | 18 +- topobenchmarkx/data/cornel_dataset.ipynb | 7 +- ...dataloader_fullbatch.py => dataloaders.py} | 165 ++++---------- topobenchmarkx/data/datasets.py | 4 + topobenchmarkx/data/heteriphilic_dataset.py | 38 ++-- .../data/us_county_demos_dataset.py | 35 +-- topobenchmarkx/eval.py | 13 +- topobenchmarkx/evaluators/comparisons.py | 33 +-- topobenchmarkx/evaluators/evaluator.py | 25 +-- topobenchmarkx/io/load/download_utils.py | 15 +- topobenchmarkx/io/load/heterophilic.py | 28 +-- topobenchmarkx/io/load/loaders.py | 63 ++++-- topobenchmarkx/io/load/preprocessor.py | 29 ++- topobenchmarkx/io/load/split_utils.py | 25 ++- topobenchmarkx/io/load/us_county_demos.py | 32 +-- topobenchmarkx/io/load/utils.py | 10 +- topobenchmarkx/models/abstractions/encoder.py | 13 +- .../models/encoders/default_encoders.py | 202 ++++++++++-------- topobenchmarkx/models/encoders/perceiver.py | 71 ++++-- topobenchmarkx/models/head_model/models.py | 9 +- topobenchmarkx/models/losses/loss.py | 56 +---- topobenchmarkx/models/network_module.py | 57 ++--- topobenchmarkx/models/readouts/__init__.py | 26 +++ topobenchmarkx/models/readouts/old_readout.py | 63 ------ .../{readouts.py => propagate_signal_down.py} | 32 ++- topobenchmarkx/models/readouts/readout.py | 24 +-- topobenchmarkx/models/wrappers/__init__.py | 9 +- .../models/wrappers/default_wrapper.py | 143 ++++++++----- topobenchmarkx/play.ipynb | 19 +- topobenchmarkx/train.py | 59 ++--- .../data_manipulations/manipulations.py | 60 ++++-- topobenchmarkx/transforms/data_transform.py | 11 +- .../feature_liftings/feature_liftings.py | 32 ++- .../transforms/liftings/graph2cell.py | 12 +- .../transforms/liftings/graph2hypergraph.py | 14 +- .../transforms/liftings/graph2simplicial.py | 28 ++- .../transforms/liftings/graph_lifting.py | 24 ++- topobenchmarkx/utils/__init__.py | 17 +- topobenchmarkx/utils/config_resolvers.py | 58 ++--- topobenchmarkx/utils/instantiators.py | 3 +- topobenchmarkx/utils/pylogger.py | 26 +-- topobenchmarkx/utils/rich_utils.py | 14 +- topobenchmarkx/utils/utils.py | 18 +- tutorials/add_new_dataset.ipynb | 86 ++++---- 67 files changed, 1406 insertions(+), 1023 deletions(-) create mode 100755 format_and_lint.sh rename topobenchmarkx/data/{dataloader_fullbatch.py => dataloaders.py} (60%) delete mode 100755 topobenchmarkx/models/readouts/old_readout.py rename topobenchmarkx/models/readouts/{readouts.py => propagate_signal_down.py} (60%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2122dbd9..9b0a82ab 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,21 +15,16 @@ repos: - id: check-added-large-files args: - --maxkb=2048 -# - id: trailing-whitespace - id: requirements-txt-fixer - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 - hooks: - - id: ruff - #types_or: [ python, pyi, jupyter ] - #types_or: [ python, pyi ] - args: [ --fix ] - - id: ruff-format - #types_or: [ python, pyi, jupyter ] - #types_or: [ python, pyi ] + # - repo: https://github.com/astral-sh/ruff-pre-commit + # rev: v0.4.4 + # hooks: + # - id: ruff + # args: [ --fix ] + # - id: ruff-format - - repo: https://github.com/numpy/numpydoc - rev: v1.6.0 - hooks: - - id: numpydoc-validation \ No newline at end of file + # - repo: https://github.com/numpy/numpydoc + # rev: v1.6.0 + # hooks: + # - id: numpydoc-validation diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index eb13722f..fd91137c 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -36,7 +36,6 @@ head_model: out_channels: ${dataset.parameters.num_classes} pooling_type: sum - loss: _target_: topobenchmarkx.models.losses.loss.DefaultLoss task: ${dataset.parameters.task} diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index b92ee200..4b8b245b 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -8,12 +8,12 @@ feature_encoder: backbone: _target_: topomodelx.nn.simplicial.sccn.SCCN channels: ${model.feature_encoder.out_channels} - max_rank: 1 + max_rank: 2 n_layers: 1 update_func: "sigmoid" backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper + _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNWrapper _partial_: true out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/train.yaml b/configs/train.yaml index 57c43837..3b109e4c 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,7 +5,7 @@ defaults: - _self_ - dataset: PROTEINS_TU #us_country_demos - - model: hypergraph/allsettransformer #hypergraph/unignn2 #allsettransformer + - model: simplicial/sccn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/custom_models/cell/cin.py b/custom_models/cell/cin.py index deefffcf..db912b06 100644 --- a/custom_models/cell/cin.py +++ b/custom_models/cell/cin.py @@ -1,10 +1,10 @@ """CWN class.""" import torch -import torch.nn.functional as F -from topomodelx.nn.cell.cwn_layer import CWNLayer import torch.nn as nn +import torch.nn.functional as F from topomodelx.base.conv import Conv +from topomodelx.nn.cell.cwn_layer import CWNLayer from torch_geometric.nn.models import MLP @@ -65,7 +65,8 @@ def forward( neighborhood_2_to_1, neighborhood_0_to_1, ): - """Forward computation through projection, convolutions, linear layers and average pooling. + """Forward computation through projection, convolutions, linear layers + and average pooling. Parameters ---------- @@ -192,15 +193,21 @@ def __init__( self.conv_1_to_1 = ( conv_1_to_1 if conv_1_to_1 is not None - else _CWNDefaultFirstConv(in_channels_1, in_channels_2, out_channels) + else _CWNDefaultFirstConv( + in_channels_1, in_channels_2, out_channels + ) ) self.conv_0_to_1 = ( conv_0_to_1 if conv_0_to_1 is not None - else _CWNDefaultSecondConv(in_channels_0, in_channels_1, out_channels) + else _CWNDefaultSecondConv( + in_channels_0, in_channels_1, out_channels + ) ) self.aggregate_fn = ( - aggregate_fn if aggregate_fn is not None else _CWNDefaultAggregate() + aggregate_fn + if aggregate_fn is not None + else _CWNDefaultAggregate() ) self.update_fn = ( update_fn @@ -325,11 +332,10 @@ def forward( class _CWNDefaultFirstConv(nn.Module): - r""" - Default implementation of the first convolutional step in CWNLayer. + r"""Default implementation of the first convolutional step in CWNLayer. - The self.forward method of this module must be treated as - a protocol for the first convolutional step in CWN layer. + The self.forward method of this module must be treated as a protocol for + the first convolutional step in CWN layer. """ def __init__( @@ -383,11 +389,10 @@ def forward(self, x_1, x_2, neighborhood_1_to_1, neighborhood_2_to_1): class _CWNDefaultSecondConv(nn.Module): - r""" - Default implementation of the second convolutional step in CWNLayer. + r"""Default implementation of the second convolutional step in CWNLayer. - The self.forward method of this module must be treated as - a protocol for the second convolutional step in CWN layer. + The self.forward method of this module must be treated as a protocol for + the second convolutional step in CWN layer. """ def __init__(self, in_channels_0, out_channels) -> None: @@ -417,11 +422,10 @@ def forward(self, x_0, neighborhood_0_to_1): class _CWNDefaultAggregate(nn.Module): - r""" - Default implementation of an aggregation step in CWNLayer. + r"""Default implementation of an aggregation step in CWNLayer. - The self.forward method of this module must be treated as - a protocol for the aggregation step in CWN layer. + The self.forward method of this module must be treated as a protocol for + the aggregation step in CWN layer. """ def __init__(self) -> None: diff --git a/custom_models/hypergraph/edgnn.py b/custom_models/hypergraph/edgnn.py index 0888fc5b..3867ff62 100644 --- a/custom_models/hypergraph/edgnn.py +++ b/custom_models/hypergraph/edgnn.py @@ -46,7 +46,9 @@ def __init__( self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.BatchNorm1d(hidden_channels)) for _ in range(num_layers - 2): - self.lins.append(nn.Linear(hidden_channels, hidden_channels)) + self.lins.append( + nn.Linear(hidden_channels, hidden_channels) + ) self.normalizations.append(nn.BatchNorm1d(hidden_channels)) self.lins.append(nn.Linear(hidden_channels, out_channels)) elif Normalization == "ln": @@ -65,7 +67,9 @@ def __init__( self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.LayerNorm(hidden_channels)) for _ in range(num_layers - 2): - self.lins.append(nn.Linear(hidden_channels, hidden_channels)) + self.lins.append( + nn.Linear(hidden_channels, hidden_channels) + ) self.normalizations.append(nn.LayerNorm(hidden_channels)) self.lins.append(nn.Linear(hidden_channels, out_channels)) else: @@ -78,7 +82,9 @@ def __init__( self.lins.append(nn.Linear(in_channels, hidden_channels)) self.normalizations.append(nn.Identity()) for _ in range(num_layers - 2): - self.lins.append(nn.Linear(hidden_channels, hidden_channels)) + self.lins.append( + nn.Linear(hidden_channels, hidden_channels) + ) self.normalizations.append(nn.Identity()) self.lins.append(nn.Linear(hidden_channels, out_channels)) @@ -88,7 +94,7 @@ def reset_parameters(self): for lin in self.lins: lin.reset_parameters() for normalization in self.normalizations: - if not (normalization.__class__.__name__ == "Identity"): + if normalization.__class__.__name__ != "Identity": normalization.reset_parameters() def forward(self, x): @@ -245,7 +251,9 @@ def forward(self, X, vertex, edges, X0): class JumpLinkConv(nn.Module): - def __init__(self, in_features, out_features, mlp_layers=2, aggr="add", alpha=0.5): + def __init__( + self, in_features, out_features, mlp_layers=2, aggr="add", alpha=0.5 + ): super().__init__() self.W = MLP( in_features, @@ -339,7 +347,10 @@ def forward(self, X, vertex, edges, X0): ) # [E, C], reduce is 'mean' here as default deg_e = torch_scatter.scatter( - torch.ones(Xve.shape[0], device=Xve.device), edges, dim=-2, reduce="sum" + torch.ones(Xve.shape[0], device=Xve.device), + edges, + dim=-2, + reduce="sum", ) Xe = torch.cat([Xe, torch.log(deg_e)[..., None]], -1) @@ -350,7 +361,10 @@ def forward(self, X, vertex, edges, X0): ) # [N, C] deg_v = torch_scatter.scatter( - torch.ones(Xev.shape[0], device=Xev.device), vertex, dim=-2, reduce="sum" + torch.ones(Xev.shape[0], device=Xev.device), + vertex, + dim=-2, + reduce="sum", ) X = self.W3(torch.cat([Xv, X, X0, torch.log(deg_v)[..., None]], -1)) @@ -374,7 +388,7 @@ def __init__( normalization="None", AllSet_input_norm=False, ): - """EDGNN + """EDGNN. Args: num_features (int): number of input features @@ -390,7 +404,6 @@ def __init__( aggregate (str, optional): aggregation method. Defaults to 'add'. normalization (str, optional): normalization method. Defaults to 'None'. AllSet_input_norm (bool, optional): whether to normalize input features. Defaults to False. - """ super().__init__() act = {"Id": nn.Identity(), "relu": nn.ReLU(), "prelu": nn.PReLU()} @@ -402,8 +415,12 @@ def __init__( self.hidden_channels = self.in_channels self.mlp1_layers = MLP_num_layers - self.mlp2_layers = MLP_num_layers if MLP2_num_layers < 0 else MLP2_num_layers - self.mlp3_layers = MLP_num_layers if MLP3_num_layers < 0 else MLP3_num_layers + self.mlp2_layers = ( + MLP_num_layers if MLP2_num_layers < 0 else MLP2_num_layers + ) + self.mlp3_layers = ( + MLP_num_layers if MLP3_num_layers < 0 else MLP3_num_layers + ) self.nlayer = All_num_layers self.edconv_type = edconv_type diff --git a/custom_models/simplicial/sccnn.py b/custom_models/simplicial/sccnn.py index fa49754f..586ca909 100644 --- a/custom_models/simplicial/sccnn.py +++ b/custom_models/simplicial/sccnn.py @@ -5,7 +5,6 @@ from torch.nn.parameter import Parameter - class SCCNNCusctom(torch.nn.Module): """SCCNN implementation for complex classification. @@ -28,7 +27,6 @@ class SCCNNCusctom(torch.nn.Module): Update function for the simplicial complex convolution. n_layers: int Number of layers. - """ def __init__( @@ -44,9 +42,15 @@ def __init__( super().__init__() # first layer # we use an MLP to map the features on simplices of different dimensions to the same dimension - self.in_linear_0 = torch.nn.Linear(in_channels_all[0], hidden_channels_all[0]) - self.in_linear_1 = torch.nn.Linear(in_channels_all[1], hidden_channels_all[1]) - self.in_linear_2 = torch.nn.Linear(in_channels_all[2], hidden_channels_all[2]) + self.in_linear_0 = torch.nn.Linear( + in_channels_all[0], hidden_channels_all[0] + ) + self.in_linear_1 = torch.nn.Linear( + in_channels_all[1], hidden_channels_all[1] + ) + self.in_linear_2 = torch.nn.Linear( + in_channels_all[2], hidden_channels_all[2] + ) self.layers = torch.nn.ModuleList( SCCNNLayer( @@ -100,6 +104,7 @@ def forward(self, x_all, laplacian_all, incidence_all): # Layer """Simplicial Complex Convolutional Neural Network Layer.""" + class SCCNNLayer(torch.nn.Module): r"""Layer of a Simplicial Complex Convolutional Neural Network. @@ -215,7 +220,9 @@ def __init__( self.weight_0 = Parameter( torch.Tensor( - self.in_channels_0, self.out_channels_0, 1 + conv_order + 1 + conv_order + self.in_channels_0, + self.out_channels_0, + 1 + conv_order + 1 + conv_order, ) ) @@ -326,7 +333,9 @@ def chebyshev_conv(self, conv_operator, conv_order, x): Output tensor. x[:, :, k] = (conv_operator@....@conv_operator) @ x. """ num_simplices, num_channels = x.shape - X = torch.empty(size=(num_simplices, num_channels, conv_order)).to(x.device) + X = torch.empty(size=(num_simplices, num_channels, conv_order)).to( + x.device + ) if self.aggr_norm: X[:, :, 0] = torch.mm(conv_operator, x) @@ -388,7 +397,9 @@ def forward(self, x_all, laplacian_all, incidence_all): x_0, x_1, x_2 = x_all if self.sc_order == 2: - laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = laplacian_all + laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2 = ( + laplacian_all + ) elif self.sc_order > 2: ( laplacian_0, @@ -407,7 +418,6 @@ def forward(self, x_all, laplacian_all, incidence_all): # torch.eye(num_edges).to(x_0.device), # torch.eye(num_triangles).to(x_0.device), # ) - """ Convolution in the node space """ @@ -429,7 +439,9 @@ def forward(self, x_all, laplacian_all, incidence_all): x_1_to_0_laplacian = self.chebyshev_conv( laplacian_0, self.conv_order, x_1_to_0_upper ) - x_1_to_0 = torch.cat([x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2) + x_1_to_0 = torch.cat( + [x_1_to_0_upper.unsqueeze(2), x_1_to_0_laplacian], dim=2 + ) # ------------------- x_0_all = torch.cat((x_0_to_0, x_1_to_0), 2) @@ -460,13 +472,19 @@ def forward(self, x_all, laplacian_all, incidence_all): x_0_1_lower = torch.mm(b1.T, x_0) # Calculate lowwer chebyshev_conv - x_0_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_0_1_lower) + x_0_1_down = self.chebyshev_conv( + laplacian_down_1, self.conv_order, x_0_1_lower + ) # Calculate upper chebyshev_conv (Note: in case of signed incidence should be always zero) - x_0_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_0_1_lower) + x_0_1_up = self.chebyshev_conv( + laplacian_up_1, self.conv_order, x_0_1_lower + ) # Concatenate output of filters - x_0_to_1 = torch.cat([x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2) + x_0_to_1 = torch.cat( + [x_0_1_lower.unsqueeze(2), x_0_1_down, x_0_1_up], dim=2 + ) # ------------------- # x_2_to_1 = torch.mm(b2, x_2) @@ -477,20 +495,23 @@ def forward(self, x_all, laplacian_all, incidence_all): x_2_1_upper = torch.mm(b2, x_2) # Calculate lowwer chebyshev_conv (Note: In case of signed incidence should be always zero) - x_2_1_down = self.chebyshev_conv(laplacian_down_1, self.conv_order, x_2_1_upper) + x_2_1_down = self.chebyshev_conv( + laplacian_down_1, self.conv_order, x_2_1_upper + ) # Calculate upper chebyshev_conv - x_2_1_up = self.chebyshev_conv(laplacian_up_1, self.conv_order, x_2_1_upper) + x_2_1_up = self.chebyshev_conv( + laplacian_up_1, self.conv_order, x_2_1_upper + ) - x_2_to_1 = torch.cat([x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2) + x_2_to_1 = torch.cat( + [x_2_1_upper.unsqueeze(2), x_2_1_down, x_2_1_up], dim=2 + ) # ------------------- x_1_all = torch.cat((x_0_to_1, x_1_to_1, x_2_to_1), 2) - - """ - convolution in the face (triangle) space, depending on the SC order, - the exact form maybe a little different - """ + """Convolution in the face (triangle) space, depending on the SC order, + the exact form maybe a little different.""" # -------------------Logic to obtain update for 2-cells -------- # x_identity_2 = torch.unsqueeze(identity_2 @ x_2, 2) @@ -516,10 +537,16 @@ def forward(self, x_all, laplacian_all, incidence_all): # x_1_to_2 = torch.cat((x_1_to_2_identity, x_1_to_2), 2) x_1_2_lower = torch.mm(b2.T, x_1) - x_1_2_down = self.chebyshev_conv(laplacian_down_2, self.conv_order, x_1_2_lower) - x_1_2_down = self.chebyshev_conv(laplacian_up_2, self.conv_order, x_1_2_lower) + x_1_2_down = self.chebyshev_conv( + laplacian_down_2, self.conv_order, x_1_2_lower + ) + x_1_2_down = self.chebyshev_conv( + laplacian_up_2, self.conv_order, x_1_2_lower + ) - x_1_to_2 = torch.cat([x_1_2_lower.unsqueeze(2), x_1_2_down, x_1_2_down], dim=2) + x_1_to_2 = torch.cat( + [x_1_2_lower.unsqueeze(2), x_1_2_down, x_1_2_down], dim=2 + ) # That is my code, but to execute this part we need to have simplices order of k+1 in this case order of 3 # x_3_2_upper = x_1_to_2 = torch.mm(b2, x_3) diff --git a/format_and_lint.sh b/format_and_lint.sh new file mode 100755 index 00000000..e7795e8c --- /dev/null +++ b/format_and_lint.sh @@ -0,0 +1,10 @@ +#!/bin/sh + +# Run ruff to check for issues and fix them +ruff check . --fix + +# Run docformatter to reformat docstrings and comments +docformatter --in-place --recursive --wrap-summaries 79 --wrap-descriptions 79 . + +# Run black to format the code +black . \ No newline at end of file diff --git a/notebooks/curvature_results.ipynb b/notebooks/curvature_results.ipynb index 4c61ea97..ea591cdc 100644 --- a/notebooks/curvature_results.ipynb +++ b/notebooks/curvature_results.ipynb @@ -6,15 +6,14 @@ "metadata": {}, "outputs": [], "source": [ - "import pandas as pd\n", - "import wandb\n", - "import pandas as pd\n", "import ast\n", "import glob\n", - "import numpy as np\n", "import warnings\n", - "from datetime import date\n", "from collections import defaultdict\n", + "from datetime import date\n", + "\n", + "import pandas as pd\n", + "import wandb\n", "\n", "today = date.today()\n", "api = wandb.Api()\n", @@ -712,9 +711,9 @@ " ]\n", "\n", " if subset.empty:\n", - " print(f\"---------\")\n", + " print(\"---------\")\n", " print(f\"No results for {model} on {dataset}\")\n", - " print(f\"---------\")\n", + " print(\"---------\")\n", " continue\n", " # Suppress all warnings\n", " warnings.filterwarnings(\"ignore\")\n", @@ -749,7 +748,7 @@ " for col, unique in unique_colums_values.items():\n", " print(f\"{col}: {unique}\")\n", " print()\n", - " print(f\"---------\")\n", + " print(\"---------\")\n", "\n", " # Check if \"special colums\" are not in unique_colums_values\n", " # For example dataset.parameters.data_seed should not be in aggregation columns\n", @@ -1486,7 +1485,7 @@ "result_dict = pd.DataFrame.from_dict(\n", " {\n", " (i, j): nested_dict[i][j]\n", - " for i in nested_dict.keys()\n", + " for i in nested_dict\n", " for j in nested_dict[i].keys()\n", " },\n", " orient=\"index\",\n", diff --git a/notebooks/data.ipynb b/notebooks/data.ipynb index e34a964b..696a2b55 100755 --- a/notebooks/data.ipynb +++ b/notebooks/data.ipynb @@ -25,18 +25,17 @@ "%load_ext autoreload\n", "%autoreload 2\n", "\n", + "import hydra\n", "import torch\n", "import torch_geometric\n", - "from topobenchmarkx.data.datasets import CustomDataset\n", - "import hydra\n", - "from hydra import initialize, compose\n", + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "\n", "from topobenchmarkx.data.dataloader_fullbatch import FullBatchDataModule\n", + "from topobenchmarkx.data.datasets import CustomDataset\n", "from topobenchmarkx.io.load.loaders import (\n", - " GraphLoader,\n", - " SimplicialLoader,\n", " HypergraphLoader,\n", ")\n", - "from omegaconf import DictConfig, OmegaConf\n", "from topobenchmarkx.utils.config_resolvers import (\n", " get_default_transform,\n", " get_monitor_metric,\n", @@ -44,7 +43,6 @@ " infer_in_channels,\n", ")\n", "\n", - "\n", "OmegaConf.register_new_resolver(\"get_default_transform\", get_default_transform)\n", "OmegaConf.register_new_resolver(\"get_monitor_metric\", get_monitor_metric)\n", "OmegaConf.register_new_resolver(\"get_monitor_mode\", get_monitor_mode)\n", @@ -114,9 +112,6 @@ } ], "source": [ - "import torch\n", - "import torch_geometric\n", - "import numpy as np\n", "\n", "nci1 = torch_geometric.datasets.TUDataset(\n", " root=\".\",\n", @@ -712,8 +707,7 @@ } ], "source": [ - "from lightning import Callback, LightningDataModule, LightningModule, Trainer\n", - "from lightning.pytorch.loggers import Logger\n", + "from lightning import LightningModule\n", "\n", "model: LightningModule = hydra.utils.instantiate(config.model)" ] @@ -1012,9 +1006,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from topobenchmarkx.data.datasets import CustomDataset" - ] + "source": [] }, { "cell_type": "code", @@ -1050,7 +1042,7 @@ " for b in batch:\n", " values, keys = b[0], b[1]\n", " data = Data()\n", - " for key, value in zip(keys, values):\n", + " for key, value in zip(keys, values, strict=False):\n", " data[key] = value\n", "\n", " return data\n", @@ -1181,7 +1173,7 @@ "outputs": [], "source": [ "import torch\n", - "from torch.utils.data import Dataset, DataLoader\n", + "from torch.utils.data import DataLoader, Dataset\n", "\n", "\n", "class TextDataset(Dataset):\n", @@ -1495,7 +1487,6 @@ "source": [ "# Load data\n", "from topobenchmarkx.data.load.loaders import HypergraphLoader\n", - "from topobenchmarkx.data.dataloader_fullbatch import FullBatchDataModule\n", "\n", "data_loader = HypergraphLoader(config)\n", "data = data_loader.load()\n", @@ -1562,7 +1553,7 @@ "metadata": {}, "outputs": [], "source": [ - "#b in []topomodelx.nn.hypergraph.unigcnii.UniGCNII" + "# b in []topomodelx.nn.hypergraph.unigcnii.UniGCNII" ] }, { @@ -1570,9 +1561,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "import topomodelx" - ] + "source": [] }, { "cell_type": "code", diff --git a/notebooks/play.ipynb b/notebooks/play.ipynb index 563e7c2d..93dc9f33 100644 --- a/notebooks/play.ipynb +++ b/notebooks/play.ipynb @@ -6,55 +6,58 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import os\n", + "import urllib.request\n", + "\n", + "import numpy as np\n", "import torch\n", "import torch_geometric\n", - "import urllib.request\n", "\n", "\n", "def hetero_load(name, path):\n", - " file_name = f'{name}.npz'\n", + " file_name = f\"{name}.npz\"\n", "\n", " data = np.load(os.path.join(path, file_name))\n", "\n", - " x = torch.tensor(data['node_features'])\n", - " y = torch.tensor(data['node_labels'])\n", - " edge_index = torch.tensor(data['edges']).T\n", + " x = torch.tensor(data[\"node_features\"])\n", + " y = torch.tensor(data[\"node_labels\"])\n", + " edge_index = torch.tensor(data[\"edges\"]).T\n", "\n", " # Make edge_index undirected\n", " edge_index = torch_geometric.utils.to_undirected(edge_index)\n", "\n", " # Remove self-loops\n", " edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)\n", - " \n", + "\n", " data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)\n", " return data\n", "\n", + "\n", "def download_hetero_datasets(name, path):\n", - " url = 'https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/'\n", - " name = f'{name}.npz'\n", + " url = \"https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/\"\n", + " name = f\"{name}.npz\"\n", " try:\n", - " print(f'Downloading {name}')\n", + " print(f\"Downloading {name}\")\n", " path2save = os.path.join(path, name)\n", " urllib.request.urlretrieve(url + name, path2save)\n", - " print('Done!')\n", + " print(\"Done!\")\n", " except:\n", - " raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')\n", - "\n", + " raise Exception(\n", + " \"\"\"Download failed! Make sure you have stable Internet connection and enter the right name\"\"\"\n", + " )\n", "\n", "\n", "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", - "import torch\n", "from omegaconf import DictConfig\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.io import fs\n", "\n", - "from topobenchmarkx.io.load.heterophilic import download_hetero_datasets, load_heterophilic_data\n", - "\n", + "from topobenchmarkx.io.load.heterophilic import (\n", + " download_hetero_datasets,\n", + " load_heterophilic_data,\n", + ")\n", "from topobenchmarkx.io.load.split_utils import random_splitting\n", "\n", "\n", @@ -97,14 +100,14 @@ " root: str,\n", " name: str,\n", " parameters: DictConfig,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " use_node_attr: bool = False,\n", " use_edge_attr: bool = False,\n", " ) -> None:\n", - " self.name = name #.replace(\"_\", \"-\")\n", + " self.name = name # .replace(\"_\", \"-\")\n", " self.parameters = parameters\n", " super().__init__(\n", " root, transform, pre_transform, pre_filter, force_reload=force_reload\n", @@ -144,7 +147,7 @@ " @property\n", " def processed_file_names(self) -> str:\n", " return \"data.pt\"\n", - " \n", + "\n", " @property\n", " def raw_file_names(self) -> list[str]:\n", " \"\"\"Spefify the downloaded raw fine name\"\"\"\n", @@ -171,7 +174,7 @@ " Returns:\n", " None\n", " \"\"\"\n", - " \n", + "\n", " data = load_heterophilic_data(name=self.name, path=self.raw_dir)\n", " data = data if self.pre_transform is None else self.pre_transform(data)\n", " self.save([data], self.processed_paths[0])\n", @@ -180,23 +183,24 @@ " return f\"{self.name}()\"\n", "\n", "\n", + "data_dir = \"/home/lev/projects/TopoBenchmarkX/datasets\"\n", + "data_domain = \"graph\"\n", + "data_type = \"heterophilic\"\n", + "data_name = \"amazon_ratings\"\n", "\n", - "data_dir = '/home/lev/projects/TopoBenchmarkX/datasets'\n", - "data_domain = 'graph'\n", - "data_type = 'heterophilic'\n", - "data_name = 'amazon_ratings'\n", - "\n", - "data_dir = f'{data_dir}/{data_domain}/{data_type}'\n", + "data_dir = f\"{data_dir}/{data_domain}/{data_type}\"\n", "\n", - "parameters={\n", - " 'split_type': 'random',\n", - " 'k': 10,\n", - " 'train_prop': 0.5,\n", - " 'data_seed':0,\n", - " 'data_split_dir': f'/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}'\n", - " }\n", + "parameters = {\n", + " \"split_type\": \"random\",\n", + " \"k\": 10,\n", + " \"train_prop\": 0.5,\n", + " \"data_seed\": 0,\n", + " \"data_split_dir\": f\"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}\",\n", + "}\n", "\n", - "dataset = HeteroDataset(name=data_name, root = data_dir, parameters=parameters, force_reload=True)" + "dataset = HeteroDataset(\n", + " name=data_name, root=data_dir, parameters=parameters, force_reload=True\n", + ")" ] }, { diff --git a/notebooks/result_processing.ipynb b/notebooks/result_processing.ipynb index a289ba0f..c684a18d 100644 --- a/notebooks/result_processing.ipynb +++ b/notebooks/result_processing.ipynb @@ -2679,7 +2679,7 @@ "a[\"dataset.transforms.graph2simplicial_lifting.feature_lifting\"][\n", " a[\"dataset.transforms.graph2simplicial_lifting.feature_lifting\"].isna()\n", "] = \"projection\"\n", - "#a = a[~a[\"test/mae\"].isna()]\n", + "# a = a[~a[\"test/mae\"].isna()]\n", "a = a[~a[\"test/accuracy\"].isna()]\n", "\n", "a = a.groupby(\n", @@ -2692,7 +2692,7 @@ ").agg({col: [\"mean\", \"std\"] for col in performance_cols})\n", "\n", "ascending = True\n", - "#a = a.sort_values(by=(\"test/mae\", \"mean\"), ascending=ascending)\n", + "# a = a.sort_values(by=(\"test/mae\", \"mean\"), ascending=ascending)\n", "a = a.sort_values(by=(\"test/accuracy\", \"mean\"), ascending=ascending)\n", "# Show all rows\n", "pd.set_option(\"display.max_rows\", None)\n", @@ -6313,7 +6313,7 @@ "result_dict = pd.DataFrame.from_dict(\n", " {\n", " (i, j): nested_dict[i][j]\n", - " for i in nested_dict.keys()\n", + " for i in nested_dict\n", " for j in nested_dict[i].keys()\n", " },\n", " orient=\"index\",\n", @@ -6605,7 +6605,6 @@ ], "source": [ "import matplotlib.pyplot as plt\n", - "import numpy as np\n", "\n", "\n", "# Define the vector field function\n", diff --git a/notebooks/test_feature_lifting_dev.ipynb b/notebooks/test_feature_lifting_dev.ipynb index 704b350a..c2661f1d 100644 --- a/notebooks/test_feature_lifting_dev.ipynb +++ b/notebooks/test_feature_lifting_dev.ipynb @@ -47,7 +47,9 @@ "from topobenchmarkx.transforms.feature_liftings.feature_liftings import (\n", " ProjectionLifting,\n", ")\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", + "from topobenchmarkx.transforms.liftings.graph2simplicial import (\n", + " SimplicialCliqueLifting,\n", + ")\n", "\n", "\n", "class TestProjectionLifting:\n", @@ -134,13 +136,10 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", - "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", "from topobenchmarkx.transforms.feature_liftings.feature_liftings import (\n", " ConcatentionLifting,\n", ")\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", "\n", "\n", "class TestConcatentionLifting:\n", @@ -254,11 +253,10 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", - "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", - "from topobenchmarkx.transforms.feature_liftings.feature_liftings import SetLifting\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", + "from topobenchmarkx.transforms.feature_liftings.feature_liftings import (\n", + " SetLifting,\n", + ")\n", "\n", "\n", "class TestSetLifting:\n", diff --git a/notebooks/test_hypergraph_lifting_dev.ipynb b/notebooks/test_hypergraph_lifting_dev.ipynb index 635965f6..90abf896 100644 --- a/notebooks/test_hypergraph_lifting_dev.ipynb +++ b/notebooks/test_hypergraph_lifting_dev.ipynb @@ -113,14 +113,8 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", - "from topobenchmarkx.io.load.loaders import manual_graph\n", - "from topobenchmarkx.transforms.liftings.graph2hypergraph import (\n", - " HypergraphKHopLifting,\n", - " HypergraphKNearestNeighborsLifting,\n", - ")\n", "\n", "\n", "class TestHypergraphKNearestNeighborsLifting:\n", diff --git a/notebooks/test_simplicialclique_dev.ipynb b/notebooks/test_simplicialclique_dev.ipynb index 2b69bb64..b2f8ad8a 100644 --- a/notebooks/test_simplicialclique_dev.ipynb +++ b/notebooks/test_simplicialclique_dev.ipynb @@ -23,7 +23,9 @@ "\n", "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", - "from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting\n", + "from topobenchmarkx.transforms.liftings.graph2simplicial import (\n", + " SimplicialCliqueLifting,\n", + ")\n", "\n", "\n", "class TestSimplicialCliqueLifting:\n", @@ -219,10 +221,8 @@ "\"\"\"Test the message passing module.\"\"\"\n", "\n", "import rootutils\n", - "import torch\n", "\n", "rootutils.setup_root(\"./\", indicator=\".project-root\", pythonpath=True)\n", - "from topobenchmarkx.io.load.loaders import manual_simple_graph\n", "from topobenchmarkx.transforms.liftings.graph2simplicial import (\n", " SimplicialNeighborhoodLifting,\n", ")\n", diff --git a/pyproject.toml b/pyproject.toml index 3b155359..6d1e47be 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,9 +78,18 @@ all = ["TopoBenchmarkX[dev, doc]"] homepage="https://github.com/pyt-team/TopoBenchmarkX" repository="https://github.com/pyt-team/TopoBenchmarkX" +[tool.black] +line-length = 79 # PEP 8 standard for maximum line length +target-version = ['py310'] + +[tool.docformatter] +wrap-summaries = 79 +wrap-descriptions = 79 + [tool.ruff] target-version = "py310" extend-include = ["*.ipynb"] +line-length = 79 # PEP 8 standard for maximum line length [tool.ruff.format] docstring-code-format = false diff --git a/test/transforms/feature_liftings/test_ConcatenationLifting.py b/test/transforms/feature_liftings/test_ConcatenationLifting.py index 87deddef..6cf9122c 100644 --- a/test/transforms/feature_liftings/test_ConcatenationLifting.py +++ b/test/transforms/feature_liftings/test_ConcatenationLifting.py @@ -6,7 +6,9 @@ from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( ConcatentionLifting, ) -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestConcatentionLifting: diff --git a/test/transforms/feature_liftings/test_ProjectionLifting.py b/test/transforms/feature_liftings/test_ProjectionLifting.py index 00dd38e5..065b8aff 100644 --- a/test/transforms/feature_liftings/test_ProjectionLifting.py +++ b/test/transforms/feature_liftings/test_ProjectionLifting.py @@ -6,7 +6,9 @@ from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( ProjectionLifting, ) -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestProjectionLifting: diff --git a/test/transforms/feature_liftings/test_SetLifting.py b/test/transforms/feature_liftings/test_SetLifting.py index a36bf62c..145c3eb8 100644 --- a/test/transforms/feature_liftings/test_SetLifting.py +++ b/test/transforms/feature_liftings/test_SetLifting.py @@ -3,8 +3,12 @@ import torch from topobenchmarkx.io.load.loaders import manual_simple_graph -from topobenchmarkx.transforms.feature_liftings.feature_liftings import SetLifting -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( + SetLifting, +) +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestSetLifting: diff --git a/test/transforms/liftings/cell/test_CellCyclesLifting.py b/test/transforms/liftings/cell/test_CellCyclesLifting.py index 364152d7..37f9005d 100644 --- a/test/transforms/liftings/cell/test_CellCyclesLifting.py +++ b/test/transforms/liftings/cell/test_CellCyclesLifting.py @@ -22,14 +22,126 @@ def test_lift_topology(self): expected_incidence_1 = torch.tensor( [ - [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [ + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 1.0, + 1.0, + 1.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 1.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + ], ] ) diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py index 384bf849..a961709d 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKHopLifting.py @@ -3,7 +3,9 @@ import torch from topobenchmarkx.io.load.loaders import manual_graph -from topobenchmarkx.transforms.liftings.graph2hypergraph import HypergraphKHopLifting +from topobenchmarkx.transforms.liftings.graph2hypergraph import ( + HypergraphKHopLifting, +) class TestHypergraphKHopLifting: @@ -38,7 +40,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k1.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k1.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=1)." assert ( expected_n_hyperedges == lifted_data_k1.num_hyperedges @@ -63,7 +66,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k2.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k2.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=2)." assert ( expected_n_hyperedges == lifted_data_k2.num_hyperedges diff --git a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py index 01b4ce9a..0807dd9c 100644 --- a/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py +++ b/test/transforms/liftings/hypergraph/test_HypergraphKNearestNeighborsLifting.py @@ -16,8 +16,12 @@ def setup_method(self): self.data = manual_graph() # Initialise the HypergraphKNearestNeighborsLifting class - self.lifting_k2 = HypergraphKNearestNeighborsLifting(k_value=2, loop=True) - self.lifting_k3 = HypergraphKNearestNeighborsLifting(k_value=3, loop=True) + self.lifting_k2 = HypergraphKNearestNeighborsLifting( + k_value=2, loop=True + ) + self.lifting_k3 = HypergraphKNearestNeighborsLifting( + k_value=3, loop=True + ) def test_lift_topology(self): # Test the lift_topology method @@ -40,7 +44,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k2.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k2.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=2)." assert ( expected_n_hyperedges == lifted_data_k2.num_hyperedges @@ -65,7 +70,8 @@ def test_lift_topology(self): ) assert ( - expected_incidence_1 == lifted_data_k3.incidence_hyperedges.to_dense() + expected_incidence_1 + == lifted_data_k3.incidence_hyperedges.to_dense() ).all(), "Something is wrong with incidence_hyperedges (k=3)." assert ( expected_n_hyperedges == lifted_data_k3.num_hyperedges diff --git a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py index d9ff1f18..8b23334c 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialCliqueLifting.py @@ -3,7 +3,9 @@ import torch from topobenchmarkx.io.load.loaders import manual_simple_graph -from topobenchmarkx.transforms.liftings.graph2simplicial import SimplicialCliqueLifting +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, +) class TestSimplicialCliqueLifting: @@ -14,8 +16,12 @@ def setup_method(self): self.data = manual_simple_graph() # Initialise the SimplicialCliqueLifting class - self.lifting_signed = SimplicialCliqueLifting(complex_dim=3, signed=True) - self.lifting_unsigned = SimplicialCliqueLifting(complex_dim=3, signed=False) + self.lifting_signed = SimplicialCliqueLifting( + complex_dim=3, signed=True + ) + self.lifting_unsigned = SimplicialCliqueLifting( + complex_dim=3, signed=False + ) def test_lift_topology(self): """Test the lift_topology method.""" @@ -26,20 +32,135 @@ def test_lift_topology(self): expected_incidence_1 = torch.tensor( [ - [-1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [1.0, 0.0, 0.0, 0.0, -1.0, -1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 1.0, 0.0, 0.0, 1.0, 0.0, -1.0, -1.0, -1.0, -1.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0], - [0.0, 0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, -1.0, -1.0], - [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0], - [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0], + [ + -1.0, + -1.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + -1.0, + -1.0, + -1.0, + -1.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + -1.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + -1.0, + -1.0, + ], + [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 1.0, + 0.0, + ], + [ + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + 0.0, + 1.0, + ], ] ) assert ( - abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense() - ).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)." + abs(expected_incidence_1) + == lifted_data_unsigned.incidence_1.to_dense() + ).all(), ( + "Something is wrong with unsigned incidence_1 (nodes to edges)." + ) assert ( expected_incidence_1 == lifted_data_signed.incidence_1.to_dense() ).all(), "Something is wrong with signed incidence_1 (nodes to edges)." @@ -63,26 +184,26 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense() + abs(expected_incidence_2) + == lifted_data_unsigned.incidence_2.to_dense() ).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)." assert ( expected_incidence_2 == lifted_data_signed.incidence_2.to_dense() - ).all(), "Something is wrong with signed incidence_2 (edges to triangles)." + ).all(), ( + "Something is wrong with signed incidence_2 (edges to triangles)." + ) expected_incidence_3 = torch.tensor( [[-1.0], [1.0], [-1.0], [0.0], [1.0], [0.0]] ) assert ( - abs(expected_incidence_3) == lifted_data_unsigned.incidence_3.to_dense() - ).all(), ( - "Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)." - ) + abs(expected_incidence_3) + == lifted_data_unsigned.incidence_3.to_dense() + ).all(), "Something is wrong with unsigned incidence_3 (triangles to tetrahedrons)." assert ( expected_incidence_3 == lifted_data_signed.incidence_3.to_dense() - ).all(), ( - "Something is wrong with signed incidence_3 (triangles to tetrahedrons)." - ) + ).all(), "Something is wrong with signed incidence_3 (triangles to tetrahedrons)." def test_lifted_features_signed(self): """Test the lift_features method in signed incidence cases.""" @@ -112,7 +233,9 @@ def test_lifted_features_signed(self): expected_features_1 == lifted_data.x_1 ).all(), "Something is wrong with x_1 features." - expected_features_2 = torch.tensor([[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]]) + expected_features_2 = torch.tensor( + [[0.0], [0.0], [0.0], [0.0], [0.0], [0.0]] + ) assert ( expected_features_2 == lifted_data.x_2 diff --git a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py index 2ea913e2..ac07a745 100644 --- a/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py +++ b/test/transforms/liftings/simplicial/test_SimplicialNeighborhoodLifting.py @@ -16,7 +16,9 @@ def setup_method(self): self.data = manual_simple_graph() # Initialise the SimplicialNeighborhoodLifting class - self.lifting_signed = SimplicialNeighborhoodLifting(complex_dim=3, signed=True) + self.lifting_signed = SimplicialNeighborhoodLifting( + complex_dim=3, signed=True + ) self.lifting_unsigned = SimplicialNeighborhoodLifting( complex_dim=3, signed=False ) @@ -247,8 +249,11 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_1) == lifted_data_unsigned.incidence_1.to_dense() - ).all(), "Something is wrong with unsigned incidence_1 (nodes to edges)." + abs(expected_incidence_1) + == lifted_data_unsigned.incidence_1.to_dense() + ).all(), ( + "Something is wrong with unsigned incidence_1 (nodes to edges)." + ) assert ( expected_incidence_1 == lifted_data_signed.incidence_1.to_dense() ).all(), "Something is wrong with signed incidence_1 (nodes to edges)." @@ -1309,11 +1314,14 @@ def test_lift_topology(self): ) assert ( - abs(expected_incidence_2) == lifted_data_unsigned.incidence_2.to_dense() + abs(expected_incidence_2) + == lifted_data_unsigned.incidence_2.to_dense() ).all(), "Something is wrong with unsigned incidence_2 (edges to triangles)." assert ( expected_incidence_2 == lifted_data_signed.incidence_2.to_dense() - ).all(), "Something is wrong with signed incidence_2 (edges to triangles)." + ).all(), ( + "Something is wrong with signed incidence_2 (edges to triangles)." + ) def test_lifted_features_signed(self): # Test the lift_features method for signed case diff --git a/topobenchmarkx/data/cornel_dataset.ipynb b/topobenchmarkx/data/cornel_dataset.ipynb index 718f7ff1..e6eeae39 100644 --- a/topobenchmarkx/data/cornel_dataset.ipynb +++ b/topobenchmarkx/data/cornel_dataset.ipynb @@ -15,7 +15,6 @@ "\n", "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.io import fs\n", @@ -43,9 +42,9 @@ " root: str,\n", " name: str,\n", " parameters: dict = None,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " use_node_attr: bool = False,\n", " use_edge_attr: bool = False,\n", diff --git a/topobenchmarkx/data/dataloader_fullbatch.py b/topobenchmarkx/data/dataloaders.py similarity index 60% rename from topobenchmarkx/data/dataloader_fullbatch.py rename to topobenchmarkx/data/dataloaders.py index 2ab9124d..c58edb68 100755 --- a/topobenchmarkx/data/dataloader_fullbatch.py +++ b/topobenchmarkx/data/dataloaders.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Any, Optional +from typing import Any import torch from lightning import LightningDataModule @@ -10,9 +10,10 @@ class MyData(Data): - """ - Data object class that overwrites some methods from torch_geometric.data.Data so that not only sparse matrices with adj in the name can work with the torch_geometric dataloaders. - """ + """Data object class that overwrites some methods from + torch_geometric.data.Data so that not only sparse matrices with adj in the + name can work with the torch_geometric dataloaders.""" + def is_valid(self, string): valid_names = ["adj", "incidence", "laplacian"] return any(name in string for name in valid_names) @@ -27,9 +28,8 @@ def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: def to_data_list(batch): - """ - Workaround needed since torch_geometric doesn't work well with torch.sparse - """ + """Workaround needed since torch_geometric doesn't work well with + torch.sparse.""" for key in batch: if batch[key].is_sparse: sparse_data = batch[key].coalesce() @@ -78,7 +78,10 @@ def collate_fn(batch): torch.tensor([[batch_idx] * current_number_of_cells]) ) - if running_idx.get(f"cell_running_idx_number_{cell_dim}") is None: + if ( + running_idx.get(f"cell_running_idx_number_{cell_dim}") + is None + ): running_idx[f"cell_running_idx_number_{cell_dim}"] = ( current_number_of_cells # current_number_of_nodes ) @@ -89,9 +92,9 @@ def collate_fn(batch): + running_idx[f"cell_running_idx_number_{cell_dim}"] ).long() - running_idx[ - f"cell_running_idx_number_{cell_dim}" - ] += current_number_of_cells # current_number_of_nodes + running_idx[f"cell_running_idx_number_{cell_dim}"] += ( + current_number_of_cells # current_number_of_nodes + ) elif x_key == "x_hyperedges": cell_dim = x_key.split("_")[1] @@ -101,7 +104,10 @@ def collate_fn(batch): torch.tensor([[batch_idx] * current_number_of_hyperedges]) ) - if running_idx.get(f"cell_running_idx_number_{cell_dim}") is None: + if ( + running_idx.get(f"cell_running_idx_number_{cell_dim}") + is None + ): running_idx[f"cell_running_idx_number_{cell_dim}"] = ( current_number_of_hyperedges ) @@ -112,9 +118,9 @@ def collate_fn(batch): + running_idx[f"cell_running_idx_number_{cell_dim}"] ).long() - running_idx[ - f"cell_running_idx_number_{cell_dim}" - ] += current_number_of_hyperedges + running_idx[f"cell_running_idx_number_{cell_dim}"] += ( + current_number_of_hyperedges + ) else: # Function Batch.from_data_list creates a running index automatically pass @@ -122,7 +128,7 @@ def collate_fn(batch): data_list.append(data) batch = Batch.from_data_list(data_list) - + # Rename batch.batch to batch.batch_0 for consistency batch["batch_0"] = batch.pop("batch") @@ -132,108 +138,8 @@ def collate_fn(batch): return batch -# class FullBatchDataModule(LightningDataModule): -# """ - -# Read the docs: -# https://lightning.ai/docs/pytorch/latest/data/datamodule.html -# """ - -# def __init__( -# self, -# dataset, -# batch_size: int = 64, -# num_workers: int = 0, -# pin_memory: bool = False, -# ) -> None: -# """Initialize a `MNISTDataModule`. - -# :param data_dir: The data directory. Defaults to `"data/"`. -# :param train_val_test_split: The train, validation and test split. Defaults to `(55_000, 5_000, 10_000)`. -# :param batch_size: The batch size. Defaults to `64`. -# :param num_workers: The number of workers. Defaults to `0`. -# :param pin_memory: Whether to pin memory. Defaults to `False`. -# """ -# super().__init__() - -# # this line allows to access init params with 'self.hparams' attribute -# # also ensures init params will be stored in ckpt -# self.save_hyperparameters(logger=False) - -# self.dataset = dataset -# self.batch_size = batch_size - -# def train_dataloader(self) -> DataLoader: -# """Create and return the train dataloader. - -# :return: The train dataloader. -# """ -# return DataLoader( -# dataset=self.dataset, -# batch_size=1, -# num_workers=self.hparams.num_workers, -# pin_memory=self.hparams.pin_memory, -# # persistent_workers=True, -# shuffle=True, -# collate_fn=collate_fn, -# ) - -# def val_dataloader(self) -> DataLoader: -# """Create and return the validation dataloader. - -# :return: The validation dataloader. -# """ -# return DataLoader( -# dataset=self.dataset, -# batch_size=1, -# num_workers=self.hparams.num_workers, -# pin_memory=self.hparams.pin_memory, -# # persistent_workers=True, -# shuffle=False, -# collate_fn=collate_fn, -# ) - -# def test_dataloader(self) -> DataLoader: -# """Create and return the test dataloader. - -# :return: The test dataloader. -# """ -# return DataLoader( -# dataset=self.dataset, -# batch_size=1, -# num_workers=self.hparams.num_workers, -# pin_memory=self.hparams.pin_memory, -# # persistent_workers=True, -# shuffle=False, -# collate_fn=collate_fn, -# ) - -# def teardown(self, stage: Optional[str] = None) -> None: -# """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, -# `trainer.test()`, and `trainer.predict()`. - -# :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. -# Defaults to ``None``. -# """ - -# def state_dict(self) -> dict[Any, Any]: -# """Called when saving a checkpoint. Implement to generate and save the datamodule state. - -# :return: A dictionary containing the datamodule state that you want to save. -# """ -# return {} - -# def load_state_dict(self, state_dict: dict[str, Any]) -> None: -# """Called when loading a checkpoint. Implement to reload datamodule state given datamodule -# `state_dict()`. - -# :param state_dict: The datamodule state returned by `self.state_dict()`. -# """ - - class DefaultDataModule(LightningDataModule): - """ - Initializes the DefaultDataModule class. + """Initializes the DefaultDataModule class. Args: dataset_train: The training dataset. @@ -263,8 +169,10 @@ def __init__( # this line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt - self.save_hyperparameters(logger=False, ignore=["dataset_train", "dataset_val", "dataset_test"]) - + self.save_hyperparameters( + logger=False, + ignore=["dataset_train", "dataset_val", "dataset_test"], + ) self.dataset_train = dataset_train self.batch_size = batch_size @@ -324,24 +232,27 @@ def test_dataloader(self) -> DataLoader: collate_fn=collate_fn, ) - def teardown(self, stage: Optional[str] = None) -> None: - """Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, - `trainer.test()`, and `trainer.predict()`. + def teardown(self, stage: str | None = None) -> None: + """Lightning hook for cleaning up after `trainer.fit()`, + `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``. """ def state_dict(self) -> dict[Any, Any]: - """Called when saving a checkpoint. Implement to generate and save the datamodule state. + """Called when saving a checkpoint. Implement to generate and save the + datamodule state. - :return: A dictionary containing the datamodule state that you want to save. + :return: A dictionary containing the datamodule state that you want to + save. """ return {} def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Called when loading a checkpoint. Implement to reload datamodule state given datamodule - `state_dict()`. + """Called when loading a checkpoint. Implement to reload datamodule + state given datamodule `state_dict()`. - :param state_dict: The datamodule state returned by `self.state_dict()`. + :param state_dict: The datamodule state returned by + `self.state_dict()`. """ diff --git a/topobenchmarkx/data/datasets.py b/topobenchmarkx/data/datasets.py index 9d3c77ed..0ed9159e 100644 --- a/topobenchmarkx/data/datasets.py +++ b/topobenchmarkx/data/datasets.py @@ -9,6 +9,7 @@ class CustomDataset(torch_geometric.data.Dataset): data_lst: list List of torch_geometric.data.Data objects . """ + def __init__(self, data_lst): super().__init__() self.data_lst = data_lst @@ -32,6 +33,7 @@ def get(self, idx): def len(self): r"""Return length of the dataset. + Returns ------- int @@ -48,6 +50,7 @@ class TorchGeometricDataset(torch_geometric.data.Dataset): data_lst: list List of torch_geometric.data.Data objects . """ + def __init__(self, data_lst): super().__init__() self.data_lst = data_lst @@ -70,6 +73,7 @@ def get(self, idx): def len(self): r"""Return length of the dataset. + Returns ------- int diff --git a/topobenchmarkx/data/heteriphilic_dataset.py b/topobenchmarkx/data/heteriphilic_dataset.py index 8a729440..073d8b65 100644 --- a/topobenchmarkx/data/heteriphilic_dataset.py +++ b/topobenchmarkx/data/heteriphilic_dataset.py @@ -1,19 +1,21 @@ import os.path as osp from collections.abc import Callable -from typing import Optional, ClassVar +from typing import ClassVar import torch from omegaconf import DictConfig from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs -from topobenchmarkx.io.load.heterophilic import download_hetero_datasets, load_heterophilic_data +from topobenchmarkx.io.load.heterophilic import ( + download_hetero_datasets, + load_heterophilic_data, +) from topobenchmarkx.io.load.split_utils import random_splitting class HeteroDataset(InMemoryDataset): - r""" - Dataset class for US County Demographics dataset. + r"""Dataset class for US County Demographics dataset. Args: root (str): Root directory where the dataset will be saved. @@ -40,7 +42,6 @@ class HeteroDataset(InMemoryDataset): URLS (dict): Dictionary containing the URLs for downloading the dataset. FILE_FORMAT (dict): Dictionary containing the file formats for the dataset. RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset. - """ RAW_FILE_NAMES: ClassVar = {} @@ -50,17 +51,21 @@ def __init__( root: str, name: str, parameters: DictConfig, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, + transform: Callable | None = None, + pre_transform: Callable | None = None, + pre_filter: Callable | None = None, force_reload: bool = True, use_node_attr: bool = False, use_edge_attr: bool = False, ) -> None: - self.name = name #.replace("_", "-") + self.name = name # .replace("_", "-") self.parameters = parameters super().__init__( - root, transform, pre_transform, pre_filter, force_reload=force_reload + root, + transform, + pre_transform, + pre_filter, + force_reload=force_reload, ) # Step 3:Load the processed data @@ -97,15 +102,15 @@ def processed_dir(self) -> str: @property def processed_file_names(self) -> str: return "data.pt" - + @property def raw_file_names(self) -> list[str]: - """Spefify the downloaded raw fine name""" + """Spefify the downloaded raw fine name.""" return [f"{self.name}.npz"] def download(self) -> None: - """ - Downloads the dataset from the specified URL and saves it to the raw directory. + """Downloads the dataset from the specified URL and saves it to the raw + directory. Raises: FileNotFoundError: If the dataset URL is not found. @@ -115,8 +120,7 @@ def download(self) -> None: download_hetero_datasets(name=self.name, path=self.raw_dir) def process(self) -> None: - """ - Process the data for the dataset. + """Process the data for the dataset. This method loads the US county demographics data, applies any pre-processing transformations if specified, and saves the processed data to the appropriate location. @@ -124,7 +128,7 @@ def process(self) -> None: Returns: None """ - + data = load_heterophilic_data(name=self.name, path=self.raw_dir) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) diff --git a/topobenchmarkx/data/us_county_demos_dataset.py b/topobenchmarkx/data/us_county_demos_dataset.py index 4cd162b3..118e7816 100644 --- a/topobenchmarkx/data/us_county_demos_dataset.py +++ b/topobenchmarkx/data/us_county_demos_dataset.py @@ -1,20 +1,19 @@ import os.path as osp from collections.abc import Callable -from typing import Optional, ClassVar +from typing import ClassVar import torch from omegaconf import DictConfig from torch_geometric.data import Data, InMemoryDataset from torch_geometric.io import fs -from topobenchmarkx.io.load.us_county_demos import load_us_county_demos from topobenchmarkx.io.load.download_utils import download_file_from_drive from topobenchmarkx.io.load.split_utils import random_splitting +from topobenchmarkx.io.load.us_county_demos import load_us_county_demos class USCountyDemosDataset(InMemoryDataset): - r""" - Dataset class for US County Demographics dataset. + r"""Dataset class for US County Demographics dataset. Args: root (str): Root directory where the dataset will be saved. @@ -41,7 +40,6 @@ class USCountyDemosDataset(InMemoryDataset): URLS (dict): Dictionary containing the URLs for downloading the dataset. FILE_FORMAT (dict): Dictionary containing the file formats for the dataset. RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset. - """ URLS: ClassVar = { @@ -61,9 +59,9 @@ def __init__( root: str, name: str, parameters: DictConfig, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, - pre_filter: Optional[Callable] = None, + transform: Callable | None = None, + pre_transform: Callable | None = None, + pre_filter: Callable | None = None, force_reload: bool = True, use_node_attr: bool = False, use_edge_attr: bool = False, @@ -71,7 +69,11 @@ def __init__( self.name = name.replace("_", "-") self.parameters = parameters super().__init__( - root, transform, pre_transform, pre_filter, force_reload=force_reload + root, + transform, + pre_transform, + pre_filter, + force_reload=force_reload, ) # Step 3:Load the processed data @@ -121,8 +123,8 @@ def processed_file_names(self) -> str: return "data.pt" def download(self) -> None: - """ - Downloads the dataset from the specified URL and saves it to the raw directory. + """Downloads the dataset from the specified URL and saves it to the raw + directory. Raises: FileNotFoundError: If the dataset URL is not found. @@ -141,7 +143,9 @@ def download(self) -> None: # Extract the downloaded file if it is compressed fs.cp( - f"{self.raw_dir}/{self.name}.{self.file_format}", self.raw_dir, extract=True + f"{self.raw_dir}/{self.name}.{self.file_format}", + self.raw_dir, + extract=True, ) # Move the etracted files to the datasets/domain/dataset_name/raw/ directory @@ -153,8 +157,7 @@ def download(self) -> None: fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}") def process(self) -> None: - """ - Process the data for the dataset. + """Process the data for the dataset. This method loads the US county demographics data, applies any pre-processing transformations if specified, and saves the processed data to the appropriate location. @@ -163,7 +166,9 @@ def process(self) -> None: None """ data = load_us_county_demos( - self.raw_dir, year=self.parameters.year, y_col=self.parameters.task_variable + self.raw_dir, + year=self.parameters.year, + y_col=self.parameters.task_variable, ) data = data if self.pre_transform is None else self.pre_transform(data) diff --git a/topobenchmarkx/eval.py b/topobenchmarkx/eval.py index 492a2c54..b5c1fd5b 100755 --- a/topobenchmarkx/eval.py +++ b/topobenchmarkx/eval.py @@ -5,7 +5,6 @@ from lightning import LightningDataModule, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig - from src.utils import ( RankedLogger, extras, @@ -39,11 +38,13 @@ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: """Evaluates given checkpoint on a datamodule testset. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. + This method is wrapped in optional @task_wrapper decorator, that controls + the behavior during failure. Useful for multiruns, saving info about the + crash, etc. :param cfg: DictConfig configuration composed by Hydra. - :return: tuple[dict, dict] with metrics and dict with all instantiated objects. + :return: tuple[dict, dict] with metrics and dict with all instantiated + objects. """ assert cfg.ckpt_path @@ -82,7 +83,9 @@ def evaluate(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: return metric_dict, object_dict -@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml") +@hydra.main( + version_base="1.3", config_path="../configs", config_name="eval.yaml" +) def main(cfg: DictConfig) -> None: """Main entry point for evaluation. diff --git a/topobenchmarkx/evaluators/comparisons.py b/topobenchmarkx/evaluators/comparisons.py index ee316f70..27b09b50 100644 --- a/topobenchmarkx/evaluators/comparisons.py +++ b/topobenchmarkx/evaluators/comparisons.py @@ -3,12 +3,13 @@ def signed_ranks_test(result1, result2): - """ - Calculates the p-value for the Wilcoxon signed-rank test between the results of two models. + """Calculates the p-value for the Wilcoxon signed-rank test between the + results of two models. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html - :param results: A 2xN numpy array with the results from the two models. N is the number of datasets over which the models have been tested on. + :param results: A 2xN numpy array with the results from the two models. N + is the number of datasets over which the models have been tested on. :return: The p-value of the test """ xs = result1 - result2 @@ -16,12 +17,13 @@ def signed_ranks_test(result1, result2): def friedman_test(results): - """ - Calculates the p-value of the Friedman test between M models on N datasets. + """Calculates the p-value of the Friedman test between M models on N + datasets. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.friedmanchisquare.html - :param results: A MxN numpy array with the results of M models over N dataset + :param results: A MxN numpy array with the results of M models over N + dataset :return: The p-value of the test """ res = [r for r in results] @@ -29,14 +31,17 @@ def friedman_test(results): def compare_models(results, p_limit=0.05, verbose=False): - """ - Compares different models. First it uses the Friedman test to check that the models are significantly different, then it uses pairwise comparisons to study the ranking of the models. + """Compares different models. First it uses the Friedman test to check that + the models are significantly different, then it uses pairwise comparisons + to study the ranking of the models. - :param results: A MxN numpy array with the results of M models over N dataset + :param results: A MxN numpy array with the results of M models over N + dataset :param p_limit: The limit below which a hypothesis is considered false :param verbose: Whether to print the results of the tests or not :return average_rank: The average ranks of the models - :return groups: List of lists with the groups of models that are statistically similar + :return groups: List of lists with the groups of models that are + statistically similar """ M = results.shape[0] @@ -56,8 +61,12 @@ def compare_models(results, p_limit=0.05, verbose=False): model_idx = np.where(np.argsort(average_ranks) == i)[0][0] group = [model_idx] while i + idx < M: - next_model_idx = np.where(np.argsort(average_ranks) == i + idx)[0][0] - p = signed_ranks_test(results[model_idx, :], results[model_idx + idx, :]) + next_model_idx = np.where(np.argsort(average_ranks) == i + idx)[0][ + 0 + ] + p = signed_ranks_test( + results[model_idx, :], results[model_idx + idx, :] + ) if verbose: print( f"P-value for Wilcoxon test between models {model_idx} and {next_model_idx}: {p}" diff --git a/topobenchmarkx/evaluators/evaluator.py b/topobenchmarkx/evaluators/evaluator.py index aa5eb670..35cec944 100755 --- a/topobenchmarkx/evaluators/evaluator.py +++ b/topobenchmarkx/evaluators/evaluator.py @@ -6,7 +6,8 @@ class TorchEvaluator: - r"""Evaluator class that is responsible for computing the metrics for a given task. + r"""Evaluator class that is responsible for computing the metrics for a + given task. Parameters ---------- @@ -24,7 +25,6 @@ class TorchEvaluator: In "regression" scenario, the following arguments are expected: - regression_metrics : list A list of regression metrics to be computed. - """ def __init__(self, task, **kwargs): @@ -43,7 +43,7 @@ def __init__(self, task, **kwargs): elif self.task == "multilabel classification": parameters = {"num_classes": kwargs["num_classes"]} parameters["task"] = "multilabel" - metric_names = kwargs["classification_metrics"] + metric_names = kwargs["classification_metrics"] elif self.task == "regression": parameters = {} @@ -57,16 +57,13 @@ def __init__(self, task, **kwargs): # ) metrics = {} - for name in metric_names: + for name in metric_names: if name in ["recall", "precision", "auroc"]: - metrics[name] = METRICS[name](average='macro', **parameters) - + metrics[name] = METRICS[name](average="macro", **parameters) + else: metrics[name] = METRICS[name](**parameters) self.metrics = MetricCollection(metrics) - - - self.best_metric = {} @@ -81,7 +78,6 @@ def update(self, model_out: dict): The model predictions. - labels : torch.Tensor The ground truth labels. - """ preds = model_out["logits"].cpu() target = model_out["labels"].cpu() @@ -111,12 +107,17 @@ def compute(self): def reset( self, ): - """Reset the metrics. This method should be called after each epoch""" + """Reset the metrics. + + This method should be called after each epoch + """ self.metrics.reset() if __name__ == "__main__": evaluator = TorchEvaluator( - task="classification", num_classes=3, classification_metrics=["accuracy"] + task="classification", + num_classes=3, + classification_metrics=["accuracy"], ) print(evaluator.task) diff --git a/topobenchmarkx/io/load/download_utils.py b/topobenchmarkx/io/load/download_utils.py index f6c66f96..6ffca758 100644 --- a/topobenchmarkx/io/load/download_utils.py +++ b/topobenchmarkx/io/load/download_utils.py @@ -5,8 +5,7 @@ # Function to extract file ID from Google Drive URL def get_file_id_from_url(url): - """ - Extracts the file ID from a Google Drive file URL. + """Extracts the file ID from a Google Drive file URL. Args: url (str): The Google Drive file URL. @@ -21,10 +20,14 @@ def get_file_id_from_url(url): query_params = parse_qs(parsed_url.query) if "id" in query_params: # Case 1: URL format contains '?id=' file_id = query_params["id"][0] - elif "file/d/" in parsed_url.path: # Case 2: URL format contains '/file/d/' + elif ( + "file/d/" in parsed_url.path + ): # Case 2: URL format contains '/file/d/' file_id = parsed_url.path.split("/")[3] else: - raise ValueError("The provided URL is not a valid Google Drive file URL.") + raise ValueError( + "The provided URL is not a valid Google Drive file URL." + ) return file_id @@ -32,8 +35,8 @@ def get_file_id_from_url(url): def download_file_from_drive( file_link, path_to_save, dataset_name, file_format="tar.gz" ): - """ - Downloads a file from a Google Drive link and saves it to the specified path. + """Downloads a file from a Google Drive link and saves it to the specified + path. Args: file_link (str): The Google Drive link of the file to download. diff --git a/topobenchmarkx/io/load/heterophilic.py b/topobenchmarkx/io/load/heterophilic.py index c7901130..f8b9c7b8 100644 --- a/topobenchmarkx/io/load/heterophilic.py +++ b/topobenchmarkx/io/load/heterophilic.py @@ -1,35 +1,39 @@ -import numpy as np import os +import urllib.request + +import numpy as np import torch import torch_geometric -import urllib.request def load_heterophilic_data(name, path): - file_name = f'{name}.npz' + file_name = f"{name}.npz" data = np.load(os.path.join(path, file_name)) - x = torch.tensor(data['node_features']) - y = torch.tensor(data['node_labels']) - edge_index = torch.tensor(data['edges']).T + x = torch.tensor(data["node_features"]) + y = torch.tensor(data["node_labels"]) + edge_index = torch.tensor(data["edges"]).T # Make edge_index undirected edge_index = torch_geometric.utils.to_undirected(edge_index) # Remove self-loops edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index) - + data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index) return data + def download_hetero_datasets(name, path): - url = 'https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/' - name = f'{name}.npz' + url = "https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/" + name = f"{name}.npz" try: - print(f'Downloading {name}') + print(f"Downloading {name}") path2save = os.path.join(path, name) urllib.request.urlretrieve(url + name, path2save) - print('Done!') + print("Done!") except Exception as e: - raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''') from e \ No newline at end of file + raise Exception( + """Download failed! Make sure you have stable Internet connection and enter the right name""" + ) from e diff --git a/topobenchmarkx/io/load/loaders.py b/topobenchmarkx/io/load/loaders.py index 3cb9d0e5..f24eb30d 100755 --- a/topobenchmarkx/io/load/loaders.py +++ b/topobenchmarkx/io/load/loaders.py @@ -8,8 +8,8 @@ from omegaconf import DictConfig from topobenchmarkx.data.datasets import CustomDataset -from topobenchmarkx.data.us_county_demos_dataset import USCountyDemosDataset from topobenchmarkx.data.heteriphilic_dataset import HeteroDataset +from topobenchmarkx.data.us_county_demos_dataset import USCountyDemosDataset from topobenchmarkx.io.load.loader import AbstractLoader from topobenchmarkx.io.load.preprocessor import Preprocessor from topobenchmarkx.io.load.split_utils import ( @@ -163,7 +163,9 @@ def load(self) -> CustomDataset: name=self.parameters["data_name"], ) if self.transforms_config is not None: - dataset = Preprocessor(data_dir, dataset, self.transforms_config) + dataset = Preprocessor( + data_dir, dataset, self.transforms_config + ) dataset = load_graph_cocitation_split(dataset, self.parameters) @@ -184,16 +186,19 @@ def load(self) -> CustomDataset: use_node_attr=False, ) if self.transforms_config is not None: - dataset = Preprocessor(data_dir, dataset, self.transforms_config) + dataset = Preprocessor( + data_dir, dataset, self.transforms_config + ) dataset = load_graph_tudataset_split(dataset, self.parameters) elif self.parameters.data_name in ["ZINC"]: datasets = [ - torch_geometric.datasets.ZINC( - root=self.parameters["data_dir"], - subset=True, - split=split, - ) for split in ["train", "val", "test"] + torch_geometric.datasets.ZINC( + root=self.parameters["data_dir"], + subset=True, + split=split, + ) + for split in ["train", "val", "test"] ] assert self.parameters.split_type == "fixed" @@ -221,7 +226,9 @@ def load(self) -> CustomDataset: ) # Split back the into train/val/test datasets - dataset = assing_train_val_test_mask_to_graphs(joined_dataset, split_idx) + dataset = assing_train_val_test_mask_to_graphs( + joined_dataset, split_idx + ) elif self.parameters.data_name in ["AQSOL"]: datasets = [] @@ -256,7 +263,9 @@ def load(self) -> CustomDataset: ) # Split back the into train/val/test datasets - dataset = assing_train_val_test_mask_to_graphs(joined_dataset, split_idx) + dataset = assing_train_val_test_mask_to_graphs( + joined_dataset, split_idx + ) elif self.parameters.data_name in ["US-county-demos"]: dataset = USCountyDemosDataset( @@ -268,13 +277,22 @@ def load(self) -> CustomDataset: if self.transforms_config is not None: # force_reload=True because in this datasets many variables can be trated as y dataset = Preprocessor( - data_dir, dataset, self.transforms_config, force_reload=True + data_dir, + dataset, + self.transforms_config, + force_reload=True, ) # We need to map original dataset into custom one to make batching work dataset = CustomDataset([dataset[0]]) - - elif self.parameters.data_name in ["amazon_ratings", "questions", "minesweeper","roman_empire", "tolokers"]: + + elif self.parameters.data_name in [ + "amazon_ratings", + "questions", + "minesweeper", + "roman_empire", + "tolokers", + ]: dataset = HeteroDataset( root=self.parameters["data_dir"], name=self.parameters["data_name"], @@ -284,7 +302,10 @@ def load(self) -> CustomDataset: if self.transforms_config is not None: # force_reload=True because in this datasets many variables can be trated as y dataset = Preprocessor( - data_dir, dataset, self.transforms_config, force_reload=True + data_dir, + dataset, + self.transforms_config, + force_reload=True, ) # We need to map original dataset into custom one to make batching work @@ -331,7 +352,9 @@ def load(self) -> CustomDataset: data_dir = os.path.join( self.parameters["data_dir"], self.parameters["data_name"] ) - processor_dataset = Preprocessor(data_dir, data, self.transforms_config) + processor_dataset = Preprocessor( + data_dir, data, self.transforms_config + ) dataset = CustomDataset([processor_dataset[0]]) return dataset @@ -367,7 +390,7 @@ def manual_graph(): for tetrahedron in tetrahedrons: for i in range(len(tetrahedron)): for j in range(i + 1, len(tetrahedron)): - edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 + edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 # Create a graph G = nx.Graph() @@ -381,7 +404,11 @@ def manual_graph(): edge_list = torch.Tensor(list(G.edges())).T.long() # Generate feature from 0 to 9 - x = torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000, 10000]).unsqueeze(1).float() + x = ( + torch.tensor([1, 5, 10, 50, 100, 500, 1000, 5000, 10000]) + .unsqueeze(1) + .float() + ) data = torch_geometric.data.Data( x=x, @@ -419,7 +446,7 @@ def manual_simple_graph(): for tetrahedron in tetrahedrons: for i in range(len(tetrahedron)): for j in range(i + 1, len(tetrahedron)): - edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 + edges.append([tetrahedron[i], tetrahedron[j]]) # noqa: PERF401 # Create a graph G = nx.Graph() diff --git a/topobenchmarkx/io/load/preprocessor.py b/topobenchmarkx/io/load/preprocessor.py index fe22e9e0..feb331e2 100644 --- a/topobenchmarkx/io/load/preprocessor.py +++ b/topobenchmarkx/io/load/preprocessor.py @@ -23,14 +23,21 @@ class Preprocessor(torch_geometric.data.InMemoryDataset): """ def __init__( - self, data_dir, data_list, transforms_config, force_reload=False, **kwargs + self, + data_dir, + data_list, + transforms_config, + force_reload=False, + **kwargs, ): if isinstance(data_list, torch_geometric.data.Dataset): data_list = [data_list.get(idx) for idx in range(len(data_list))] elif isinstance(data_list, torch_geometric.data.Data): data_list = [data_list] self.data_list = data_list - pre_transform = self.instantiate_pre_transform(data_dir, transforms_config) + pre_transform = self.instantiate_pre_transform( + data_dir, transforms_config + ) super().__init__( self.processed_data_dir, None, @@ -84,7 +91,9 @@ def instantiate_pre_transform( pre_transforms = torch_geometric.transforms.Compose( list(pre_transforms_dict.values()) ) - self.set_processed_data_dir(pre_transforms_dict, data_dir, transforms_config) + self.set_processed_data_dir( + pre_transforms_dict, data_dir, transforms_config + ) return pre_transforms def set_processed_data_dir( @@ -109,7 +118,9 @@ def set_processed_data_dir( } params_hash = make_hash(transforms_parameters) self.transforms_parameters = ensure_serializable(transforms_parameters) - self.processed_data_dir = os.path.join(*[data_dir, repo_name, f"{params_hash}"]) + self.processed_data_dir = os.path.join( + *[data_dir, repo_name, f"{params_hash}"] + ) def save_transform_parameters(self) -> None: r"""Save the transform parameters.""" @@ -126,9 +137,13 @@ def save_transform_parameters(self) -> None: saved_transform_parameters = json.load(f) if saved_transform_parameters != self.transforms_parameters: - raise ValueError("Different transform parameters for the same data_dir") - - print(f"Transform parameters are the same, using existing data_dir: {self.processed_data_dir}") + raise ValueError( + "Different transform parameters for the same data_dir" + ) + + print( + f"Transform parameters are the same, using existing data_dir: {self.processed_data_dir}" + ) def process(self) -> None: r"""Process the data.""" diff --git a/topobenchmarkx/io/load/split_utils.py b/topobenchmarkx/io/load/split_utils.py index 5d201c4f..ce662679 100644 --- a/topobenchmarkx/io/load/split_utils.py +++ b/topobenchmarkx/io/load/split_utils.py @@ -9,9 +9,9 @@ # Generate splits in different fasions def k_fold_split(labels, parameters): - """ - Returns train and valid indices as in K-Fold Cross-Validation. If the split already exists - it loads it automatically, otherwise it creates the split file for the subsequent runs. + """Returns train and valid indices as in K-Fold Cross-Validation. If the + split already exists it loads it automatically, otherwise it creates the + split file for the subsequent runs. Parameters ---------- @@ -48,13 +48,22 @@ def k_fold_split(labels, parameters): skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42) - for fold_n, (train_idx, valid_idx) in enumerate(skf.split(x_idx, labels)): - split_idx = {"train": train_idx, "valid": valid_idx, "test": valid_idx} + for fold_n, (train_idx, valid_idx) in enumerate( + skf.split(x_idx, labels) + ): + split_idx = { + "train": train_idx, + "valid": valid_idx, + "test": valid_idx, + } # Check that all nodes/graph have been assigned to some split assert np.all( np.sort( - np.array(split_idx["train"].tolist() + split_idx["valid"].tolist()) + np.array( + split_idx["train"].tolist() + + split_idx["valid"].tolist() + ) ) == np.sort(np.arange(len(labels))) ), "Not every sample has been loaded." @@ -282,8 +291,8 @@ def load_graph_tudataset_split(dataset, cfg): f"split_type {cfg.split_type} not valid. Choose either 'test' or 'k-fold'" ) - train_dataset, val_dataset, test_dataset = assing_train_val_test_mask_to_graphs( - dataset, split_idx + train_dataset, val_dataset, test_dataset = ( + assing_train_val_test_mask_to_graphs(dataset, split_idx) ) return [train_dataset, val_dataset, test_dataset] diff --git a/topobenchmarkx/io/load/us_county_demos.py b/topobenchmarkx/io/load/us_county_demos.py index fbbed4b9..516694d0 100644 --- a/topobenchmarkx/io/load/us_county_demos.py +++ b/topobenchmarkx/io/load/us_county_demos.py @@ -5,8 +5,8 @@ def load_us_county_demos(path, year=2012, y_col="Election"): - r"""Load US County Demos dataset - + r"""Load US County Demos dataset. + Parameters ---------- path: str @@ -15,15 +15,17 @@ def load_us_county_demos(path, year=2012, y_col="Election"): Year to load the features. y_col: str Column to use as label. - + Returns ------- torch_geometric.data.Data Data object of the graph for the US County Demos dataset. """ - + edges_df = pd.read_csv(f"{path}/county_graph.csv") - stat = pd.read_csv(f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1") + stat = pd.read_csv( + f"{path}/county_stats_{year}.csv", encoding="ISO-8859-1" + ) keep_cols = [ "FIPS", @@ -36,12 +38,12 @@ def load_us_county_demos(path, year=2012, y_col="Election"): "BachelorRate", "UnemploymentRate", ] - + # Select columns, replace ',' with '.' and convert to numeric stat = stat.loc[:, keep_cols] - stat["MedianIncome"] = stat["MedianIncome"].replace(',','.', regex=True) - stat = stat.apply(pd.to_numeric, errors='coerce') - + stat["MedianIncome"] = stat["MedianIncome"].replace(",", ".", regex=True) + stat = stat.apply(pd.to_numeric, errors="coerce") + # Step 2: Substitute NaN values with column mean for column in stat.columns: if column != "FIPS": @@ -58,7 +60,9 @@ def load_us_county_demos(path, year=2012, y_col="Election"): edges_df = edges_df[src_ & dst_] # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present - stat = stat[stat["FIPS"].isin(edges_df["SRC"]) & stat["FIPS"].isin(edges_df["DST"])] + stat = stat[ + stat["FIPS"].isin(edges_df["SRC"]) & stat["FIPS"].isin(edges_df["DST"]) + ] stat = stat.reset_index(drop=True) # Remove rows where SRC == DST @@ -91,7 +95,9 @@ def load_us_county_demos(path, year=2012, y_col="Election"): ) # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically) - edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index) + edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes( + edge_index + ) # Conver mask to index index = np.arange(mask.size(0))[mask] @@ -104,7 +110,9 @@ def load_us_county_demos(path, year=2012, y_col="Election"): stat["FIPS"] = stat.reset_index()["index"] # Create Election variable - stat["Election"] = (stat["DEM"] - stat["GOP"]) / (stat["DEM"] + stat["GOP"]) + stat["Election"] = (stat["DEM"] - stat["GOP"]) / ( + stat["DEM"] + stat["GOP"] + ) # Drop DEM and GOP columns and FIPS stat = stat.drop(columns=["DEM", "GOP", "FIPS"]) diff --git a/topobenchmarkx/io/load/utils.py b/topobenchmarkx/io/load/utils.py index 46690ccc..eb5414eb 100755 --- a/topobenchmarkx/io/load/utils.py +++ b/topobenchmarkx/io/load/utils.py @@ -54,13 +54,15 @@ def get_complex_connectivity(complex, max_rank, signed=False): if connectivity_info == "incidence": connectivity[f"{connectivity_info}_{rank_idx}"] = ( generate_zero_sparse_connectivity( - m=practical_shape[rank_idx - 1], n=practical_shape[rank_idx] + m=practical_shape[rank_idx - 1], + n=practical_shape[rank_idx], ) ) else: connectivity[f"{connectivity_info}_{rank_idx}"] = ( generate_zero_sparse_connectivity( - m=practical_shape[rank_idx], n=practical_shape[rank_idx] + m=practical_shape[rank_idx], + n=practical_shape[rank_idx], ) ) connectivity["shape"] = practical_shape @@ -237,7 +239,9 @@ def load_hypergraph_pickle_dataset(cfg): # check that every node is in some hyperedge if len(np.unique(node_list)) != num_nodes: # add self hyperedges to isolated nodes - isolated_nodes = np.setdiff1d(np.arange(num_nodes), np.unique(node_list)) + isolated_nodes = np.setdiff1d( + np.arange(num_nodes), np.unique(node_list) + ) for node in isolated_nodes: node_list += [node] diff --git a/topobenchmarkx/models/abstractions/encoder.py b/topobenchmarkx/models/abstractions/encoder.py index fbe3189f..6a5e4617 100644 --- a/topobenchmarkx/models/abstractions/encoder.py +++ b/topobenchmarkx/models/abstractions/encoder.py @@ -1,22 +1,25 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod import torch import torch_geometric + class AbstractInitFeaturesEncoder(torch.nn.Module): - """abstract class that provides an interface to define a custom initial feature encoders""" + """Abstract class that provides an interface to define a custom initial + feature encoders.""" def __init__(self): return @abstractmethod - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - """Forward pass of the feature encoder model + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + """Forward pass of the feature encoder model. Parameters: :data: torch_geometric.data.Data Returns: :data: torch_geometric.data.Data - """ diff --git a/topobenchmarkx/models/encoders/default_encoders.py b/topobenchmarkx/models/encoders/default_encoders.py index 775015c2..4ac0b8ac 100644 --- a/topobenchmarkx/models/encoders/default_encoders.py +++ b/topobenchmarkx/models/encoders/default_encoders.py @@ -2,13 +2,15 @@ import torch_geometric from torch_geometric.nn.norm import GraphNorm -from topobenchmarkx.models.abstractions.encoder import AbstractInitFeaturesEncoder -from topobenchmarkx.models.encoders.perceiver import Perceiver +from topobenchmarkx.models.abstractions.encoder import ( + AbstractInitFeaturesEncoder, +) class BaseEncoder(torch.nn.Module): - r"""Encoder class that uses two linear layers with GraphNorm, Relu activation function, and dropout between the two layers. - + r"""Encoder class that uses two linear layers with GraphNorm, Relu + activation function, and dropout between the two layers. + Parameters ---------- in_channels: int @@ -18,6 +20,7 @@ class BaseEncoder(torch.nn.Module): dropout: float Percentage of channels to discard between the two linear layers. """ + def __init__(self, in_channels, out_channels, dropout=0): super().__init__() self.linear1 = torch.nn.Linear(in_channels, out_channels) @@ -27,16 +30,15 @@ def __init__(self, in_channels, out_channels, dropout=0): self.dropout = torch.nn.Dropout(dropout) def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: - r""" - Forward pass - + r"""Forward pass. + Parameters ---------- x: torch.Tensor Input tensor of dimensions [N, in_channels]. batch: torch.Tensor The batch vector which assigns each element to a specific example. - + Returns ------- torch.Tensor @@ -50,8 +52,9 @@ def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: class BaseFeatureEncoder(AbstractInitFeaturesEncoder): - r"""Encoder class to apply BaseEncoder to the features of higher order structures. - + r"""Encoder class to apply BaseEncoder to the features of higher order + structures. + Parameters ---------- in_channels: list(int) @@ -63,8 +66,13 @@ class BaseFeatureEncoder(AbstractInitFeaturesEncoder): selected_dimensions: list(int) List of indexes to apply the BaseEncoders to. """ + def __init__( - self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None + self, + in_channels, + out_channels, + proj_dropout=0, + selected_dimensions=None, ): super(AbstractInitFeaturesEncoder, self).__init__() self.in_channels = in_channels @@ -79,19 +87,22 @@ def __init__( self, f"encoder_{i}", BaseEncoder( - self.in_channels[i], self.out_channels, dropout=proj_dropout + self.in_channels[i], + self.out_channels, + dropout=proj_dropout, ), ) - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - r""" - Forward pass - + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: + r"""Forward pass. + Parameters ---------- data: torch_geometric.data.Data Input data object which should contain x_{i} features for each i in the selected_dimensions. - + Returns ------- torch_geometric.data.Data @@ -103,85 +114,88 @@ def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: for i in self.dimensions: if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): batch = getattr(data, f"batch_{i}") - data[f"x_{i}"] = getattr(self, f"encoder_{i}")(data[f"x_{i}"], batch) + data[f"x_{i}"] = getattr(self, f"encoder_{i}")( + data[f"x_{i}"], batch + ) return data -class SetFeatureEncoder(AbstractInitFeaturesEncoder): - r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures. - - Parameters - ---------- - in_channels: list(int) - Input dimensions for the features. - out_channels: list(int) - Output dimensions for the features. - proj_dropout: float - Dropout for the BaseEncoders. - selected_dimensions: list(int) - List of indexes to apply the BaseEncoders to. - """ - def __init__( - self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None - ): - super(AbstractInitFeaturesEncoder, self).__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.dimensions = ( - selected_dimensions - if selected_dimensions is not None - else range(len(self.in_channels)) - ) - for idx, i in enumerate(self.dimensions): - if idx == 0: - setattr( - self, - f"encoder_{i}", - BaseEncoder( - self.in_channels[i], self.out_channels, dropout=proj_dropout - ), - ) - else: - setattr( - self, - f"encoder_{i}", - Perceiver( - dim=self.out_channels, - depth=1, - cross_heads=4, - cross_dim_head=self.out_channels, - latent_dim_head=self.out_channels, - ), - ) +# from topobenchmarkx.models.encoders.perceiver import Perceiver +# class SetFeatureEncoder(AbstractInitFeaturesEncoder): +# r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures. - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: - r""" - Forward pass - - Parameters - ---------- - data: torch_geometric.data.Data - Input data object which should contain x_{i} features for each i in the selected_dimensions. - - Returns - ------- - torch_geometric.data.Data - Output data object. - """ - if not hasattr(data, "x_0"): - data.x_0 = data.x +# Parameters +# ---------- +# in_channels: list(int) +# Input dimensions for the features. +# out_channels: list(int) +# Output dimensions for the features. +# proj_dropout: float +# Dropout for the BaseEncoders. +# selected_dimensions: list(int) +# List of indexes to apply the BaseEncoders to. +# """ +# def __init__( +# self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None +# ): +# super(AbstractInitFeaturesEncoder, self).__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.dimensions = ( +# selected_dimensions +# if selected_dimensions is not None +# else range(len(self.in_channels)) +# ) +# for idx, i in enumerate(self.dimensions): +# if idx == 0: +# setattr( +# self, +# f"encoder_{i}", +# BaseEncoder( +# self.in_channels[i], self.out_channels, dropout=proj_dropout +# ), +# ) +# else: +# setattr( +# self, +# f"encoder_{i}", +# Perceiver( +# dim=self.out_channels, +# depth=1, +# cross_heads=4, +# cross_dim_head=self.out_channels, +# latent_dim_head=self.out_channels, +# ), +# ) - for idx, i in enumerate(self.dimensions): - if idx == 0: - if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): - batch = data.batch if i == 0 else getattr(data, f"batch_{i}") - data[f"x_{i}"] = getattr(self, f"encoder_{i}")( - data[f"x_{i}"], batch - ) - else: - if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): - cell_features = data["x_0"][data[f"x_{i}"].long()] - data[f"x_{i}"] = getattr(self, f"encoder_{i}")(cell_features) - else: - data[f"x_{i}"] = torch.tensor([], device=data.x_0.device) - return data +# def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: +# r""" +# Forward pass + +# Parameters +# ---------- +# data: torch_geometric.data.Data +# Input data object which should contain x_{i} features for each i in the selected_dimensions. + +# Returns +# ------- +# torch_geometric.data.Data +# Output data object. +# """ +# if not hasattr(data, "x_0"): +# data.x_0 = data.x + +# for idx, i in enumerate(self.dimensions): +# if idx == 0: +# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): +# batch = data.batch if i == 0 else getattr(data, f"batch_{i}") +# data[f"x_{i}"] = getattr(self, f"encoder_{i}")( +# data[f"x_{i}"], batch +# ) +# else: +# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): +# cell_features = data["x_0"][data[f"x_{i}"].long()] +# data[f"x_{i}"] = getattr(self, f"encoder_{i}")(cell_features) +# else: +# data[f"x_{i}"] = torch.tensor([], device=data.x_0.device) +# return data diff --git a/topobenchmarkx/models/encoders/perceiver.py b/topobenchmarkx/models/encoders/perceiver.py index 2ca12cca..fd218b62 100644 --- a/topobenchmarkx/models/encoders/perceiver.py +++ b/topobenchmarkx/models/encoders/perceiver.py @@ -65,7 +65,7 @@ def cached_fn(*args, _cache=True, **kwargs): class PreNorm(nn.Module): r"""Class to wrap together LayerNorm and a specified function. - + Parameters ---------- dim: int @@ -75,22 +75,25 @@ class PreNorm(nn.Module): context_dim: int Size of the context to normalize. """ + def __init__(self, dim, fn, context_dim=None): super().__init__() self.fn = fn self.norm = nn.LayerNorm(dim) - self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None + self.norm_context = ( + nn.LayerNorm(context_dim) if exists(context_dim) else None + ) def forward(self, x, **kwargs): r"""Forward pass. - + Parameters ---------- x: torch.Tensor Input tensor. kwargs: dict Dictionary of keyword arguments. - + Returns ------- torch.Tensor @@ -108,9 +111,10 @@ def forward(self, x, **kwargs): class GEGLU(nn.Module): r"""GEGLU activation function.""" + def forward(self, x): r"""Forward pass. - + Parameters ---------- x: torch.Tensor @@ -119,9 +123,10 @@ def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) + class FeedForward(nn.Module): r"""Feedforward network. - + Parameters ---------- dim: int @@ -129,6 +134,7 @@ class FeedForward(nn.Module): mult: int Multiplier for the hidden dimension. """ + def __init__(self, dim, mult=4): super().__init__() self.net = nn.Sequential( @@ -137,7 +143,7 @@ def __init__(self, dim, mult=4): def forward(self, x): r"""Forward pass. - + Parameters ---------- x: torch.Tensor @@ -148,7 +154,7 @@ def forward(self, x): class Attention(nn.Module): r"""Attention function. - + Parameters ---------- query_dim: int @@ -160,6 +166,7 @@ class Attention(nn.Module): dim_head: int Size for each head. """ + def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): super().__init__() inner_dim = dim_head * heads @@ -173,7 +180,7 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): def forward(self, x, context=None, mask=None): r"""Forward pass. - + Parameters ---------- x: torch.Tensor @@ -182,7 +189,7 @@ def forward(self, x, context=None, mask=None): Context tensor. mask: torch.Tensor Mask for attention calculation purposes. - + Returns ------- torch.Tensor @@ -194,7 +201,9 @@ def forward(self, x, context=None, mask=None): context = default(context, x) k, v = self.to_kv(context).chunk(2, dim=-1) - q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) + q, k, v = map( + lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v) + ) sim = einsum("b i d, b j d -> b i j", q, k) * self.scale @@ -217,7 +226,7 @@ def forward(self, x, context=None, mask=None): class Perceiver(nn.Module): r"""Perceiver model. - + Parameters ---------- depth: int @@ -239,6 +248,7 @@ class Perceiver(nn.Module): decoder_ff: bool Whether to use a feedforward network in the decoder. """ + def __init__( self, *, @@ -267,7 +277,10 @@ def __init__( PreNorm( latent_dim, Attention( - latent_dim, dim, heads=cross_heads, dim_head=cross_dim_head + latent_dim, + dim, + heads=cross_heads, + dim_head=cross_dim_head, ), context_dim=dim, ), @@ -275,16 +288,20 @@ def __init__( ] ) - def get_latent_attn(): + def get_latent_attn(): return PreNorm( latent_dim, - Attention(latent_dim, heads=latent_heads, dim_head=latent_dim_head), + Attention( + latent_dim, heads=latent_heads, dim_head=latent_dim_head + ), ) - def get_latent_ff(): + def get_latent_ff(): return PreNorm(latent_dim, FeedForward(latent_dim)) - - get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff)) + + get_latent_attn, get_latent_ff = map( + cache_fn, (get_latent_attn, get_latent_ff) + ) self.layers = nn.ModuleList([]) cache_args = {"_cache": weight_tie_layers} @@ -292,19 +309,27 @@ def get_latent_ff(): for _ in range(depth): self.layers.append( nn.ModuleList( - [get_latent_attn(**cache_args), get_latent_ff(**cache_args)] + [ + get_latent_attn(**cache_args), + get_latent_ff(**cache_args), + ] ) ) self.decoder_cross_attn = PreNorm( queries_dim, Attention( - queries_dim, latent_dim, heads=cross_heads, dim_head=cross_dim_head + queries_dim, + latent_dim, + heads=cross_heads, + dim_head=cross_dim_head, ), context_dim=latent_dim, ) self.decoder_ff = ( - PreNorm(queries_dim, FeedForward(queries_dim)) if decoder_ff else None + PreNorm(queries_dim, FeedForward(queries_dim)) + if decoder_ff + else None ) # self.to_logits = ( @@ -313,7 +338,7 @@ def get_latent_ff(): def forward(self, data, mask=None, queries=None): r"""Forward pass. - + Parameters ---------- data: torch.Tensor @@ -365,4 +390,4 @@ def forward(self, data, mask=None, queries=None): # final linear out # return x #self.to_logits(latents) - return + return None diff --git a/topobenchmarkx/models/head_model/models.py b/topobenchmarkx/models/head_model/models.py index 2da2f79e..0c63f0be 100644 --- a/topobenchmarkx/models/head_model/models.py +++ b/topobenchmarkx/models/head_model/models.py @@ -4,7 +4,7 @@ class DefaultHead(torch.nn.Module): r"""Head model. - + Parameters ---------- in_channels: int @@ -16,6 +16,7 @@ class DefaultHead(torch.nn.Module): pooling_type: str Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. """ + def __init__( self, in_channels: int, @@ -31,15 +32,15 @@ def __init__( assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" self.pooling_type = pooling_type - + def forward(self, model_out: dict): r"""Forward pass. - + Parameters ---------- model_out: dict Dictionary containing the model output. - + Returns ------- dict diff --git a/topobenchmarkx/models/losses/loss.py b/topobenchmarkx/models/losses/loss.py index 42a6d47d..8517082d 100755 --- a/topobenchmarkx/models/losses/loss.py +++ b/topobenchmarkx/models/losses/loss.py @@ -1,11 +1,9 @@ import torch -# import hydra -# from omegaconf import DictConfig - class DefaultLoss: - """Abstract class that provides an interface to loss logic within netowrk""" + """Abstract class that provides an interface to loss logic within + netowrk.""" def __init__(self, task, loss_type=None): self.task = task @@ -22,7 +20,7 @@ def __init__(self, task, loss_type=None): raise Exception("Loss is not defined") def __call__(self, model_output): - """Loss logic based on model_output""" + """Loss logic based on model_output.""" logits = model_output["logits"] target = model_output["labels"] @@ -33,51 +31,3 @@ def __call__(self, model_output): model_output["loss"] = self.criterion(logits, target) return model_output - - -# class NodeTaskLoss: -# """Abstract class that provides an interface to loss logic within netowrk""" - -# def __init__(self, task): -# if task == "classification": -# self.criterion = torch.nn.CrossEntropyLoss() - -# elif task == "regression": -# self.criterion == torch.nn.mse() - -# else: -# raise Exception("Loss is not defined") - -# def __call__(self, model_output): -# """Loss logic based on model_output""" - -# logits = model_output["logits"] -# target = model_output["labels"] -# model_output["loss"] = self.criterion(logits, target) - -# return model_output - - -# from abc import ABC, abstractmethod - -# import hydra -# from omegaconf import DictConfig - -# # logger = logging.getLogger(__name__) - - -# class AbstractLoss(ABC): -# """Abstract class that provides an interface to loss logic within netowrk""" - -# def __init__(self, cfg: DictConfig): -# self.cfg = cfg - -# @abstractmethod -# def init_loss( -# self, -# ): -# """Initialize loss""" - -# @abstractmethod -# def forward(self, model_output): -# """Loss logic based on model_output""" diff --git a/topobenchmarkx/models/network_module.py b/topobenchmarkx/models/network_module.py index 61b2ee8e..ee3b5dcf 100755 --- a/topobenchmarkx/models/network_module.py +++ b/topobenchmarkx/models/network_module.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any import torch from lightning import LightningModule @@ -21,7 +21,7 @@ def __init__( readout: torch.nn.Module, head_model: torch.nn.Module, loss: torch.nn.Module, - feature_encoder: Union[torch.nn.Module, None] = None, + feature_encoder: torch.nn.Module | None = None, **kwargs, ) -> None: """Initialize a `NetworkModule`. @@ -36,16 +36,13 @@ def __init__( # This line allows to access init params with 'self.hparams' attribute # also ensures init params will be stored in ckpt - self.save_hyperparameters( - logger=False, - ignore=[] - ) + self.save_hyperparameters(logger=False, ignore=[]) self.feature_encoder = feature_encoder self.backbone = backbone_wrapper(backbone) self.readout = readout self.head_model = head_model - + # Evaluator self.evaluator = None self.train_metrics_logged = False @@ -68,7 +65,9 @@ def forward(self, batch) -> dict: """ return self.backbone(batch) - def model_step(self, batch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def model_step( + self, batch + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform a single model step on a batch of data. :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. @@ -78,7 +77,7 @@ def model_step(self, batch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - A tensor of predictions. - A tensor of target labels. """ - # Pipeline + # Pipeline if self.feature_encoder: batch = self.feature_encoder(batch) @@ -94,10 +93,11 @@ def model_step(self, batch) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return model_out def training_step(self, batch, batch_idx: int) -> torch.Tensor: - """Perform a single training step on a batch of data from the training set. + """Perform a single training step on a batch of data from the training + set. - :param batch: A batch of data (a tuple) containing the input tensor of images and target - labels. + :param batch: A batch of data (a tuple) containing the input tensor of + images and target labels. :param batch_idx: The index of the current batch. :return: A tensor of losses between model predictions and targets. """ @@ -120,10 +120,11 @@ def training_step(self, batch, batch_idx: int) -> torch.Tensor: def validation_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> None: - """Perform a single validation step on a batch of data from the validation set. + """Perform a single validation step on a batch of data from the + validation set. - :param batch: A batch of data (a tuple) containing the input tensor of images and target - labels. + :param batch: A batch of data (a tuple) containing the input tensor of + images and target labels. :param batch_idx: The index of the current batch. """ self.state_str = "Validation" @@ -158,8 +159,8 @@ def test_step( ) -> None: """Perform a single test step on a batch of data from the test set. - :param batch: A batch of data (a tuple) containing the input tensor of images and target - labels. + :param batch: A batch of data (a tuple) containing the input tensor of + images and target labels. :param batch_idx: The index of the current batch. """ self.state_str = "Test" @@ -224,24 +225,24 @@ def log_metrics(self, mode=None): self.evaluator.reset() def on_validation_epoch_start(self) -> None: - """According pytorch lightning documentation, this hook is called at the beginning of the validation epoch. + """According pytorch lightning documentation, this hook is called at + the beginning of the validation epoch. https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks Note that the validation step is within the train epoch. Hence here we have to log the train metrics before we reset the evaluator to start the validation loop. """ - + # Log train metrics and reset evaluator self.log_metrics(mode="train") self.train_metrics_logged = True - + def on_train_epoch_end(self) -> None: # Log train metrics and reset evaluator if not self.train_metrics_logged: self.log_metrics(mode="train") self.train_metrics_logged = True - def on_validation_epoch_end(self) -> None: """Lightning hook that is called when a test epoch ends.""" @@ -267,11 +268,12 @@ def on_test_epoch_start(self) -> None: self.evaluator.reset() def setup(self, stage: str) -> None: - """Lightning hook that is called at the beginning of fit (train + validate), validate, - test, or predict. + """Lightning hook that is called at the beginning of fit (train + + validate), validate, test, or predict. - This is a good hook when you need to build models dynamically or adjust something about - them. This hook is called on every process when using DDP. + This is a good hook when you need to build models dynamically or adjust + something about them. This hook is called on every process when using + DDP. :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. """ @@ -279,8 +281,9 @@ def setup(self, stage: str) -> None: self.net = torch.compile(self.net) def configure_optimizers(self) -> dict[str, Any]: - """Choose what optimizers and learning-rate schedulers to use in your optimization. - Normally you'd need one. But in the case of GANs or similar you might have multiple. + """Choose what optimizers and learning-rate schedulers to use in your + optimization. Normally you'd need one. But in the case of GANs or + similar you might have multiple. Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers diff --git a/topobenchmarkx/models/readouts/__init__.py b/topobenchmarkx/models/readouts/__init__.py index e69de29b..3038e843 100644 --- a/topobenchmarkx/models/readouts/__init__.py +++ b/topobenchmarkx/models/readouts/__init__.py @@ -0,0 +1,26 @@ +from topobenchmarkx.models.readouts.propagate_signal_down import ( + PropagateSignalDown, +) + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherReadout1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherReadout2 + + +# Dictionary of all readouts +READOUTS = { + "PropagateSignalDown": PropagateSignalDown, + # "OtherReadout1": OtherReadout1, + # "OtherReadout2": OtherReadout2, + # ... add other readout mappings here +} + +# Export all readouts and the dictionary +__all__ = [ + "PropagateSignalDown", + # "OtherReadout1", + # "OtherReadout2", + # ... add other readout classes here + "READOUTS", +] diff --git a/topobenchmarkx/models/readouts/old_readout.py b/topobenchmarkx/models/readouts/old_readout.py deleted file mode 100755 index 9db97518..00000000 --- a/topobenchmarkx/models/readouts/old_readout.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from torch_geometric.utils import scatter - -from topobenchmarkx.models.abstractions.readout import AbstractReadOut - - -class GNNBatchReadOut(AbstractReadOut): - r"""Readout layer for GNNs that operates on the batch level. - - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. - """ - def __init__( - self, - in_channels: int, - out_channels: int, - task_level: str, - pooling_type: str = "sum", - ): - super(AbstractReadOut, self).__init__() - self.linear = torch.nn.Linear(in_channels, out_channels) - - assert task_level in ["graph", "node"], "Invalid task_level" - self.task_level = task_level - - assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" - self.pooling_type = pooling_type - - def forward(self, model_out: dict): - r"""Forward pass. - - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - - Returns - ------- - dict - Dictionary containing the updated model output. Resulting key is "logits". - """ - x = model_out["x_0"] - batch = model_out["batch"] - if self.task_level == "graph": - if self.pooling_type == "max": - x = scatter(x, batch, dim=0, reduce="max") - - elif self.pooling_type == "mean": - x = scatter(x, batch, dim=0, reduce="mean") - - elif self.pooling_type == "sum": - x = scatter(x, batch, dim=0, reduce="sum") - - model_out["logits"] = self.linear(x) - return model_out diff --git a/topobenchmarkx/models/readouts/readouts.py b/topobenchmarkx/models/readouts/propagate_signal_down.py similarity index 60% rename from topobenchmarkx/models/readouts/readouts.py rename to topobenchmarkx/models/readouts/propagate_signal_down.py index 2a340fd9..155d5a3b 100644 --- a/topobenchmarkx/models/readouts/readouts.py +++ b/topobenchmarkx/models/readouts/propagate_signal_down.py @@ -1,5 +1,5 @@ -import torch import topomodelx +import torch class PropagateSignalDown(torch.nn.Module): @@ -8,39 +8,35 @@ def __init__(self, **kwargs): self.dimensions = range(kwargs["num_cell_dimensions"] - 1, 0, -1) hidden_dim = kwargs["hidden_dim"] - + for i in self.dimensions: setattr( self, f"agg_conv_{i}", topomodelx.base.conv.Conv( - hidden_dim, - hidden_dim, - aggr_norm=False - ) + hidden_dim, hidden_dim, aggr_norm=False + ), ) - setattr( - self, - f"ln_{i}", - torch.nn.LayerNorm(hidden_dim) - ) + setattr(self, f"ln_{i}", torch.nn.LayerNorm(hidden_dim)) setattr( self, f"projector_{i}", - torch.nn.Linear(2*hidden_dim, hidden_dim) + torch.nn.Linear(2 * hidden_dim, hidden_dim), ) def __call__(self, model_out, batch): return self.forward(model_out, batch) - def forward(self, model_out, batch): - for i in self.dimensions: - x_i = getattr(self, f"agg_conv_{i}")(model_out[f"x_{i}"], batch[f"incidence_{i}"]) + for i in self.dimensions: + x_i = getattr(self, f"agg_conv_{i}")( + model_out[f"x_{i}"], batch[f"incidence_{i}"] + ) x_i = getattr(self, f"ln_{i}")(x_i) - model_out[f"x_{i-1}"] = getattr(self, f"projector_{i}")(torch.cat([x_i, model_out[f"x_{i-1}"]], dim=1)) - + model_out[f"x_{i-1}"] = getattr(self, f"projector_{i}")( + torch.cat([x_i, model_out[f"x_{i-1}"]], dim=1) + ) + return model_out - \ No newline at end of file diff --git a/topobenchmarkx/models/readouts/readout.py b/topobenchmarkx/models/readouts/readout.py index c2189d4d..7d618cab 100755 --- a/topobenchmarkx/models/readouts/readout.py +++ b/topobenchmarkx/models/readouts/readout.py @@ -1,18 +1,12 @@ import torch import torch_geometric -from torch_geometric.utils import scatter - -from topobenchmarkx.models.readouts.readouts import PropagateSignalDown -# Implemented Poolings -READOUTS = { - "PropagateSignalDown": PropagateSignalDown -} +from . import READOUTS class AbstractReadOut(torch.nn.Module): r"""Readout layer for GNNs that operates on the batch level. - + Parameters ---------- in_channels: int @@ -24,10 +18,8 @@ class AbstractReadOut(torch.nn.Module): pooling_type: str Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. """ - def __init__( - self, - **kwargs - ): + + def __init__(self, **kwargs): super().__init__() self.signal_readout = kwargs["readout_name"] != "None" @@ -37,12 +29,12 @@ def __init__( def forward(self, model_out: dict, batch: torch_geometric.data.Data): r"""Forward pass. - + Parameters ---------- model_out: dict Dictionary containing the model output. - + Returns ------- dict @@ -53,7 +45,3 @@ def forward(self, model_out: dict, batch: torch_geometric.data.Data): model_out = self.readout(model_out, batch) return model_out - - - - \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/__init__.py b/topobenchmarkx/models/wrappers/__init__.py index d9f57a40..705d292e 100755 --- a/topobenchmarkx/models/wrappers/__init__.py +++ b/topobenchmarkx/models/wrappers/__init__.py @@ -1,10 +1,11 @@ -import hydra # noqa: F401 +import hydra # noqa: F401 import torch -from omegaconf import DictConfig # noqa: F401 +from omegaconf import DictConfig # noqa: F401 class DefaultLoss: - """Abstract class that provides an interface to loss logic within netowrk""" + """Abstract class that provides an interface to loss logic within + netowrk.""" def __init__(self, task): if task == "classification": @@ -16,7 +17,7 @@ def __init__(self, task): raise Exception("Loss is not defined") def __call__(self, model_output): - """Loss logic based on model_output""" + """Loss logic based on model_output.""" logits = model_output["logits"] target = model_output["labels"] diff --git a/topobenchmarkx/models/wrappers/default_wrapper.py b/topobenchmarkx/models/wrappers/default_wrapper.py index b3afd5b0..405693d4 100755 --- a/topobenchmarkx/models/wrappers/default_wrapper.py +++ b/topobenchmarkx/models/wrappers/default_wrapper.py @@ -1,18 +1,16 @@ from abc import ABC, abstractmethod -import topomodelx import torch -from torch_geometric.nn.norm import GraphNorm -import torch.nn as nn - +import torch.nn as nn class DefaultWrapper(ABC, torch.nn.Module): - """Abstract class that provides an interface to handle the network output""" + """Abstract class that provides an interface to handle the network + output.""" def __init__(self, backbone, **kwargs): super().__init__() - self.backbone = backbone + self.backbone = backbone out_channels = kwargs["out_channels"] self.dimensions = range(kwargs["num_cell_dimensions"]) @@ -24,84 +22,100 @@ def __init__(self, backbone, **kwargs): ) def __call__(self, batch): - """Define logic for forward pass""" + """Define logic for forward pass.""" model_out = self.forward(batch) model_out = self.residual_connection(model_out=model_out, batch=batch) return model_out def residual_connection(self, model_out, batch): for i in self.dimensions: - if (f"x_{i}" in batch) and hasattr(self, f"ln_{i}") and (f"x_{i}" in model_out): + if ( + (f"x_{i}" in batch) + and hasattr(self, f"ln_{i}") + and (f"x_{i}" in model_out) + ): residual = model_out[f"x_{i}"] + batch[f"x_{i}"] model_out[f"x_{i}"] = getattr(self, f"ln_{i}")(residual) return model_out - + @abstractmethod def forward(self, batch): - """Define handling output here""" + """Define handling output here.""" + class GNNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" # def __init__(self, backbone, **kwargs): # super().__init__(backbone) def forward(self, batch): - """Define logic for forward pass""" + """Define logic for forward pass.""" x_0 = self.backbone(batch.x_0, batch.edge_index) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 - + return model_out class HypergraphWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" + """Define logic for forward pass.""" x_0, x_1 = self.backbone(batch.x_0, batch.incidence_hyperedges) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 model_out["hyperedge"] = x_1 - + return model_out class SANWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - x_1 = self.backbone(batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1) + """Define logic for forward pass.""" + x_1 = self.backbone( + batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1 + ) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) model_out["x_1"] = x_1 return model_out + class SCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - - + """Define logic for forward pass.""" + laplacian_0 = self.normalize_matrix(batch.hodge_laplacian_0) laplacian_1 = self.normalize_matrix(batch.hodge_laplacian_1) laplacian_2 = self.normalize_matrix(batch.hodge_laplacian_2) x_0, x_1, x_2 = self.backbone( - batch.x_0, batch.x_1, batch.x_2, laplacian_0, laplacian_1, laplacian_2 + batch.x_0, + batch.x_1, + batch.x_2, + laplacian_0, + laplacian_1, + laplacian_2, ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_2"] = x_2 model_out["x_1"] = x_1 model_out["x_0"] = x_0 return model_out - + def normalize_matrix(self, matrix): matrix_ = matrix.to_dense() n, _ = matrix_.shape @@ -117,14 +131,16 @@ def normalize_matrix(self, matrix): diag_indices, diag_sum, matrix_.shape, device=matrix.device ).coalesce() normalized_matrix = diag_matrix @ (matrix @ diag_matrix) - return normalized_matrix + return normalized_matrix + class SCCNNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_all = (batch.x_0, batch.x_1, batch.x_2) laplacian_all = ( batch.hodge_laplacian_0, @@ -136,22 +152,23 @@ def forward(self, batch): incidence_all = (batch.incidence_1, batch.incidence_2) x_0, x_1, x_2 = self.backbone(x_all, laplacian_all, incidence_all) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} - + model_out["x_0"] = x_0 model_out["x_1"] = x_1 model_out["x_2"] = x_2 - + return model_out class SCCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + features = { f"rank_{r}": batch[f"x_{r}"] for r in range(self.backbone.layers[0].max_rank + 1) @@ -169,30 +186,37 @@ def forward(self, batch): # TODO: First decide which strategy is the best then make code general model_out = {"labels": batch.y, "batch_0": batch.batch_0} if len(output) == 3: - x_0, x_1, x_2 = output["rank_0"], output["rank_1"], output["rank_2"] - + x_0, x_1, x_2 = ( + output["rank_0"], + output["rank_1"], + output["rank_2"], + ) + model_out["x_2"] = x_2 model_out["x_1"] = x_1 model_out["x_0"] = x_0 elif len(output) == 2: x_0, x_1 = output["rank_0"], output["rank_1"] - + model_out["x_1"] = x_1 model_out["x_0"] = x_0 - + else: - raise ValueError(f"Invalid number of output tensors: {len(output)}") + raise ValueError( + f"Invalid number of output tensors: {len(output)}" + ) return model_out class CANWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_1 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, @@ -200,7 +224,7 @@ def forward(self, batch): down_laplacian_1=batch.down_laplacian_1.coalesce(), up_laplacian_1=batch.up_laplacian_1.coalesce(), ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_1"] = x_1 model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) @@ -208,30 +232,32 @@ def forward(self, batch): class CWNDCMWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_1 = self.backbone( batch.x_1, batch.down_laplacian_1.coalesce(), batch.up_laplacian_1.coalesce(), ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} - + model_out["x_1"] = x_1 - model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) return model_out class CWNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, @@ -249,18 +275,19 @@ def forward(self, batch): class CCXNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within network""" + """Abstract class that provides an interface to loss logic within + network.""" def forward(self, batch): - """Define logic for forward pass""" - + """Define logic for forward pass.""" + x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, x_1=batch.x_1, adjacency_0=batch.adjacency_0, incidence_2_t=batch.incidence_2.T, ) - + model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 model_out["x_1"] = x_1 diff --git a/topobenchmarkx/play.ipynb b/topobenchmarkx/play.ipynb index 0f18b067..ac254ec3 100644 --- a/topobenchmarkx/play.ipynb +++ b/topobenchmarkx/play.ipynb @@ -7,8 +7,9 @@ "outputs": [], "source": [ "import torch\n", - "run_1 = torch.load('/home/lev/projects/TopoBenchmarkX/run_1')\n", - "run_2 = torch.load('/home/lev/projects/TopoBenchmarkX/run_2')\n" + "\n", + "run_1 = torch.load(\"/home/lev/projects/TopoBenchmarkX/run_1\")\n", + "run_2 = torch.load(\"/home/lev/projects/TopoBenchmarkX/run_2\")" ] }, { @@ -65,8 +66,7 @@ ], "source": [ "for key in run_1.keys():\n", - " print(f'{key} {(run_1[key] == run_2[key]).all()}')\n", - " " + " print(f\"{key} {(run_1[key] == run_2[key]).all()}\")" ] }, { @@ -76,7 +76,11 @@ "outputs": [], "source": [ "import numpy as np\n", - "run1 = np.load('/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz', allow_pickle=True)" + "\n", + "run1 = np.load(\n", + " \"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz\",\n", + " allow_pickle=True,\n", + ")" ] }, { @@ -85,7 +89,10 @@ "metadata": {}, "outputs": [], "source": [ - "run2 = np.load('/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz', allow_pickle=True)" + "run2 = np.load(\n", + " \"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos/train_prop=0.5_global_seed=42/0.npz\",\n", + " allow_pickle=True,\n", + ")" ] }, { diff --git a/topobenchmarkx/train.py b/topobenchmarkx/train.py index 33fa9e68..ff3430d3 100755 --- a/topobenchmarkx/train.py +++ b/topobenchmarkx/train.py @@ -1,25 +1,18 @@ -import numpy as np import random -from typing import Any, Optional +from typing import Any import hydra import lightning as L +import numpy as np import rootutils +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) import torch from lightning import Callback, LightningModule, Trainer from lightning.pytorch.loggers import Logger from omegaconf import DictConfig, OmegaConf -from topobenchmarkx.utils.config_resolvers import ( - get_default_transform, - get_monitor_metric, - get_monitor_mode, - infer_in_channels, - infere_list_length, -) - -from topobenchmarkx.data.dataloader_fullbatch import DefaultDataModule +from topobenchmarkx.data.dataloaders import DefaultDataModule from topobenchmarkx.utils import ( RankedLogger, extras, @@ -29,6 +22,13 @@ log_hyperparameters, task_wrapper, ) +from topobenchmarkx.utils.config_resolvers import ( + get_default_transform, + get_monitor_metric, + get_monitor_mode, + infer_in_channels, + infere_list_length, +) # ------------------------------------------------------------------------------------ # # the setup_root above is equivalent to: @@ -47,7 +47,6 @@ # more info: https://github.com/ashleve/rootutils # ------------------------------------------------------------------------------------ # -rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) OmegaConf.register_new_resolver("get_default_transform", get_default_transform) OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) @@ -64,18 +63,19 @@ @task_wrapper def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: - """Trains the model. Can additionally evaluate on a testset, using best weights obtained during - training. + """Trains the model. Can additionally evaluate on a testset, using best + weights obtained during training. - This method is wrapped in optional @task_wrapper decorator, that controls the behavior during - failure. Useful for multiruns, saving info about the crash, etc. + This method is wrapped in optional @task_wrapper decorator, that controls + the behavior during failure. Useful for multiruns, saving info about the + crash, etc. :param cfg: A DictConfig configuration composed by Hydra. :return: A tuple with metrics and dict with all instantiated objects. """ # Set seed for random number generators in pytorch, numpy and python.random - #if cfg.get("seed"): + # if cfg.get("seed"): L.seed_everything(cfg.seed, workers=True) # Seed for torch torch.manual_seed(cfg.seed) @@ -84,7 +84,6 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: # Seed for python random random.seed(cfg.seed) - # Instantiate and load dataset dataset = hydra.utils.instantiate(cfg.dataset, _recursive_=False) dataset = dataset.load() @@ -140,7 +139,9 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: if cfg.get("train"): log.info("Starting training!") - trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path")) + trainer.fit( + model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path") + ) train_metrics = trainer.callback_metrics @@ -148,7 +149,9 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: log.info("Starting testing!") ckpt_path = trainer.checkpoint_callback.best_model_path if ckpt_path == "": - log.warning("Best ckpt not found! Using current weights for testing...") + log.warning( + "Best ckpt not found! Using current weights for testing..." + ) ckpt_path = None trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) log.info(f"Best ckpt path: {ckpt_path}") @@ -164,8 +167,8 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: def count_number_of_parameters( model: torch.nn.Module, only_trainable: bool = True ) -> int: - """ - Counts the number of trainable params. If all params, specify only_trainable = False. + """Counts the number of trainable params. If all params, specify + only_trainable = False. Ref: - https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9?u=brando_miranda @@ -173,15 +176,19 @@ def count_number_of_parameters( :return: """ if only_trainable: - num_params: int = sum(p.numel() for p in model.parameters() if p.requires_grad) + num_params: int = sum( + p.numel() for p in model.parameters() if p.requires_grad + ) else: # counts trainable and none-traibale num_params: int = sum(p.numel() for p in model.parameters() if p) assert num_params > 0, f"Err: {num_params=}" return int(num_params) -@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml") -def main(cfg: DictConfig) -> Optional[float]: +@hydra.main( + version_base="1.3", config_path="../configs", config_name="train.yaml" +) +def main(cfg: DictConfig) -> float | None: """Main entry point for training. :param cfg: DictConfig configuration composed by Hydra. @@ -204,6 +211,4 @@ def main(cfg: DictConfig) -> Optional[float]: if __name__ == "__main__": - main() - diff --git a/topobenchmarkx/transforms/data_manipulations/manipulations.py b/topobenchmarkx/transforms/data_manipulations/manipulations.py index 3ae9d463..b75e1509 100644 --- a/topobenchmarkx/transforms/data_manipulations/manipulations.py +++ b/topobenchmarkx/transforms/data_manipulations/manipulations.py @@ -29,7 +29,8 @@ def forward(self, data: torch_geometric.data.Data): class InfereKNNConnectivity(torch_geometric.transforms.BaseTransform): - r"""A transform that generates the k-nearest neighbor connectivity of the input point cloud.""" + r"""A transform that generates the k-nearest neighbor connectivity of the + input point cloud.""" def __init__(self, **kwargs): super().__init__() @@ -38,6 +39,7 @@ def __init__(self, **kwargs): def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. + Parameters ---------- data : torch_geometric.data.Data @@ -56,7 +58,8 @@ def forward(self, data: torch_geometric.data.Data): class InfereRadiusConnectivity(torch_geometric.transforms.BaseTransform): - r"""A transform that generates the radius connectivity of the input point cloud.""" + r"""A transform that generates the radius connectivity of the input point + cloud.""" def __init__(self, **kwargs): super().__init__() @@ -65,6 +68,7 @@ def __init__(self, **kwargs): def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. + Parameters ---------- data : torch_geometric.data.Data @@ -79,7 +83,8 @@ def forward(self, data: torch_geometric.data.Data): class EqualGausFeatures(torch_geometric.transforms.BaseTransform): - r"""A transform that generates equal Gaussian features for all nodes in the input graph. + r"""A transform that generates equal Gaussian features for all nodes in the + input graph. Parameters ---------- @@ -173,9 +178,11 @@ def forward(self, data: torch_geometric.data.Data): torch_geometric.data.Data The transformed data. """ - field_to_process = [key for key in data - for field_substring in self.parameters["selected_fields"] - if field_substring in key and key != "incidence_0" + field_to_process = [ + key + for key in data + for field_substring in self.parameters["selected_fields"] + if field_substring in key and key != "incidence_0" ] for field in field_to_process: data = self.calculate_node_degrees(data, field) @@ -214,7 +221,9 @@ def calculate_node_degrees( ) if "incidence" in field: - field_name = str(int(field.split("_")[1]) - 1) + "_cell" + "_degrees" + field_name = ( + str(int(field.split("_")[1]) - 1) + "_cell" + "_degrees" + ) else: field_name = "node_degrees" @@ -223,8 +232,8 @@ def calculate_node_degrees( class KeepOnlyConnectedComponent(torch_geometric.transforms.BaseTransform): - """ - A transform that keeps only the largest connected components of the input graph. + """A transform that keeps only the largest connected components of the + input graph. Parameters ---------- @@ -238,8 +247,7 @@ def __init__(self, **kwargs): self.parameters = kwargs def forward(self, data: torch_geometric.data.Data): - """ - Apply the transform to the input data. + """Apply the transform to the input data. Parameters ---------- @@ -263,8 +271,7 @@ def forward(self, data: torch_geometric.data.Data): class CalculateSimplicialCurvature(torch_geometric.transforms.BaseTransform): - """ - A transform that calculates the simplicial curvature of the input graph. + """A transform that calculates the simplicial curvature of the input graph. Parameters ---------- @@ -278,8 +285,7 @@ def __init__(self, **kwargs): self.parameters = kwargs def forward(self, data: torch_geometric.data.Data): - """ - Apply the transform to the input data. + """Apply the transform to the input data. Parameters ---------- @@ -300,8 +306,7 @@ def zero_cell_curvature( self, data: torch_geometric.data.Data, ) -> torch_geometric.data.Data: - """ - Calculate the zero cell curvature of the input data. + """Calculate the zero cell curvature of the input data. Parameters ---------- @@ -364,7 +369,9 @@ def two_cell_curvature( idx = torch.where(data["2_cell_degrees"] > 1)[0] two_cell_degrees[idx] = 0 up = data["incidence_3"].to_dense() @ data["incidence_3"].to_dense().T - down = data["incidence_2"].to_dense().T @ data["incidence_2"].to_dense() + down = ( + data["incidence_2"].to_dense().T @ data["incidence_2"].to_dense() + ) mask = torch.eye(up.size()[0]).bool() up.masked_fill_(mask, 0) down.masked_fill_(mask, 0) @@ -375,7 +382,8 @@ def two_cell_curvature( class OneHotDegreeFeatures(torch_geometric.transforms.BaseTransform): - r"""A transform that adds the node degree as one hot encodings to the node features. + r"""A transform that adds the node degree as one hot encodings to the node + features. Parameters ---------- @@ -404,14 +412,16 @@ def forward(self, data: torch_geometric.data.Data): The transformed data. """ data = self.transform.forward( - data, degrees_field=self.deg_field, features_field=self.features_fields + data, + degrees_field=self.deg_field, + features_field=self.features_fields, ) return data class OneHotDegree(torch_geometric.transforms.BaseTransform): - r"""Adds the node degree as one hot encodings to the node features + r"""Adds the node degree as one hot encodings to the node features. Parameters ---------- @@ -430,7 +440,10 @@ def __init__( self.cat = cat def forward( - self, data: torch_geometric.data.Data, degrees_field: str, features_field: str + self, + data: torch_geometric.data.Data, + degrees_field: str, + features_field: str, ) -> torch_geometric.data.Data: r"""Apply the transform to the input data. @@ -499,7 +512,8 @@ def forward(self, data: torch_geometric.data.Data): """ # Keeps all the fields fields_to_keep = ( - self.parameters["base_fields"] + self.parameters["preserved_fields"] + self.parameters["base_fields"] + + self.parameters["preserved_fields"] ) # if len(self.parameters["keep_fields"]) == 1: # return data diff --git a/topobenchmarkx/transforms/data_transform.py b/topobenchmarkx/transforms/data_transform.py index e66c108b..89431637 100755 --- a/topobenchmarkx/transforms/data_transform.py +++ b/topobenchmarkx/transforms/data_transform.py @@ -57,7 +57,8 @@ class DataTransform(torch_geometric.transforms.BaseTransform): - """Abstract class that provides an interface to define a custom data lifting. + """Abstract class that provides an interface to define a custom data + lifting. Parameters ---------- @@ -74,10 +75,14 @@ def __init__(self, transform_name, **kwargs): self.parameters = kwargs self.transform = ( - TRANSFORMS[transform_name](**kwargs) if transform_name is not None else None + TRANSFORMS[transform_name](**kwargs) + if transform_name is not None + else None ) - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: """Forward pass of the lifting. Parameters diff --git a/topobenchmarkx/transforms/feature_liftings/feature_liftings.py b/topobenchmarkx/transforms/feature_liftings/feature_liftings.py index 3b7117f8..90c9f99e 100644 --- a/topobenchmarkx/transforms/feature_liftings/feature_liftings.py +++ b/topobenchmarkx/transforms/feature_liftings/feature_liftings.py @@ -17,7 +17,8 @@ def __init__(self, **kwargs): def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: - r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix. + r"""Projects r-cell features of a graph to r+1-cell structures using the + incidence matrix. Parameters ---------- @@ -27,8 +28,11 @@ def lift_features( Returns ------- torch_geometric.data.Data | dict - The lifted data.""" - keys = sorted([key.split("_")[1] for key in data if "incidence" in key]) + The lifted data. + """ + keys = sorted( + [key.split("_")[1] for key in data if "incidence" in key] + ) for elem in keys: if f"x_{elem}" not in data: idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 @@ -72,7 +76,8 @@ def __init__(self, **kwargs): def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: - r"""Concatenates r-cell features to r+1-cell structures using the incidence matrix. + r"""Concatenates r-cell features to r+1-cell structures using the + incidence matrix. Parameters ---------- @@ -82,9 +87,12 @@ def lift_features( Returns ------- torch_geometric.data.Data | dict - The lifted data.""" + The lifted data. + """ - keys = sorted([key.split("_")[1] for key in data if "incidence" in key]) + keys = sorted( + [key.split("_")[1] for key in data if "incidence" in key] + ) for elem in keys: if f"x_{elem}" not in data: idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 @@ -141,7 +149,8 @@ def __init__(self, **kwargs): def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: - r"""Concatenates r-cell features to r+1-cell structures using the incidence matrix. + r"""Concatenates r-cell features to r+1-cell structures using the + incidence matrix. Parameters ---------- @@ -151,12 +160,15 @@ def lift_features( Returns ------- torch_geometric.data.Data | dict - The lifted data.""" + The lifted data. + """ - keys = sorted([key.split("_")[1] for key in data if "incidence" in key]) + keys = sorted( + [key.split("_")[1] for key in data if "incidence" in key] + ) for elem in keys: if f"x_{elem}" not in data: - #idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 + # idx_to_project = 0 if elem == "hyperedges" else int(elem) - 1 incidence = data["incidence_" + elem] _, n = incidence.shape diff --git a/topobenchmarkx/transforms/liftings/graph2cell.py b/topobenchmarkx/transforms/liftings/graph2cell.py index ffff21ae..2f2224a4 100755 --- a/topobenchmarkx/transforms/liftings/graph2cell.py +++ b/topobenchmarkx/transforms/liftings/graph2cell.py @@ -46,7 +46,9 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: """ raise NotImplementedError - def _get_lifted_topology(self, cell_complex: CellComplex, graph: nx.Graph) -> dict: + def _get_lifted_topology( + self, cell_complex: CellComplex, graph: nx.Graph + ) -> dict: r"""Returns the lifted topology. Parameters @@ -61,7 +63,9 @@ def _get_lifted_topology(self, cell_complex: CellComplex, graph: nx.Graph) -> di dict The lifted topology. """ - lifted_topology = get_complex_connectivity(cell_complex, self.complex_dim) + lifted_topology = get_complex_connectivity( + cell_complex, self.complex_dim + ) lifted_topology["x_0"] = torch.stack( list(cell_complex.get_cell_attributes("features", 0).values()) ) @@ -112,7 +116,9 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: cycles = [cycle for cycle in cycles if len(cycle) != 1] # Eliminate cycles that are greater than the max_cell_lenght if self.max_cell_length is not None: - cycles = [cycle for cycle in cycles if len(cycle) <= self.max_cell_length] + cycles = [ + cycle for cycle in cycles if len(cycle) <= self.max_cell_length + ] if len(cycles) != 0: cell_complex.add_cells_from(cycles, rank=self.complex_dim) return self._get_lifted_topology(cell_complex, G) diff --git a/topobenchmarkx/transforms/liftings/graph2hypergraph.py b/topobenchmarkx/transforms/liftings/graph2hypergraph.py index 7ed3bd2f..0ac80a82 100755 --- a/topobenchmarkx/transforms/liftings/graph2hypergraph.py +++ b/topobenchmarkx/transforms/liftings/graph2hypergraph.py @@ -57,7 +57,8 @@ def __init__(self, k_value=1, **kwargs): self.k = k_value def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to hypergraph domain by considering k-hop neighborhoods. + r"""Lifts the topology of a graph to hypergraph domain by considering + k-hop neighborhoods. Parameters ---------- @@ -79,13 +80,17 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: edge_index = torch_geometric.utils.to_undirected(data.edge_index) # Detect isolated nodes - isolated_nodes = [i for i in range(num_nodes) if i not in edge_index[0]] + isolated_nodes = [ + i for i in range(num_nodes) if i not in edge_index[0] + ] if len(isolated_nodes) > 0: # Add completely isolated nodes to the edge_index edge_index = torch.cat( [ edge_index, - torch.tensor([isolated_nodes, isolated_nodes], dtype=torch.long), + torch.tensor( + [isolated_nodes, isolated_nodes], dtype=torch.long + ), ], dim=1, ) @@ -125,7 +130,8 @@ def __init__(self, k_value=1, loop=True, **kwargs): self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop) def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to hypergraph domain by considering k-nearest neighbors. + r"""Lifts the topology of a graph to hypergraph domain by considering + k-nearest neighbors. Parameters ---------- diff --git a/topobenchmarkx/transforms/liftings/graph2simplicial.py b/topobenchmarkx/transforms/liftings/graph2simplicial.py index b647b9db..036ba976 100755 --- a/topobenchmarkx/transforms/liftings/graph2simplicial.py +++ b/topobenchmarkx/transforms/liftings/graph2simplicial.py @@ -71,20 +71,29 @@ def _get_lifted_topology( simplicial_complex, self.complex_dim, signed=self.signed ) lifted_topology["x_0"] = torch.stack( - list(simplicial_complex.get_simplex_attributes("features", 0).values()) + list( + simplicial_complex.get_simplex_attributes( + "features", 0 + ).values() + ) ) # If new edges have been added during the lifting process, we discard the edge attributes if self.contains_edge_attr and simplicial_complex.shape[1] == ( graph.number_of_edges() ): lifted_topology["x_1"] = torch.stack( - list(simplicial_complex.get_simplex_attributes("features", 1).values()) + list( + simplicial_complex.get_simplex_attributes( + "features", 1 + ).values() + ) ) return lifted_topology class SimplicialNeighborhoodLifting(Graph2SimplicialLifting): - r"""Lifts graphs to simplicial complex domain by considering k-hop neighborhoods. + r"""Lifts graphs to simplicial complex domain by considering k-hop + neighborhoods. Parameters ---------- @@ -99,7 +108,8 @@ def __init__(self, max_k_simplices=5000, **kwargs): self.max_k_simplices = max_k_simplices def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to simplicial complex domain by considering k-hop neighborhoods. + r"""Lifts the topology of a graph to simplicial complex domain by + considering k-hop neighborhoods. Parameters ---------- @@ -117,7 +127,9 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: simplices = [set() for _ in range(2, self.complex_dim + 1)] for n in range(graph.number_of_nodes()): # Find 1-hop node n neighbors - neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph(n, 1, edge_index) + neighbors, _, _, _ = torch_geometric.utils.k_hop_subgraph( + n, 1, edge_index + ) if n not in neighbors: neighbors.append(n) neighbors = neighbors.numpy() @@ -135,7 +147,8 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: class SimplicialCliqueLifting(Graph2SimplicialLifting): - r"""Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices. + r"""Lifts graphs to simplicial complex domain by identifying the cliques as + k-simplices. Parameters ---------- @@ -147,7 +160,8 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. + r"""Lifts the topology of a graph to a simplicial complex by identifying + the cliques as k-simplices. Parameters ---------- diff --git a/topobenchmarkx/transforms/liftings/graph_lifting.py b/topobenchmarkx/transforms/liftings/graph_lifting.py index ec25308b..0392f2dc 100644 --- a/topobenchmarkx/transforms/liftings/graph_lifting.py +++ b/topobenchmarkx/transforms/liftings/graph_lifting.py @@ -19,7 +19,8 @@ class GraphLifting(torch_geometric.transforms.BaseTransform): - r"""Abstract class for lifting graph topologies to higher-order topological domains. + r"""Abstract class for lifting graph topologies to higher-order topological + domains. Parameters ---------- @@ -54,7 +55,9 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: """ raise NotImplementedError - def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: + def forward( + self, data: torch_geometric.data.Data + ) -> torch_geometric.data.Data: r"""Applies the full lifting (topology + features) to the input data. Parameters @@ -70,7 +73,9 @@ def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: initial_data = data.to_dict() lifted_topology = self.lift_topology(data) lifted_topology = self.feature_lifting(lifted_topology) - lifted_data = torch_geometric.data.Data(**initial_data, **lifted_topology) + lifted_data = torch_geometric.data.Data( + **initial_data, **lifted_topology + ) return lifted_data def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: @@ -88,7 +93,9 @@ def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: """ return hasattr(data, "edge_attr") and data.edge_attr is not None - def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph: + def _generate_graph_from_data( + self, data: torch_geometric.data.Data + ) -> nx.Graph: r"""Generates a NetworkX graph from the input data object. Parameters @@ -102,7 +109,10 @@ def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph The generated NetworkX graph. """ # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? - nodes = [(n, dict(features=data.x[n], dim=0)) for n in range(data.x.shape[0])] + nodes = [ + (n, dict(features=data.x[n], dim=0)) + for n in range(data.x.shape[0]) + ] if self.preserve_edge_attr and self._data_has_edge_attr(data): # In case edge features are given, assign features to every edge @@ -125,7 +135,9 @@ def _generate_graph_from_data(self, data: torch_geometric.data.Data) -> nx.Graph # If edge_attr is not present, return list list of edges edges = [ (i.item(), j.item()) - for i, j in zip(data.edge_index[0], data.edge_index[1], strict=False) + for i, j in zip( + data.edge_index[0], data.edge_index[1], strict=False + ) ] self.contains_edge_attr = False graph = nx.Graph() diff --git a/topobenchmarkx/utils/__init__.py b/topobenchmarkx/utils/__init__.py index cb6b0bc7..b2beffb3 100755 --- a/topobenchmarkx/utils/__init__.py +++ b/topobenchmarkx/utils/__init__.py @@ -1,8 +1,17 @@ from topobenchmarkx.utils.instantiators import ( instantiate_callbacks, # noqa: F401 - instantiate_loggers, # noqa: F401 + instantiate_loggers, # noqa: F401 +) +from topobenchmarkx.utils.logging_utils import ( + log_hyperparameters, # noqa: F401 ) -from topobenchmarkx.utils.logging_utils import log_hyperparameters # noqa: F401 from topobenchmarkx.utils.pylogger import RankedLogger # noqa: F401 -from topobenchmarkx.utils.rich_utils import enforce_tags, print_config_tree # noqa: F401 -from topobenchmarkx.utils.utils import extras, get_metric_value, task_wrapper # noqa: F401 +from topobenchmarkx.utils.rich_utils import ( + enforce_tags, + print_config_tree, +) +from topobenchmarkx.utils.utils import ( + extras, + get_metric_value, + task_wrapper, +) diff --git a/topobenchmarkx/utils/config_resolvers.py b/topobenchmarkx/utils/config_resolvers.py index 9af53491..f5aa1bf1 100644 --- a/topobenchmarkx/utils/config_resolvers.py +++ b/topobenchmarkx/utils/config_resolvers.py @@ -1,18 +1,18 @@ def get_default_transform(data_domain, model): r"""Get default transform for a given data domain and model. - + Parameters ---------- data_domain: str Data domain. model: str Model name. Should be in the format "model_domain/name". - + Returns ------- str Default transform. - + Raises ------ ValueError @@ -31,27 +31,25 @@ def get_default_transform(data_domain, model): def get_monitor_metric(task, metric): r"""Get monitor metric for a given task and loss. - + Parameters ---------- task: str Task, either "classification" or "regression". loss: str Name of the loss function. - + Returns ------- str Monitor metric. - + Raises ------ ValueError If the task is invalid. """ - if task == "classification": - return f"val/{metric}" - elif task == "regression": + if task == "classification" or task == "regression": return f"val/{metric}" else: raise ValueError(f"Invalid task {task}") @@ -59,17 +57,17 @@ def get_monitor_metric(task, metric): def get_monitor_mode(task): r"""Get monitor mode for a given task. - + Parameters ---------- task: str Task, either "classification" or "regression". - + Returns ------- str Monitor mode, either "max" or "min". - + Raises ------ ValueError @@ -85,25 +83,26 @@ def get_monitor_mode(task): def infer_in_channels(dataset): r"""Infer the number of input channels for a given dataset. - + Parameters ---------- dataset: torch_geometric.data.Dataset Input dataset. - + Returns ------- list List with dimensions of the input channels. """ + def find_complex_lifting(dataset): r"""Find if there is a complex lifting in the dataset. - + Parameters ---------- dataset: torch_geometric.data.Dataset Input dataset. - + Returns ------- bool @@ -125,14 +124,14 @@ def find_complex_lifting(dataset): def check_for_type_feature_lifting(dataset, lifting): r"""Check the type of feature lifting in the dataset. - + Parameters ---------- dataset: torch_geometric.data.Dataset Input dataset. lifting: str Name of the complex lifting. - + Returns ------- str @@ -171,19 +170,23 @@ def check_for_type_feature_lifting(dataset, lifting): else: if not dataset.transforms[lifting].preserve_edge_attr: if feature_lifting == "projection": - return [dataset.parameters.num_features[0]] * dataset.transforms[ - lifting - ].complex_dim + return [ + dataset.parameters.num_features[0] + ] * dataset.transforms[lifting].complex_dim elif feature_lifting == "concatenation": return_value = [dataset.parameters.num_features] - for i in range(2, dataset.transforms[lifting].complex_dim + 1): - return_value += [int(dataset.parameters.num_features * i)] + for i in range( + 2, dataset.transforms[lifting].complex_dim + 1 + ): + return_value += [ + int(dataset.parameters.num_features * i) + ] return return_value else: - return [dataset.parameters.num_features] * dataset.transforms[ - lifting - ].complex_dim + return [ + dataset.parameters.num_features + ] * dataset.transforms[lifting].complex_dim else: return list(dataset.parameters.num_features) + [ @@ -199,5 +202,6 @@ def check_for_type_feature_lifting(dataset, lifting): else: return [dataset.parameters.num_features[0]] + def infere_list_length(list): - return len(list) \ No newline at end of file + return len(list) diff --git a/topobenchmarkx/utils/instantiators.py b/topobenchmarkx/utils/instantiators.py index 3b94bf4e..2e01a1b6 100755 --- a/topobenchmarkx/utils/instantiators.py +++ b/topobenchmarkx/utils/instantiators.py @@ -11,7 +11,8 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: """Instantiates callbacks from config. - :param callbacks_cfg: A DictConfig object containing callback configurations. + :param callbacks_cfg: A DictConfig object containing callback + configurations. :return: A list of instantiated callbacks. """ callbacks: list[Callback] = [] diff --git a/topobenchmarkx/utils/pylogger.py b/topobenchmarkx/utils/pylogger.py index 31a76c37..3b8222fd 100755 --- a/topobenchmarkx/utils/pylogger.py +++ b/topobenchmarkx/utils/pylogger.py @@ -1,7 +1,10 @@ import logging -from typing import Mapping, Optional +from collections.abc import Mapping -from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only +from lightning_utilities.core.rank_zero import ( + rank_prefixed_message, + rank_zero_only, +) class RankedLogger(logging.LoggerAdapter): @@ -11,10 +14,10 @@ def __init__( self, name: str = __name__, rank_zero_only: bool = False, - extra: Optional[Mapping[str, object]] = None, + extra: Mapping[str, object] | None = None, ) -> None: - """Initializes a multi-GPU-friendly python command line logger that logs on all processes - with their rank prefixed in the log message. + """Initializes a multi-GPU-friendly python command line logger that + logs on all processes with their rank prefixed in the log message. :param name: The name of the logger. Default is ``__name__``. :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. @@ -25,11 +28,12 @@ def __init__( self.rank_zero_only = rank_zero_only def log( - self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs + self, level: int, msg: str, rank: int | None = None, *args, **kwargs ) -> None: - """Delegate a log call to the underlying logger, after prefixing its message with the rank - of the process it's being logged from. If `'rank'` is provided, then the log will only - occur on that rank/process. + """Delegate a log call to the underlying logger, after prefixing its + message with the rank of the process it's being logged from. If + `'rank'` is provided, then the log will only occur on that + rank/process. :param level: The level to log at. Look at `logging.__init__.py` for more information. :param msg: The message to log. @@ -49,7 +53,5 @@ def log( if current_rank == 0: self.logger.log(level, msg, *args, **kwargs) else: - if rank is None: - self.logger.log(level, msg, *args, **kwargs) - elif current_rank == rank: + if rank is None or current_rank == rank: self.logger.log(level, msg, *args, **kwargs) diff --git a/topobenchmarkx/utils/rich_utils.py b/topobenchmarkx/utils/rich_utils.py index 692b8158..6cf5080c 100755 --- a/topobenchmarkx/utils/rich_utils.py +++ b/topobenchmarkx/utils/rich_utils.py @@ -29,7 +29,8 @@ def print_config_tree( resolve: bool = False, save_to_file: bool = False, ) -> None: - """Prints the contents of a DictConfig as a tree structure using the Rich library. + """Prints the contents of a DictConfig as a tree structure using the Rich + library. :param cfg: A DictConfig composed by Hydra. :param print_order: Determines in what order config components are printed. Default is ``("data", "model", @@ -80,7 +81,8 @@ def print_config_tree( @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in config. + """Prompts user to input tags from command line if no tags are provided in + config. :param cfg: A DictConfig composed by Hydra. :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. @@ -89,8 +91,12 @@ def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: if "id" in HydraConfig().cfg.hydra.job: raise ValueError("Specify tags before launching a multirun!") - log.warning("No tags provided in config. Prompting user to input tags...") - tags = Prompt.ask("Enter a list of comma separated tags", default="dev") + log.warning( + "No tags provided in config. Prompting user to input tags..." + ) + tags = Prompt.ask( + "Enter a list of comma separated tags", default="dev" + ) tags = [t.strip() for t in tags.split(",") if t != ""] with open_dict(cfg): diff --git a/topobenchmarkx/utils/utils.py b/topobenchmarkx/utils/utils.py index 436a85ae..74fdfbb5 100755 --- a/topobenchmarkx/utils/utils.py +++ b/topobenchmarkx/utils/utils.py @@ -1,6 +1,7 @@ import warnings +from collections.abc import Callable from importlib.util import find_spec -from typing import Any, Callable, Optional +from typing import Any from omegaconf import DictConfig @@ -26,7 +27,9 @@ def extras(cfg: DictConfig) -> None: # disable python warnings if cfg.extras.get("ignore_warnings"): - log.info("Disabling python warnings! ") + log.info( + "Disabling python warnings! " + ) warnings.filterwarnings("ignore") # prompt user to input tags from command line if none are provided in the config @@ -36,12 +39,15 @@ def extras(cfg: DictConfig) -> None: # pretty print config tree using Rich library if cfg.extras.get("print_config"): - log.info("Printing config tree with Rich! ") + log.info( + "Printing config tree with Rich! " + ) rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the task function. + """Optional decorator that controls the failure behavior when executing the + task function. This wrapper can be used to: - make sure loggers are closed even if the task function raises an exception (prevents multirun failure) @@ -96,8 +102,8 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: def get_metric_value( - metric_dict: dict[str, Any], metric_name: Optional[str] -) -> Optional[float]: + metric_dict: dict[str, Any], metric_name: str | None +) -> float | None: """Safely retrieves value of the metric logged in LightningModule. :param metric_dict: A dict containing metric values. diff --git a/tutorials/add_new_dataset.ipynb b/tutorials/add_new_dataset.ipynb index 75cc9e18..2486d5d9 100644 --- a/tutorials/add_new_dataset.ipynb +++ b/tutorials/add_new_dataset.ipynb @@ -227,9 +227,7 @@ "source": [ "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", - "import torch\n", "from omegaconf import DictConfig\n", "from torch_geometric.data import Data, InMemoryDataset\n", "from torch_geometric.io import fs\n", @@ -286,9 +284,9 @@ " root: str,\n", " name: str,\n", " parameters: DictConfig,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " ) -> None:\n", " # Assign the class variables that would be needed for steps 1, 2, 4, and 3\n", @@ -405,57 +403,50 @@ "metadata": {}, "outputs": [], "source": [ - "import numpy as np\n", "import os\n", - "import torch\n", - "import torch_geometric\n", "import urllib.request\n", "\n", "\n", - "def hetero_load(name, path='./data/hetero_data'):\n", - " file_name = f'{name}.npz'\n", + "def hetero_load(name, path=\"./data/hetero_data\"):\n", + " file_name = f\"{name}.npz\"\n", "\n", " data = np.load(os.path.join(path, file_name))\n", "\n", - " x = torch.tensor(data['node_features'])\n", - " y = torch.tensor(data['node_labels'])\n", - " edge_index = torch.tensor(data['edges']).T\n", + " x = torch.tensor(data[\"node_features\"])\n", + " y = torch.tensor(data[\"node_labels\"])\n", + " edge_index = torch.tensor(data[\"edges\"]).T\n", "\n", " # Make edge_index undirected\n", " edge_index = torch_geometric.utils.to_undirected(edge_index)\n", "\n", " # Remove self-loops\n", " edge_index, _ = torch_geometric.utils.remove_self_loops(edge_index)\n", - " \n", + "\n", " data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)\n", " return data\n", "\n", + "\n", "def download_hetero_datasets(name, path):\n", - " url = 'https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/'\n", - " name = f'{name}.npz'\n", + " url = \"https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/\"\n", + " name = f\"{name}.npz\"\n", " try:\n", - " print(f'Downloading {name}')\n", + " print(f\"Downloading {name}\")\n", " path2save = os.path.join(path, name)\n", " urllib.request.urlretrieve(url + name, path2save)\n", - " print('Done!')\n", + " print(\"Done!\")\n", " except:\n", - " raise Exception('''Download failed! Make sure you have stable Internet connection and enter the right name''')\n", - "\n", + " raise Exception(\n", + " \"\"\"Download failed! Make sure you have stable Internet connection and enter the right name\"\"\"\n", + " )\n", "\n", "\n", - "import os.path as osp\n", "from collections.abc import Callable\n", - "from typing import Optional\n", "\n", - "import torch\n", "from omegaconf import DictConfig\n", - "from torch_geometric.data import Data, InMemoryDataset\n", - "from torch_geometric.io import fs\n", + "from torch_geometric.data import InMemoryDataset\n", "\n", "from topobenchmarkx.io.load.us_county_demos import load_us_county_demos\n", "\n", - "from topobenchmarkx.io.load.split_utils import random_splitting\n", - "\n", "\n", "class HeteroDataset(InMemoryDataset):\n", " r\"\"\"\n", @@ -496,14 +487,14 @@ " root: str,\n", " name: str,\n", " parameters: DictConfig,\n", - " transform: Optional[Callable] = None,\n", - " pre_transform: Optional[Callable] = None,\n", - " pre_filter: Optional[Callable] = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", " force_reload: bool = True,\n", " use_node_attr: bool = False,\n", " use_edge_attr: bool = False,\n", " ) -> None:\n", - " self.name = name #.replace(\"_\", \"-\")\n", + " self.name = name # .replace(\"_\", \"-\")\n", " self.parameters = parameters\n", " super().__init__(\n", " root, transform, pre_transform, pre_filter, force_reload=force_reload\n", @@ -542,7 +533,7 @@ " @property\n", " def processed_file_names(self) -> str:\n", " return \"data.pt\"\n", - " \n", + "\n", " @property\n", " def raw_file_names(self) -> list[str]:\n", " \"\"\"Spefify the downloaded raw fine name\"\"\"\n", @@ -569,7 +560,7 @@ " Returns:\n", " None\n", " \"\"\"\n", - " \n", + "\n", " data = hetero_load(name=self.name, path=self.raw_dir)\n", " data = data if self.pre_transform is None else self.pre_transform(data)\n", " self.save([data], self.processed_paths[0])\n", @@ -578,23 +569,24 @@ " return f\"{self.name}()\"\n", "\n", "\n", + "data_dir = \"/home/lev/projects/TopoBenchmarkX/datasets\"\n", + "data_domain = \"graph\"\n", + "data_type = \"heterophilic\"\n", + "data_name = \"amazon_ratings\"\n", "\n", - "data_dir = '/home/lev/projects/TopoBenchmarkX/datasets'\n", - "data_domain = 'graph'\n", - "data_type = 'heterophilic'\n", - "data_name = 'amazon_ratings'\n", - "\n", - "data_dir = f'{data_dir}/{data_domain}/{data_type}'\n", + "data_dir = f\"{data_dir}/{data_domain}/{data_type}\"\n", "\n", - "parameters={\n", - " 'split_type': 'random',\n", - " 'k': 10,\n", - " 'train_prop': 0.5,\n", - " 'data_seed':0,\n", - " 'data_split_dir': f'/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}'\n", - " }\n", + "parameters = {\n", + " \"split_type\": \"random\",\n", + " \"k\": 10,\n", + " \"train_prop\": 0.5,\n", + " \"data_seed\": 0,\n", + " \"data_split_dir\": f\"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/{data_name}\",\n", + "}\n", "\n", - "dataset = HeteroDataset(name=data_name, root = data_dir, parameters=parameters, force_reload=True)" + "dataset = HeteroDataset(\n", + " name=data_name, root=data_dir, parameters=parameters, force_reload=True\n", + ")" ] } ], From 74fcc6820eb0e1ea8d82e81f8c5f3a189dc377ca Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 01:56:53 +0200 Subject: [PATCH 02/32] updated wrapper modularity --- configs/model/cell/can.yaml | 3 +- configs/model/cell/ccxn.yaml | 3 +- configs/model/cell/cwn.yaml | 3 +- configs/model/cell/cwn_dcm.yaml | 3 +- configs/model/graph/gat.yaml | 4 +- configs/model/graph/gcn.yaml | 3 +- configs/model/graph/gin.yaml | 4 +- configs/model/hypergraph/alldeepset.yaml | 3 +- .../model/hypergraph/allsettransformer.yaml | 3 +- configs/model/hypergraph/edgnn.yaml | 3 +- configs/model/hypergraph/unignn.yaml | 3 +- configs/model/hypergraph/unignn2.yaml | 3 +- configs/model/simplicial/san.yaml | 3 +- configs/model/simplicial/sccn.yaml | 3 +- configs/model/simplicial/sccnn.yaml | 3 +- configs/model/simplicial/sccnn_custom.yaml | 3 +- configs/model/simplicial/scn.yaml | 3 +- configs/train.yaml | 2 +- topobenchmarkx/models/wrappers/__init__.py | 49 ++++++++++--------- .../models/wrappers/cell/__init__.py | 20 ++++++++ .../models/wrappers/cell/can_wrapper.py | 22 +++++++++ .../models/wrappers/cell/ccxn_wrapper.py | 21 ++++++++ .../models/wrappers/cell/cwn_wrapper.py | 23 +++++++++ .../models/wrappers/cell/cwndcm_wrapper.py | 21 ++++++++ .../models/wrappers/graph/__init__.py | 15 ++++++ .../models/wrappers/graph/gnn_wrapper.py | 14 ++++++ .../models/wrappers/hypergraph/__init__.py | 15 ++++++ .../wrappers/hypergraph/hypergraph_wrapper.py | 14 ++++++ .../{default_wrapper.py => old_wrapper.py} | 0 .../models/wrappers/simplicial/__init__.py | 21 ++++++++ .../models/wrappers/simplicial/san_wrapper.py | 17 +++++++ .../wrappers/simplicial/sccn_wrapper.py | 48 ++++++++++++++++++ .../wrappers/simplicial/sccnn_wrapper.py | 28 +++++++++++ .../models/wrappers/simplicial/scn_wrapper.py | 45 +++++++++++++++++ topobenchmarkx/models/wrappers/wrapper.py | 42 ++++++++++++++++ 35 files changed, 427 insertions(+), 43 deletions(-) create mode 100644 topobenchmarkx/models/wrappers/cell/__init__.py create mode 100644 topobenchmarkx/models/wrappers/cell/can_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/cell/cwn_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/graph/__init__.py create mode 100644 topobenchmarkx/models/wrappers/graph/gnn_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/hypergraph/__init__.py create mode 100644 topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py rename topobenchmarkx/models/wrappers/{default_wrapper.py => old_wrapper.py} (100%) create mode 100644 topobenchmarkx/models/wrappers/simplicial/__init__.py create mode 100644 topobenchmarkx/models/wrappers/simplicial/san_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py create mode 100644 topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py create mode 100755 topobenchmarkx/models/wrappers/wrapper.py diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index 510ce09c..dc60d16a 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -21,8 +21,9 @@ backbone: att_lift: False backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CANWrapper + _target_: topobenchmarkx.models.wrappers.CANWrapper _partial_: true + wrapper_name: CANWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/cell/ccxn.yaml b/configs/model/cell/ccxn.yaml index 7d4f336a..488ad52e 100755 --- a/configs/model/cell/ccxn.yaml +++ b/configs/model/cell/ccxn.yaml @@ -17,8 +17,9 @@ backbone_additional_params: hidden_channels: ${model.feature_encoder.out_channels} backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CCXNWrapper + _target_: topobenchmarkx.models.wrappers.CCXNWrapper _partial_: true + wrapper_name: CCXNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/cell/cwn.yaml b/configs/model/cell/cwn.yaml index 85ab2cb7..2b80ae6e 100755 --- a/configs/model/cell/cwn.yaml +++ b/configs/model/cell/cwn.yaml @@ -15,8 +15,9 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CWNWrapper + _target_: topobenchmarkx.models.wrappers.CWNWrapper _partial_: true + wrapper_name: CWNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/cell/cwn_dcm.yaml b/configs/model/cell/cwn_dcm.yaml index 764602f7..4e573817 100755 --- a/configs/model/cell/cwn_dcm.yaml +++ b/configs/model/cell/cwn_dcm.yaml @@ -16,8 +16,9 @@ backbone: dropout: 0.0 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.CWNDCMWrapper + _target_: topobenchmarkx.models.wrappers.CWNDCMWrapper _partial_: true + wrapper_name: CWNDCMWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/graph/gat.yaml b/configs/model/graph/gat.yaml index 92a74591..a733fa06 100755 --- a/configs/model/graph/gat.yaml +++ b/configs/model/graph/gat.yaml @@ -17,8 +17,9 @@ backbone: concat: true backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true + wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} @@ -35,7 +36,6 @@ head_model: out_channels: ${dataset.parameters.num_classes} pooling_type: sum - loss: _target_: topobenchmarkx.models.losses.loss.DefaultLoss task: ${dataset.parameters.task} diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index ffc152ad..49e6d90c 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -14,8 +14,9 @@ backbone: act: relu backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true + wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/graph/gin.yaml b/configs/model/graph/gin.yaml index b3797cdf..4197fe7a 100755 --- a/configs/model/graph/gin.yaml +++ b/configs/model/graph/gin.yaml @@ -14,8 +14,9 @@ backbone: act: relu backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true + wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} @@ -32,7 +33,6 @@ head_model: out_channels: ${dataset.parameters.num_classes} pooling_type: sum - loss: _target_: topobenchmarkx.models.losses.loss.DefaultLoss task: ${dataset.parameters.task} diff --git a/configs/model/hypergraph/alldeepset.yaml b/configs/model/hypergraph/alldeepset.yaml index ff6d84e2..153fdf97 100755 --- a/configs/model/hypergraph/alldeepset.yaml +++ b/configs/model/hypergraph/alldeepset.yaml @@ -24,8 +24,9 @@ backbone: #num_features: ${model.backbone.hidden_channels} backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 6c589729..673e49b2 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -18,8 +18,9 @@ backbone: mlp_dropout: 0. backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 42088fff..dfb89041 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -20,8 +20,9 @@ backbone: aggregate: 'add' backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/hypergraph/unignn.yaml b/configs/model/hypergraph/unignn.yaml index bfd56c5f..785a92c5 100755 --- a/configs/model/hypergraph/unignn.yaml +++ b/configs/model/hypergraph/unignn.yaml @@ -7,8 +7,9 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index 31dd1d62..9d9d3cb3 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -18,8 +18,9 @@ backbone: layer_drop: 0.2 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.HypergraphWrapper + _target_: topobenchmarkx.models.wrappers.HypergraphWrapper _partial_: true + wrapper_name: HypergraphWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index fd91137c..01806431 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -18,8 +18,9 @@ backbone: epsilon_harmonic: 1e-1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SANWrapper + _target_: topobenchmarkx.models.wrappers.SANWrapper _partial_: true + wrapper_name: SANWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index 4b8b245b..a6f70919 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -13,8 +13,9 @@ backbone: update_func: "sigmoid" backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNWrapper + _target_: topobenchmarkx.models.wrappers.SCCNWrapper _partial_: true + wrapper_name: SCCNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/simplicial/sccnn.yaml b/configs/model/simplicial/sccnn.yaml index 4e1d65ff..d9b8e501 100755 --- a/configs/model/simplicial/sccnn.yaml +++ b/configs/model/simplicial/sccnn.yaml @@ -26,8 +26,9 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper + _target_: topobenchmarkx.models.wrappers.SCCNNWrapper _partial_: true + wrapper_name: SCCNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/simplicial/sccnn_custom.yaml b/configs/model/simplicial/sccnn_custom.yaml index 48a5f4a9..9b309146 100755 --- a/configs/model/simplicial/sccnn_custom.yaml +++ b/configs/model/simplicial/sccnn_custom.yaml @@ -27,8 +27,9 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCCNNWrapper + _target_: topobenchmarkx.models.wrappers.SCCNNWrapper _partial_: true + wrapper_name: SCCNNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/simplicial/scn.yaml b/configs/model/simplicial/scn.yaml index bb2c2b28..f3e81b3f 100755 --- a/configs/model/simplicial/scn.yaml +++ b/configs/model/simplicial/scn.yaml @@ -17,8 +17,9 @@ backbone: n_layers: 1 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.default_wrapper.SCNWrapper + _target_: topobenchmarkx.models.wrappers.SCNWrapper _partial_: true + wrapper_name: SCNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/train.yaml b/configs/train.yaml index 3b109e4c..afade5b8 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,7 +5,7 @@ defaults: - _self_ - dataset: PROTEINS_TU #us_country_demos - - model: simplicial/sccn #hypergraph/unignn2 #allsettransformer + - model: graph/gcn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/topobenchmarkx/models/wrappers/__init__.py b/topobenchmarkx/models/wrappers/__init__.py index 705d292e..5264dd99 100755 --- a/topobenchmarkx/models/wrappers/__init__.py +++ b/topobenchmarkx/models/wrappers/__init__.py @@ -1,26 +1,29 @@ -import hydra # noqa: F401 -import torch -from omegaconf import DictConfig # noqa: F401 +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper +from topobenchmarkx.models.wrappers.graph import GNNWrapper +from topobenchmarkx.models.wrappers.hypergraph import HypergraphWrapper +from topobenchmarkx.models.wrappers.simplicial import SANWrapper, SCNWrapper, SCCNNWrapper, SCCNWrapper +from topobenchmarkx.models.wrappers.cell import CANWrapper, CWNDCMWrapper, CWNWrapper, CCXNWrapper +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 -class DefaultLoss: - """Abstract class that provides an interface to loss logic within - netowrk.""" - def __init__(self, task): - if task == "classification": - self.criterion = torch.nn.CrossEntropyLoss() - - elif task == "regression": - self.criterion = torch.nn.mse() - else: - raise Exception("Loss is not defined") - - def __call__(self, model_output): - """Loss logic based on model_output.""" - - logits = model_output["logits"] - target = model_output["labels"] - model_output["loss"] = self.criterion(logits, target) - - return model_output +# Export all wrappers +__all__ = [ + "DefaultWrapper", + "GNNWrapper", + "HypergraphWrapper", + "SANWrapper", + "SCNWrapper", + "SCCNNWrapper", + "SCCNWrapper", + "CANWrapper", + "CWNDCMWrapper", + "CWNWrapper", + "CCXNWrapper", + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/__init__.py b/topobenchmarkx/models/wrappers/cell/__init__.py new file mode 100644 index 00000000..06af3421 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/__init__.py @@ -0,0 +1,20 @@ +from topobenchmarkx.models.wrappers.cell.can_wrapper import CANWrapper +from topobenchmarkx.models.wrappers.cell.cwndcm_wrapper import CWNDCMWrapper +from topobenchmarkx.models.wrappers.cell.cwn_wrapper import CWNWrapper +from topobenchmarkx.models.wrappers.cell.ccxn_wrapper import CCXNWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +__all__ = [ + "CANWrapper", + "CWNDCMWrapper", + "CWNWrapper", + "CCXNWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/can_wrapper.py b/topobenchmarkx/models/wrappers/cell/can_wrapper.py new file mode 100644 index 00000000..02d6bc60 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/can_wrapper.py @@ -0,0 +1,22 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CANWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + x_1 = self.backbone( + x_0=batch.x_0, + x_1=batch.x_1, + adjacency_0=batch.adjacency_0.coalesce(), + down_laplacian_1=batch.down_laplacian_1.coalesce(), + up_laplacian_1=batch.up_laplacian_1.coalesce(), + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_1"] = x_1 + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py b/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py new file mode 100644 index 00000000..ed159c9a --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py @@ -0,0 +1,21 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CCXNWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + x_0, x_1, x_2 = self.backbone( + x_0=batch.x_0, + x_1=batch.x_1, + adjacency_0=batch.adjacency_0, + incidence_2_t=batch.incidence_2.T, + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["x_1"] = x_1 + model_out["x_2"] = x_2 + return model_out diff --git a/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py b/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py new file mode 100644 index 00000000..941ba2a7 --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py @@ -0,0 +1,23 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CWNWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + x_0, x_1, x_2 = self.backbone( + x_0=batch.x_0, + x_1=batch.x_1, + x_2=batch.x_2, + incidence_1_t=batch.incidence_1.T, + adjacency_0=batch.adjacency_1, + incidence_2=batch.incidence_2, + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["x_1"] = x_1 + model_out["x_2"] = x_2 + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py b/topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py new file mode 100644 index 00000000..3439282b --- /dev/null +++ b/topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py @@ -0,0 +1,21 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class CWNDCMWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + x_1 = self.backbone( + batch.x_1, + batch.down_laplacian_1.coalesce(), + batch.up_laplacian_1.coalesce(), + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + + model_out["x_1"] = x_1 + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/graph/__init__.py b/topobenchmarkx/models/wrappers/graph/__init__.py new file mode 100644 index 00000000..74d5787d --- /dev/null +++ b/topobenchmarkx/models/wrappers/graph/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.wrappers.graph.gnn_wrapper import GNNWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +# Export all wrappers +__all__ = [ + "GNNWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] diff --git a/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py b/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py new file mode 100644 index 00000000..59ffa7d3 --- /dev/null +++ b/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py @@ -0,0 +1,14 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class GNNWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + x_0 = self.backbone(batch.x_0, batch.edge_index) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/hypergraph/__init__.py b/topobenchmarkx/models/wrappers/hypergraph/__init__.py new file mode 100644 index 00000000..869f46b0 --- /dev/null +++ b/topobenchmarkx/models/wrappers/hypergraph/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.wrappers.hypergraph.hypergraph_wrapper import HypergraphWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +# Export all wrappers +__all__ = [ + "HypergraphWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py b/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py new file mode 100644 index 00000000..f458592b --- /dev/null +++ b/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py @@ -0,0 +1,14 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class HypergraphWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + x_0, x_1 = self.backbone(batch.x_0, batch.incidence_hyperedges) + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = x_0 + model_out["hyperedge"] = x_1 + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/default_wrapper.py b/topobenchmarkx/models/wrappers/old_wrapper.py similarity index 100% rename from topobenchmarkx/models/wrappers/default_wrapper.py rename to topobenchmarkx/models/wrappers/old_wrapper.py diff --git a/topobenchmarkx/models/wrappers/simplicial/__init__.py b/topobenchmarkx/models/wrappers/simplicial/__init__.py new file mode 100644 index 00000000..7dd8b690 --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/__init__.py @@ -0,0 +1,21 @@ +from topobenchmarkx.models.wrappers.simplicial.san_wrapper import SANWrapper +from topobenchmarkx.models.wrappers.simplicial.scn_wrapper import SCNWrapper +from topobenchmarkx.models.wrappers.simplicial.sccnn_wrapper import SCCNNWrapper +from topobenchmarkx.models.wrappers.simplicial.sccn_wrapper import SCCNWrapper + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 + +# Export all wrappers and the dictionary +__all__ = [ + "SANWrapper", + "SCNWrapper", + "SCCNNWrapper", + "SCCNWrapper", + + # "OtherWrapper1", + # "OtherWrapper2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py new file mode 100644 index 00000000..a55fb308 --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py @@ -0,0 +1,17 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SANWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + x_1 = self.backbone( + batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1 + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_0"] = torch.sparse.mm(batch.incidence_1, x_1) + model_out["x_1"] = x_1 + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py new file mode 100644 index 00000000..9e64112f --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py @@ -0,0 +1,48 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SCCNWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + features = { + f"rank_{r}": batch[f"x_{r}"] + for r in range(self.backbone.layers[0].max_rank + 1) + } + incidences = { + f"rank_{r}": batch[f"incidence_{r}"] + for r in range(1, self.backbone.layers[0].max_rank + 1) + } + adjacencies = { + f"rank_{r}": batch[f"hodge_laplacian_{r}"] + for r in range(self.backbone.layers[0].max_rank + 1) + } + output = self.backbone(features, incidences, adjacencies) + + # TODO: First decide which strategy is the best then make code general + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + if len(output) == 3: + x_0, x_1, x_2 = ( + output["rank_0"], + output["rank_1"], + output["rank_2"], + ) + + model_out["x_2"] = x_2 + model_out["x_1"] = x_1 + model_out["x_0"] = x_0 + + elif len(output) == 2: + x_0, x_1 = output["rank_0"], output["rank_1"] + + model_out["x_1"] = x_1 + model_out["x_0"] = x_0 + + else: + raise ValueError( + f"Invalid number of output tensors: {len(output)}" + ) + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py new file mode 100644 index 00000000..d9eef982 --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py @@ -0,0 +1,28 @@ +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SCCNNWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + x_all = (batch.x_0, batch.x_1, batch.x_2) + laplacian_all = ( + batch.hodge_laplacian_0, + batch.down_laplacian_1, + batch.up_laplacian_1, + batch.down_laplacian_2, + batch.up_laplacian_2, + ) + + incidence_all = (batch.incidence_1, batch.incidence_2) + x_0, x_1, x_2 = self.backbone(x_all, laplacian_all, incidence_all) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + + model_out["x_0"] = x_0 + model_out["x_1"] = x_1 + model_out["x_2"] = x_2 + + return model_out \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py new file mode 100644 index 00000000..42e5f45a --- /dev/null +++ b/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py @@ -0,0 +1,45 @@ +import torch +from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper + +class SCNWrapper(DefaultWrapper): + """Abstract class that provides an interface to loss logic within + network.""" + + def forward(self, batch): + """Define logic for forward pass.""" + + laplacian_0 = self.normalize_matrix(batch.hodge_laplacian_0) + laplacian_1 = self.normalize_matrix(batch.hodge_laplacian_1) + laplacian_2 = self.normalize_matrix(batch.hodge_laplacian_2) + x_0, x_1, x_2 = self.backbone( + batch.x_0, + batch.x_1, + batch.x_2, + laplacian_0, + laplacian_1, + laplacian_2, + ) + + model_out = {"labels": batch.y, "batch_0": batch.batch_0} + model_out["x_2"] = x_2 + model_out["x_1"] = x_1 + model_out["x_0"] = x_0 + + return model_out + + def normalize_matrix(self, matrix): + matrix_ = matrix.to_dense() + n, _ = matrix_.shape + abs_matrix = abs(matrix_) + diag_sum = abs_matrix.sum(axis=1) + + # Handle division by zero + idxs = torch.where(diag_sum != 0) + diag_sum[idxs] = 1.0 / torch.sqrt(diag_sum[idxs]) + + diag_indices = torch.stack([torch.arange(n), torch.arange(n)]) + diag_matrix = torch.sparse_coo_tensor( + diag_indices, diag_sum, matrix_.shape, device=matrix.device + ).coalesce() + normalized_matrix = diag_matrix @ (matrix @ diag_matrix) + return normalized_matrix \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/wrapper.py b/topobenchmarkx/models/wrappers/wrapper.py new file mode 100755 index 00000000..372c9a07 --- /dev/null +++ b/topobenchmarkx/models/wrappers/wrapper.py @@ -0,0 +1,42 @@ +from abc import ABC, abstractmethod +import torch +import torch.nn as nn + +class DefaultWrapper(ABC, torch.nn.Module): + """Abstract class that provides an interface to handle the network + output.""" + + def __init__(self, backbone, **kwargs): + super().__init__() + self.backbone = backbone + out_channels = kwargs["out_channels"] + self.dimensions = range(kwargs["num_cell_dimensions"]) + + for i in self.dimensions: + setattr( + self, + f"ln_{i}", + nn.LayerNorm(out_channels), + ) + + def __call__(self, batch): + """Define logic for forward pass.""" + model_out = self.forward(batch) + model_out = self.residual_connection(model_out=model_out, batch=batch) + return model_out + + def residual_connection(self, model_out, batch): + for i in self.dimensions: + if ( + (f"x_{i}" in batch) + and hasattr(self, f"ln_{i}") + and (f"x_{i}" in model_out) + ): + residual = model_out[f"x_{i}"] + batch[f"x_{i}"] + model_out[f"x_{i}"] = getattr(self, f"ln_{i}")(residual) + return model_out + + @abstractmethod + def forward(self, batch): + """Define handling output here.""" + pass From c69c1f3808b6d52579a7506074830cb2558faf0e Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 05:35:52 +0200 Subject: [PATCH 03/32] enhanced modularity --- configs/loss/default.yaml | 0 configs/model/cell/can.yaml | 17 +- configs/model/cell/ccxn.yaml | 13 +- configs/model/cell/cwn.yaml | 13 +- configs/model/cell/cwn_dcm.yaml | 13 +- configs/model/graph/gat.yaml | 13 +- configs/model/graph/gcn.yaml | 15 +- configs/model/graph/gin.yaml | 13 +- configs/model/hypergraph/alldeepset.yaml | 11 +- .../model/hypergraph/allsettransformer.yaml | 12 +- configs/model/hypergraph/edgnn.yaml | 10 +- configs/model/hypergraph/unignn.yaml | 16 +- configs/model/hypergraph/unignn2.yaml | 11 +- configs/model/simplicial/san.yaml | 13 +- configs/model/simplicial/sccn.yaml | 12 +- configs/model/simplicial/sccnn.yaml | 13 +- configs/model/simplicial/sccnn_custom.yaml | 13 +- configs/model/simplicial/scn.yaml | 13 +- configs/train.yaml | 4 +- .../models/abstractions/__init__.py | 0 topobenchmarkx/models/encoders/__init__.py | 15 ++ ...efault_encoders.py => all_cell_encoder.py} | 177 +++++------------- .../{abstractions => encoders}/encoder.py | 13 +- topobenchmarkx/models/encoders/perceiver.py | 82 ++++++++ topobenchmarkx/models/head_model/__init__.py | 0 topobenchmarkx/models/head_models/__init__.py | 15 ++ .../models/head_models/head_model.py | 47 +++++ .../models.py => head_models/old_models.py} | 3 +- .../models/head_models/zero_cell_model.py | 65 +++++++ topobenchmarkx/models/losses/__init__.py | 15 ++ topobenchmarkx/models/losses/default_loss.py | 38 ++++ topobenchmarkx/models/losses/loss.py | 38 ++-- topobenchmarkx/models/losses/losses.py | 34 ---- topobenchmarkx/models/network_module.py | 72 +------ .../models/readouts/propagate_signal_down.py | 4 + topobenchmarkx/models/wrappers/__init__.py | 4 +- 36 files changed, 512 insertions(+), 335 deletions(-) delete mode 100755 configs/loss/default.yaml delete mode 100755 topobenchmarkx/models/abstractions/__init__.py rename topobenchmarkx/models/encoders/{default_encoders.py => all_cell_encoder.py} (50%) rename topobenchmarkx/models/{abstractions => encoders}/encoder.py (68%) delete mode 100644 topobenchmarkx/models/head_model/__init__.py create mode 100644 topobenchmarkx/models/head_models/__init__.py create mode 100644 topobenchmarkx/models/head_models/head_model.py rename topobenchmarkx/models/{head_model/models.py => head_models/old_models.py} (96%) create mode 100644 topobenchmarkx/models/head_models/zero_cell_model.py create mode 100644 topobenchmarkx/models/losses/default_loss.py delete mode 100755 topobenchmarkx/models/losses/losses.py diff --git a/configs/loss/default.yaml b/configs/loss/default.yaml deleted file mode 100755 index e69de29b..00000000 diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index dc60d16a..a8f3495a 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: can +model_domain: cell + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 selected_dimensions: @@ -34,14 +38,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} - in_channels: ${parameter_multiplication:${model.backbone.out_channels},${model.backbone.heads}} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel + in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} - pooling_type: sum + task_level: ${dataset.parameters.task_level} + pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/cell/ccxn.yaml b/configs/model/cell/ccxn.yaml index 488ad52e..783f92c5 100755 --- a/configs/model/cell/ccxn.yaml +++ b/configs/model/cell/ccxn.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: ccxn +model_domain: cell + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 @@ -30,14 +34,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/cell/cwn.yaml b/configs/model/cell/cwn.yaml index 2b80ae6e..7d5e4edf 100755 --- a/configs/model/cell/cwn.yaml +++ b/configs/model/cell/cwn.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: cwn +model_domain: cell + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 proj_dropout: 0.0 @@ -29,14 +33,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/cell/cwn_dcm.yaml b/configs/model/cell/cwn_dcm.yaml index 4e573817..c411e025 100755 --- a/configs/model/cell/cwn_dcm.yaml +++ b/configs/model/cell/cwn_dcm.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: cwn_dcm +model_domain: cell + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 proj_dropout: 0.0 @@ -29,14 +33,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/graph/gat.yaml b/configs/model/graph/gat.yaml index a733fa06..fed38c36 100755 --- a/configs/model/graph/gat.yaml +++ b/configs/model/graph/gat.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: gat +model_domain: graph + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} out_channels: 32 @@ -30,14 +34,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index 49e6d90c..7bb1c695 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: gcn +model_domain: graph + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 64 @@ -14,7 +18,7 @@ backbone: act: relu backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.GNNWrapper + _target_: topobenchmarkx.models.wrappers.GNNWrapper _partial_: true wrapper_name: GNNWrapper out_channels: ${model.feature_encoder.out_channels} @@ -27,14 +31,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/graph/gin.yaml b/configs/model/graph/gin.yaml index 4197fe7a..4b9ae61d 100755 --- a/configs/model/graph/gin.yaml +++ b/configs/model/graph/gin.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: gin +model_domain: graph + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 @@ -27,14 +31,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/alldeepset.yaml b/configs/model/hypergraph/alldeepset.yaml index 153fdf97..2780cea9 100755 --- a/configs/model/hypergraph/alldeepset.yaml +++ b/configs/model/hypergraph/alldeepset.yaml @@ -1,9 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule model_name: alldeepset +model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 @@ -37,14 +39,15 @@ readout: num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 673e49b2..817e1d8f 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -1,9 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule model_name: allsettransformer +model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 @@ -31,15 +33,15 @@ readout: num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum - loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index dfb89041..4a07ee48 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -4,7 +4,8 @@ model_name: edgnn model_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 @@ -33,14 +34,15 @@ readout: num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/unignn.yaml b/configs/model/hypergraph/unignn.yaml index 785a92c5..91fbc1f0 100755 --- a/configs/model/hypergraph/unignn.yaml +++ b/configs/model/hypergraph/unignn.yaml @@ -1,5 +1,14 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: unignn2 +mode_domain: hypergraph + +feature_encoder: + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder + in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} + out_channels: 32 + backbone: _target_: topomodelx.nn.hypergraph.unigcn.UniGCN in_channels: ${data.num_features} @@ -20,14 +29,15 @@ readout: num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index 9d9d3cb3..ca080e27 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -1,9 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule model_name: unignn2 +mode_domain: hypergraph feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 @@ -31,14 +33,15 @@ readout: num_cell_dimensions: None head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index 01806431..46aa3709 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: san +model_domain: simplicial + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 64 selected_dimensions: @@ -31,14 +35,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index a6f70919..09d3f4cf 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -1,7 +1,10 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: sccnn +model_domain: simplicial feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} # ${dataset.parameters.num_features} out_channels: 32 @@ -26,14 +29,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/sccnn.yaml b/configs/model/simplicial/sccnn.yaml index d9b8e501..a02853f0 100755 --- a/configs/model/simplicial/sccnn.yaml +++ b/configs/model/simplicial/sccnn.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: sccnn +model_domain: simplicial + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 selected_dimensions: @@ -39,15 +43,16 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/sccnn_custom.yaml b/configs/model/simplicial/sccnn_custom.yaml index 9b309146..cfa62f75 100755 --- a/configs/model/simplicial/sccnn_custom.yaml +++ b/configs/model/simplicial/sccnn_custom.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +model_name: sccnn_custom +model_domain: simplicial + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 64 selected_dimensions: @@ -40,14 +44,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/model/simplicial/scn.yaml b/configs/model/simplicial/scn.yaml index f3e81b3f..78adb09d 100755 --- a/configs/model/simplicial/scn.yaml +++ b/configs/model/simplicial/scn.yaml @@ -1,7 +1,11 @@ _target_: topobenchmarkx.models.network_module.NetworkModule +mdoel_name: scn +model_type: simplicial + feature_encoder: - _target_: topobenchmarkx.models.encoders.default_encoders.BaseFeatureEncoder + _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} + encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 selected_dimensions: @@ -30,14 +34,15 @@ readout: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} head_model: - _target_: topobenchmarkx.models.head_model.models.DefaultHead - task_level: ${dataset.parameters.task_level} + _target_: topobenchmarkx.models.head_models.${model.head_model.head_model_name} + head_model_name: ZeroCellModel in_channels: ${model.feature_encoder.out_channels} out_channels: ${dataset.parameters.num_classes} + task_level: ${dataset.parameters.task_level} pooling_type: sum loss: - _target_: topobenchmarkx.models.losses.loss.DefaultLoss + _target_: topobenchmarkx.models.losses.DefaultLoss task: ${dataset.parameters.task} loss_type: ${dataset.parameters.loss_type} diff --git a/configs/train.yaml b/configs/train.yaml index afade5b8..cfef1ef5 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,8 +4,8 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: PROTEINS_TU #us_country_demos - - model: graph/gcn #hypergraph/unignn2 #allsettransformer + - dataset: IMDB_BINARY #us_country_demos + - model: hypergraph/edgnn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/topobenchmarkx/models/abstractions/__init__.py b/topobenchmarkx/models/abstractions/__init__.py deleted file mode 100755 index e69de29b..00000000 diff --git a/topobenchmarkx/models/encoders/__init__.py b/topobenchmarkx/models/encoders/__init__.py index e69de29b..fec596d4 100644 --- a/topobenchmarkx/models/encoders/__init__.py +++ b/topobenchmarkx/models/encoders/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.encoders.encoder import AbstractFeatureEncoder +from topobenchmarkx.models.encoders.all_cell_encoder import AllCellFeatureEncoder + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.encoders.other_encoder_1 import OtherEncoder1 +# from topobenchmarkx.models.encoders.other_encoder_2 import OtherEncoder2 + +__all__ = [ + "AbstractFeatureEncoder" + "AllCellFeatureEncoder" + # "OtherEncoder1", + # "OtherEncoder2", + # ... add other readout classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/encoders/default_encoders.py b/topobenchmarkx/models/encoders/all_cell_encoder.py similarity index 50% rename from topobenchmarkx/models/encoders/default_encoders.py rename to topobenchmarkx/models/encoders/all_cell_encoder.py index 4ac0b8ac..469ad55d 100644 --- a/topobenchmarkx/models/encoders/default_encoders.py +++ b/topobenchmarkx/models/encoders/all_cell_encoder.py @@ -1,57 +1,10 @@ import torch import torch_geometric from torch_geometric.nn.norm import GraphNorm +from topobenchmarkx.models.encoders.encoder import AbstractFeatureEncoder -from topobenchmarkx.models.abstractions.encoder import ( - AbstractInitFeaturesEncoder, -) - -class BaseEncoder(torch.nn.Module): - r"""Encoder class that uses two linear layers with GraphNorm, Relu - activation function, and dropout between the two layers. - - Parameters - ---------- - in_channels: int - Dimension of input features. - out_channels: int - Dimensions of output features. - dropout: float - Percentage of channels to discard between the two linear layers. - """ - - def __init__(self, in_channels, out_channels, dropout=0): - super().__init__() - self.linear1 = torch.nn.Linear(in_channels, out_channels) - self.linear2 = torch.nn.Linear(out_channels, out_channels) - self.relu = torch.nn.ReLU() - self.BN = GraphNorm(out_channels) - self.dropout = torch.nn.Dropout(dropout) - - def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor of dimensions [N, in_channels]. - batch: torch.Tensor - The batch vector which assigns each element to a specific example. - - Returns - ------- - torch.Tensor - Output tensor of shape [N, out_channels]. - """ - x = self.linear1(x) - x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x) - x = self.dropout(self.relu(x)) - x = self.linear2(x) - return x - - -class BaseFeatureEncoder(AbstractInitFeaturesEncoder): +class AllCellFeatureEncoder(AbstractFeatureEncoder): r"""Encoder class to apply BaseEncoder to the features of higher order structures. @@ -73,8 +26,10 @@ def __init__( out_channels, proj_dropout=0, selected_dimensions=None, + **kwargs ): - super(AbstractInitFeaturesEncoder, self).__init__() + super().__init__(**kwargs) + self.in_channels = in_channels self.out_channels = out_channels self.dimensions = ( @@ -119,83 +74,47 @@ def forward( ) return data +class BaseEncoder(torch.nn.Module): + r"""Encoder class that uses two linear layers with GraphNorm, Relu + activation function, and dropout between the two layers. + + Parameters + ---------- + in_channels: int + Dimension of input features. + out_channels: int + Dimensions of output features. + dropout: float + Percentage of channels to discard between the two linear layers. + """ + + def __init__(self, in_channels, out_channels, dropout=0): + super().__init__() + self.linear1 = torch.nn.Linear(in_channels, out_channels) + self.linear2 = torch.nn.Linear(out_channels, out_channels) + self.relu = torch.nn.ReLU() + self.BN = GraphNorm(out_channels) + self.dropout = torch.nn.Dropout(dropout) + + def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: + r"""Forward pass. + + Parameters + ---------- + x: torch.Tensor + Input tensor of dimensions [N, in_channels]. + batch: torch.Tensor + The batch vector which assigns each element to a specific example. + + Returns + ------- + torch.Tensor + Output tensor of shape [N, out_channels]. + """ + x = self.linear1(x) + x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x) + x = self.dropout(self.relu(x)) + x = self.linear2(x) + return x + -# from topobenchmarkx.models.encoders.perceiver import Perceiver -# class SetFeatureEncoder(AbstractInitFeaturesEncoder): -# r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures. - -# Parameters -# ---------- -# in_channels: list(int) -# Input dimensions for the features. -# out_channels: list(int) -# Output dimensions for the features. -# proj_dropout: float -# Dropout for the BaseEncoders. -# selected_dimensions: list(int) -# List of indexes to apply the BaseEncoders to. -# """ -# def __init__( -# self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None -# ): -# super(AbstractInitFeaturesEncoder, self).__init__() -# self.in_channels = in_channels -# self.out_channels = out_channels -# self.dimensions = ( -# selected_dimensions -# if selected_dimensions is not None -# else range(len(self.in_channels)) -# ) -# for idx, i in enumerate(self.dimensions): -# if idx == 0: -# setattr( -# self, -# f"encoder_{i}", -# BaseEncoder( -# self.in_channels[i], self.out_channels, dropout=proj_dropout -# ), -# ) -# else: -# setattr( -# self, -# f"encoder_{i}", -# Perceiver( -# dim=self.out_channels, -# depth=1, -# cross_heads=4, -# cross_dim_head=self.out_channels, -# latent_dim_head=self.out_channels, -# ), -# ) - -# def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: -# r""" -# Forward pass - -# Parameters -# ---------- -# data: torch_geometric.data.Data -# Input data object which should contain x_{i} features for each i in the selected_dimensions. - -# Returns -# ------- -# torch_geometric.data.Data -# Output data object. -# """ -# if not hasattr(data, "x_0"): -# data.x_0 = data.x - -# for idx, i in enumerate(self.dimensions): -# if idx == 0: -# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): -# batch = data.batch if i == 0 else getattr(data, f"batch_{i}") -# data[f"x_{i}"] = getattr(self, f"encoder_{i}")( -# data[f"x_{i}"], batch -# ) -# else: -# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): -# cell_features = data["x_0"][data[f"x_{i}"].long()] -# data[f"x_{i}"] = getattr(self, f"encoder_{i}")(cell_features) -# else: -# data[f"x_{i}"] = torch.tensor([], device=data.x_0.device) -# return data diff --git a/topobenchmarkx/models/abstractions/encoder.py b/topobenchmarkx/models/encoders/encoder.py similarity index 68% rename from topobenchmarkx/models/abstractions/encoder.py rename to topobenchmarkx/models/encoders/encoder.py index 6a5e4617..b21cc512 100644 --- a/topobenchmarkx/models/abstractions/encoder.py +++ b/topobenchmarkx/models/encoders/encoder.py @@ -4,13 +4,16 @@ import torch_geometric -class AbstractInitFeaturesEncoder(torch.nn.Module): - """Abstract class that provides an interface to define a custom initial - feature encoders.""" +class AbstractFeatureEncoder(torch.nn.Module): + """Abstract class that provides an interface to define a custom feature encoder.""" - def __init__(self): + def __init__(self, **kwargs): + super().__init__() return + def __call__(self, data): + return self.forward(data) + @abstractmethod def forward( self, data: torch_geometric.data.Data @@ -22,4 +25,4 @@ def forward( Returns: :data: torch_geometric.data.Data - """ + """ \ No newline at end of file diff --git a/topobenchmarkx/models/encoders/perceiver.py b/topobenchmarkx/models/encoders/perceiver.py index fd218b62..42381d07 100644 --- a/topobenchmarkx/models/encoders/perceiver.py +++ b/topobenchmarkx/models/encoders/perceiver.py @@ -391,3 +391,85 @@ def forward(self, data, mask=None, queries=None): # return x #self.to_logits(latents) return None + + + +# from topobenchmarkx.models.encoders.perceiver import Perceiver +# class SetFeatureEncoder(AbstractInitFeaturesEncoder): +# r"""Encoder class to apply BaseEncoder to the node features and Perceiver to the features of higher order structures. + +# Parameters +# ---------- +# in_channels: list(int) +# Input dimensions for the features. +# out_channels: list(int) +# Output dimensions for the features. +# proj_dropout: float +# Dropout for the BaseEncoders. +# selected_dimensions: list(int) +# List of indexes to apply the BaseEncoders to. +# """ +# def __init__( +# self, in_channels, out_channels, proj_dropout=0, selected_dimensions=None +# ): +# super(AbstractInitFeaturesEncoder, self).__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.dimensions = ( +# selected_dimensions +# if selected_dimensions is not None +# else range(len(self.in_channels)) +# ) +# for idx, i in enumerate(self.dimensions): +# if idx == 0: +# setattr( +# self, +# f"encoder_{i}", +# BaseEncoder( +# self.in_channels[i], self.out_channels, dropout=proj_dropout +# ), +# ) +# else: +# setattr( +# self, +# f"encoder_{i}", +# Perceiver( +# dim=self.out_channels, +# depth=1, +# cross_heads=4, +# cross_dim_head=self.out_channels, +# latent_dim_head=self.out_channels, +# ), +# ) + +# def forward(self, data: torch_geometric.data.Data) -> torch_geometric.data.Data: +# r""" +# Forward pass + +# Parameters +# ---------- +# data: torch_geometric.data.Data +# Input data object which should contain x_{i} features for each i in the selected_dimensions. + +# Returns +# ------- +# torch_geometric.data.Data +# Output data object. +# """ +# if not hasattr(data, "x_0"): +# data.x_0 = data.x + +# for idx, i in enumerate(self.dimensions): +# if idx == 0: +# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): +# batch = data.batch if i == 0 else getattr(data, f"batch_{i}") +# data[f"x_{i}"] = getattr(self, f"encoder_{i}")( +# data[f"x_{i}"], batch +# ) +# else: +# if hasattr(data, f"x_{i}") and hasattr(self, f"encoder_{i}"): +# cell_features = data["x_0"][data[f"x_{i}"].long()] +# data[f"x_{i}"] = getattr(self, f"encoder_{i}")(cell_features) +# else: +# data[f"x_{i}"] = torch.tensor([], device=data.x_0.device) +# return data diff --git a/topobenchmarkx/models/head_model/__init__.py b/topobenchmarkx/models/head_model/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/topobenchmarkx/models/head_models/__init__.py b/topobenchmarkx/models/head_models/__init__.py new file mode 100644 index 00000000..4e4577e5 --- /dev/null +++ b/topobenchmarkx/models/head_models/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.head_models.head_model import AbstractHeadModel +from topobenchmarkx.models.head_models.zero_cell_model import ZeroCellModel + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.readouts.other_readout_1 import OtherheadModel1 +# from topobenchmarkx.models.readouts.other_readout_2 import OtherheadModel2 + +__all__ = [ + "AbstractHeadModel", + "ZeroCellModel", + # "OtherheadModel1", + # "OtherheadModel2", + # ... add other readout classes here +] diff --git a/topobenchmarkx/models/head_models/head_model.py b/topobenchmarkx/models/head_models/head_model.py new file mode 100644 index 00000000..117dce5d --- /dev/null +++ b/topobenchmarkx/models/head_models/head_model.py @@ -0,0 +1,47 @@ +import torch +import torch_geometric +from abc import abstractmethod + +class AbstractHeadModel(torch.nn.Module): + r"""Head model. + + Parameters + ---------- + in_channels: int + Input dimension. + out_channels: int + Output dimension. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + + ): + super().__init__() + self.linear = torch.nn.Linear(in_channels, out_channels) + + def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + x = self.forward(model_out, batch) + model_out["logits"] = self.linear(x) + return model_out + + @abstractmethod + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass. + + Parameters + ---------- + model_out: dict + Dictionary containing the model output. + batch: torch_geometric.data.Data + Batch object containing the batched domain data. + + Returns + ------- + x: torch.Tensor + Output tensor over which the final linear layer is applied. + """ + pass + \ No newline at end of file diff --git a/topobenchmarkx/models/head_model/models.py b/topobenchmarkx/models/head_models/old_models.py similarity index 96% rename from topobenchmarkx/models/head_model/models.py rename to topobenchmarkx/models/head_models/old_models.py index 0c63f0be..ca5ff2b6 100644 --- a/topobenchmarkx/models/head_model/models.py +++ b/topobenchmarkx/models/head_models/old_models.py @@ -23,6 +23,7 @@ def __init__( out_channels: int, task_level: str, pooling_type: str = "sum", + **kwargs, ): super().__init__() self.linear = torch.nn.Linear(in_channels, out_channels) @@ -33,7 +34,7 @@ def __init__( assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" self.pooling_type = pooling_type - def forward(self, model_out: dict): + def forward(self, model_out: dict, batch): r"""Forward pass. Parameters diff --git a/topobenchmarkx/models/head_models/zero_cell_model.py b/topobenchmarkx/models/head_models/zero_cell_model.py new file mode 100644 index 00000000..5b196821 --- /dev/null +++ b/topobenchmarkx/models/head_models/zero_cell_model.py @@ -0,0 +1,65 @@ +import torch +import torch_geometric +from torch_geometric.utils import scatter +from topobenchmarkx.models.head_models.head_model import AbstractHeadModel + +class ZeroCellModel(AbstractHeadModel): + r"""Head model. + + Parameters + ---------- + in_channels: int + Input dimension. + out_channels: int + Output dimension. + task_level: str + Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. + pooling_type: str + Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + task_level: str, + pooling_type: str = "sum", + **kwargs, + ): + super().__init__(in_channels, out_channels) + + assert task_level in ["graph", "node"], "Invalid task_level" + self.task_level = task_level + + assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" + self.pooling_type = pooling_type + + + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass. + + Parameters + ---------- + model_out: dict + Dictionary containing the model output. + batch: torch_geometric.data.Data + Batch object containing the batched domain data. + + Returns + ------- + x: torch.Tensor + Output tensor over which the final linear layer is applied. + """ + x = model_out["x_0"] + batch = batch["batch_0"] + if self.task_level == "graph": + if self.pooling_type == "max": + x = scatter(x, batch, dim=0, reduce="max") + + elif self.pooling_type == "mean": + x = scatter(x, batch, dim=0, reduce="mean") + + elif self.pooling_type == "sum": + x = scatter(x, batch, dim=0, reduce="sum") + + return x \ No newline at end of file diff --git a/topobenchmarkx/models/losses/__init__.py b/topobenchmarkx/models/losses/__init__.py index e69de29b..bd2158c3 100755 --- a/topobenchmarkx/models/losses/__init__.py +++ b/topobenchmarkx/models/losses/__init__.py @@ -0,0 +1,15 @@ +from topobenchmarkx.models.losses.loss import AbstractltLoss +from topobenchmarkx.models.losses.default_loss import DefaultLoss + +# ... import other readout classes here +# For example: +# from topobenchmarkx.models.losses.other_loss_1 import OtherLoss1 +# from topobenchmarkx.models.losses.other_loss_2 import OtherLoss2 + +__all__ = [ + "AbstractltLoss", + "DefaultLoss" + # "OtherLoss1", + # "OtherLoss2", + # ... add other loss classes here +] \ No newline at end of file diff --git a/topobenchmarkx/models/losses/default_loss.py b/topobenchmarkx/models/losses/default_loss.py new file mode 100644 index 00000000..db4893cc --- /dev/null +++ b/topobenchmarkx/models/losses/default_loss.py @@ -0,0 +1,38 @@ +import torch +import torch_geometric +from topobenchmarkx.models.losses.loss import AbstractltLoss + +class DefaultLoss(AbstractltLoss): + """Abstract class that provides an interface to loss logic within + netowrk.""" + + def __init__(self, task, loss_type=None): + super().__init__() + self.task = task + if task == "classification" and loss_type == "cross_entropy": + self.criterion = torch.nn.CrossEntropyLoss() + + elif task == "regression" and loss_type == "mse": + self.criterion = torch.nn.MSELoss() + + elif task == "regression" and loss_type == "mae": + self.criterion = torch.nn.L1Loss() + + else: + raise Exception("Loss is not defined") + + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + """Loss logic based on model_out.""" + + logits = model_out["logits"] + target = model_out["labels"] + + if self.task == "regression": + target = target.unsqueeze(1) + + model_out["loss"] = self.criterion(logits, target) + + return model_out + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(task={self.task}, criterion={self.criterion.__class__.__name__})' \ No newline at end of file diff --git a/topobenchmarkx/models/losses/loss.py b/topobenchmarkx/models/losses/loss.py index 8517082d..21bf170d 100755 --- a/topobenchmarkx/models/losses/loss.py +++ b/topobenchmarkx/models/losses/loss.py @@ -1,33 +1,17 @@ -import torch +import torch_geometric +from abc import ABC, abstractmethod - -class DefaultLoss: +class AbstractltLoss(ABC): """Abstract class that provides an interface to loss logic within netowrk.""" - def __init__(self, task, loss_type=None): - self.task = task - if task == "classification" and loss_type == "cross_entropy": - self.criterion = torch.nn.CrossEntropyLoss() - - elif task == "regression" and loss_type == "mse": - self.criterion = torch.nn.MSELoss() - - elif task == "regression" and loss_type == "mae": - self.criterion = torch.nn.L1Loss() - - else: - raise Exception("Loss is not defined") + def __init__(self,): + super().__init__() - def __call__(self, model_output): + def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: """Loss logic based on model_output.""" - - logits = model_output["logits"] - target = model_output["labels"] - - if self.task == "regression": - target = target.unsqueeze(1) - - model_output["loss"] = self.criterion(logits, target) - - return model_output + return self.forward(model_out, batch) + + @abstractmethod + def forward(self, model_out: dict, batch: torch_geometric.data.Data): + pass diff --git a/topobenchmarkx/models/losses/losses.py b/topobenchmarkx/models/losses/losses.py deleted file mode 100755 index fbdb7dc8..00000000 --- a/topobenchmarkx/models/losses/losses.py +++ /dev/null @@ -1,34 +0,0 @@ -# import hydra -# import torch -# from omegaconf import DictConfig - -# from topobenchmarkx.models.losses.loss import AbstractLoss - - -# class DefaultLoss(AbstractLoss): -# """Abstract class that provides an interface to loss logic within netowrk""" - -# def __init__(self, cfg: DictConfig): -# super().__init__(cfg) - -# def init_loss( -# self, -# ): -# if self.cfg.task == 'classification': -# self.criterion = torch.nn.CrossEntropyLoss() - -# elif self.cfg.task == 'regression': -# self.criterion == torch.nn.mse() - -# else: -# raise Exception("Loss is not defined") - - -# def forward(self, model_output): -# """Loss logic based on model_output""" - -# logits = model_output["logits"] -# target = model_output["labels"] -# model_output["loss"] = self.criterion(logits, target) - -# return model_output diff --git a/topobenchmarkx/models/network_module.py b/topobenchmarkx/models/network_module.py index ee3b5dcf..ec861871 100755 --- a/topobenchmarkx/models/network_module.py +++ b/topobenchmarkx/models/network_module.py @@ -3,6 +3,7 @@ import torch from lightning import LightningModule from torchmetrics import MeanMetric +from torch_geometric.data import Data # import topomodelx @@ -49,7 +50,7 @@ def __init__( # Loss function self.task_level = self.hparams["head_model"].task_level - self.criterion = loss + self.loss = loss # Tracking best so far validation accuracy self.val_acc_best = MeanMetric() @@ -82,12 +83,12 @@ def model_step( batch = self.feature_encoder(batch) model_out = self.forward(batch) - model_out = self.readout(model_out, batch) - model_out = self.head_model(model_out) + model_out = self.readout(model_out=model_out, batch=batch) + model_out = self.head_model(model_out=model_out, batch=batch) - # Criterion and metric - model_out = self.process_outputs(batch, model_out) - model_out = self.criterion(model_out) + # Loss and metric + model_out = self.process_outputs(model_out=model_out, batch=batch) + model_out = self.loss(model_out=model_out, batch=batch) self.evaluator.update(model_out) return model_out @@ -130,20 +131,6 @@ def validation_step( self.state_str = "Validation" model_out = self.model_step(batch) - # # Keep only validation data points - # if self.task_level == "node": - # for key, val in model_out.items(): - # # if key not in ["loss", "hyperedge"]: - # if key in ["logits", "labels"]: - # model_out[key] = val[batch.val_mask] - - # # Criterion - # model_out = self.criterion(model_out) - - # # Evaluation - # self.evaluator.update(model_out) - # self.metric_collector_val.append((model_out["logits"], model_out["labels"])) - # Log Loss self.log( "val/loss", @@ -166,15 +153,6 @@ def test_step( self.state_str = "Test" model_out = self.model_step(batch) - # if self.task_level == "node": - # # Keep only test data points - # for key, val in model_out.items(): - # if key in ["logits", "labels"]: - # model_out[key] = val[batch.test_mask] - - # # Criterion - # model_out = self.criterion(model_out) - # Log loss self.log( "test/loss", @@ -185,11 +163,7 @@ def test_step( batch_size=1, ) - # Evaluation - # self.evaluator.update(model_out) - # self.metric_collector_test.append((model_out["logits"], model_out["labels"])) - - def process_outputs(self, batch, model_out: dict) -> dict: + def process_outputs(self, model_out: dict, batch: Data) -> dict: """Process model outputs.""" # Get the correct mask @@ -308,35 +282,5 @@ def configure_optimizers(self) -> dict[str, Any]: return {"optimizer": optimizer} -# Collect validation statistics -# self.val_acc_best.update(model_out["metrics"]["acc"]) -# self.metric_collector.append(model_out["metrics"]["acc"]) - - -# def on_train_start(self) -> None: -# """Lightning hook that is called when training begins.""" -# # by default lightning executes validation step sanity checks before training starts, -# # so it's worth to make sure validation metrics don't store results from these checks -# # self.val_loss.reset() -# # self.val_acc.reset() -# self.val_acc_best.reset() - - -# def on_validation_epoch_end(self) -> None: -# "Lightning hook that is called when a validation epoch ends." -# pass -# self.criterion = torch.nn.CrossEntropyLoss() - -# self.evaluator = evaluator -# # metric objects for calculating and averaging accuracy across batches -# self.train_acc = Accuracy(task="multiclass", num_classes=7) -# self.val_acc = Accuracy(task="multiclass", num_classes=7) -# self.test_acc = Accuracy(task="multiclass", num_classes=7) - -# for averaging loss across batches -# self.train_loss = MeanMetric() -# self.val_loss = MeanMetric() -# self.test_loss = MeanMetric() - if __name__ == "__main__": _ = NetworkModule(None, None, None, None) diff --git a/topobenchmarkx/models/readouts/propagate_signal_down.py b/topobenchmarkx/models/readouts/propagate_signal_down.py index 155d5a3b..95471249 100644 --- a/topobenchmarkx/models/readouts/propagate_signal_down.py +++ b/topobenchmarkx/models/readouts/propagate_signal_down.py @@ -6,6 +6,7 @@ class PropagateSignalDown(torch.nn.Module): def __init__(self, **kwargs): super().__init__() + self.name = kwargs["readout_name"] self.dimensions = range(kwargs["num_cell_dimensions"] - 1, 0, -1) hidden_dim = kwargs["hidden_dim"] @@ -40,3 +41,6 @@ def forward(self, model_out, batch): ) return model_out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_cell_dimensions={len(self.dimensions)}, hidden_dim={self.hidden_dim}, readout_name={self.name}" diff --git a/topobenchmarkx/models/wrappers/__init__.py b/topobenchmarkx/models/wrappers/__init__.py index 5264dd99..4c2099e7 100755 --- a/topobenchmarkx/models/wrappers/__init__.py +++ b/topobenchmarkx/models/wrappers/__init__.py @@ -6,8 +6,8 @@ # ... import other readout classes here # For example: -# from topobenchmarkx.models.readouts.other_readout_1 import OtherWrapper1 -# from topobenchmarkx.models.readouts.other_readout_2 import OtherWrapper2 +# from topobenchmarkx.models.wrappers.other_wrapper_1 import OtherWrapper1 +# from topobenchmarkx.models.wrappers.other_wrapper_2 import OtherWrapper2 # Export all wrappers From db353635e2081abe05e104818fe5852d9effa615 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 19:46:43 +0200 Subject: [PATCH 04/32] updated collate --- configs/dataset/NCI1.yaml | 2 +- configs/model/hypergraph/edgnn.yaml | 2 +- configs/train.yaml | 2 +- topobenchmarkx/data/dataloaders.py | 47 +++++------------------------ 4 files changed, 11 insertions(+), 42 deletions(-) diff --git a/configs/dataset/NCI1.yaml b/configs/dataset/NCI1.yaml index 60572768..622481e3 100755 --- a/configs/dataset/NCI1.yaml +++ b/configs/dataset/NCI1.yaml @@ -21,7 +21,7 @@ parameters: monitor_metric: accuracy task_level: graph data_seed: 0 - split_type: random #'k-fold' # either "k-fold" or "random" strategies + split_type: k-fold #'k-fold' # either "k-fold" or "random" strategies k: 10 # for "k-fold" Cross-Validation train_prop: 0.5 # for "random" strategy splitting diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 4a07ee48..047882cb 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -7,7 +7,7 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 16 backbone: _target_: custom_models.hypergraph.edgnn.EDGNN diff --git a/configs/train.yaml b/configs/train.yaml index cfef1ef5..5f86b9ea 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: IMDB_BINARY #us_country_demos + - dataset: NCI1 #us_country_demos - model: hypergraph/edgnn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index c58edb68..983e4310 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -9,7 +9,7 @@ from torch_sparse import SparseTensor -class MyData(Data): +class DomainData(Data): """Data object class that overwrites some methods from torch_geometric.data.Data so that not only sparse matrices with adj in the name can work with the torch_geometric dataloaders.""" @@ -59,7 +59,7 @@ def collate_fn(batch): for batch_idx, b in enumerate(batch): values, keys = b[0], b[1] - data = MyData() + data = DomainData() for key, value in zip(keys, values, strict=False): if is_sparse(value): value = value.coalesce() @@ -69,9 +69,12 @@ def collate_fn(batch): x_keys = [el for el in keys if ("x_" in el)] for x_key in x_keys: # current_number_of_nodes = data["x_0"].shape[0] + if x_key != "x_0": + if x_key != "x_hyperedges": + cell_dim = int(x_key.split("_")[1]) + else: + cell_dim = x_key.split("_")[1] - if x_key != "x_0" and x_key != "x_hyperedges": - cell_dim = int(x_key.split("_")[1]) current_number_of_cells = data[x_key].shape[0] batch_idx_dict[f"batch_{cell_dim}"].append( @@ -85,46 +88,12 @@ def collate_fn(batch): running_idx[f"cell_running_idx_number_{cell_dim}"] = ( current_number_of_cells # current_number_of_nodes ) + else: - # Make sure the idx is contiguous - data[f"x_{cell_dim}"] = ( - data[f"x_{cell_dim}"] - + running_idx[f"cell_running_idx_number_{cell_dim}"] - ).long() - running_idx[f"cell_running_idx_number_{cell_dim}"] += ( current_number_of_cells # current_number_of_nodes ) - elif x_key == "x_hyperedges": - cell_dim = x_key.split("_")[1] - current_number_of_hyperedges = data[x_key].shape[0] - - batch_idx_dict["batch_hyperedges"].append( - torch.tensor([[batch_idx] * current_number_of_hyperedges]) - ) - - if ( - running_idx.get(f"cell_running_idx_number_{cell_dim}") - is None - ): - running_idx[f"cell_running_idx_number_{cell_dim}"] = ( - current_number_of_hyperedges - ) - else: - # Make sure the idx is contiguous - data[f"x_{cell_dim}"] = ( - data[f"x_{cell_dim}"] - + running_idx[f"cell_running_idx_number_{cell_dim}"] - ).long() - - running_idx[f"cell_running_idx_number_{cell_dim}"] += ( - current_number_of_hyperedges - ) - else: - # Function Batch.from_data_list creates a running index automatically - pass - data_list.append(data) batch = Batch.from_data_list(data_list) From ec42f1e0a146b2251278ab4d4674bc5bce8f4d3d Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 20:18:13 +0200 Subject: [PATCH 05/32] fixing to_data_list --- configs/train.yaml | 2 +- topobenchmarkx/data/dataloaders.py | 15 ++++++++++----- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/configs/train.yaml b/configs/train.yaml index 5f86b9ea..e54d45e7 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -5,7 +5,7 @@ defaults: - _self_ - dataset: NCI1 #us_country_demos - - model: hypergraph/edgnn #hypergraph/unignn2 #allsettransformer + - model: simplicial/scn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index 983e4310..0b0612b0 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -30,11 +30,12 @@ def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: def to_data_list(batch): """Workaround needed since torch_geometric doesn't work well with torch.sparse.""" - for key in batch: + for key in batch.keys(): if batch[key].is_sparse: sparse_data = batch[key].coalesce() batch[key] = SparseTensor.from_torch_sparse_coo_tensor(sparse_data) data_list = batch.to_data_list() + for i, data in enumerate(data_list): for key in data: if isinstance(data[key], SparseTensor): @@ -65,10 +66,9 @@ def collate_fn(batch): value = value.coalesce() data[key] = value - # Generate batch_slice values for x_2, x_3, ... + # Generate batch_slice values for x_1, x_2, x_3, ... x_keys = [el for el in keys if ("x_" in el)] for x_key in x_keys: - # current_number_of_nodes = data["x_0"].shape[0] if x_key != "x_0": if x_key != "x_hyperedges": cell_dim = int(x_key.split("_")[1]) @@ -86,12 +86,12 @@ def collate_fn(batch): is None ): running_idx[f"cell_running_idx_number_{cell_dim}"] = ( - current_number_of_cells # current_number_of_nodes + current_number_of_cells ) else: running_idx[f"cell_running_idx_number_{cell_dim}"] += ( - current_number_of_cells # current_number_of_nodes + current_number_of_cells ) data_list.append(data) @@ -104,6 +104,11 @@ def collate_fn(batch): # Add batch slices to batch for key, value in batch_idx_dict.items(): batch[key] = torch.cat(value, dim=1).squeeze(0).long() + + # Ensure shape is torch.Tensor + # "shape" describes the number of n_cells in each graph + batch["shape"] = torch.Tensor(batch["shape"]).long() + to_data_list(batch) return batch From 7721ac94442c88adf30efc0c522fe308dafeb486 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 20:19:18 +0200 Subject: [PATCH 06/32] fixing to_data_list --- topobenchmarkx/data/dataloaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index 0b0612b0..9fbc7863 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -108,7 +108,7 @@ def collate_fn(batch): # Ensure shape is torch.Tensor # "shape" describes the number of n_cells in each graph batch["shape"] = torch.Tensor(batch["shape"]).long() - to_data_list(batch) + return batch From 34da4f8ba3ebbed08f0756e1aeeec090102fbd46 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Tue, 14 May 2024 21:51:24 +0200 Subject: [PATCH 07/32] small corrections --- topobenchmarkx/data/dataloaders.py | 4 +- .../models/head_models/old_models.py | 63 ------------------- topobenchmarkx/models/network_module.py | 21 ++++--- topobenchmarkx/models/readouts/__init__.py | 18 ++---- topobenchmarkx/models/readouts/identical.py | 14 +++++ .../models/readouts/propagate_signal_down.py | 12 ++-- topobenchmarkx/models/readouts/readout.py | 25 ++++---- .../data_manipulations/manipulations.py | 10 ++- 8 files changed, 60 insertions(+), 107 deletions(-) delete mode 100644 topobenchmarkx/models/head_models/old_models.py create mode 100644 topobenchmarkx/models/readouts/identical.py diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index 9fbc7863..bddc1014 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -107,7 +107,9 @@ def collate_fn(batch): # Ensure shape is torch.Tensor # "shape" describes the number of n_cells in each graph - batch["shape"] = torch.Tensor(batch["shape"]).long() + if batch.get("shape") is not None: + cell_statistics = batch.pop("shape") + batch["cell_statistics"] = torch.Tensor(cell_statistics).long() return batch diff --git a/topobenchmarkx/models/head_models/old_models.py b/topobenchmarkx/models/head_models/old_models.py deleted file mode 100644 index ca5ff2b6..00000000 --- a/topobenchmarkx/models/head_models/old_models.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from torch_geometric.utils import scatter - - -class DefaultHead(torch.nn.Module): - r"""Head model. - - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - task_level: str, - pooling_type: str = "sum", - **kwargs, - ): - super().__init__() - self.linear = torch.nn.Linear(in_channels, out_channels) - - assert task_level in ["graph", "node"], "Invalid task_level" - self.task_level = task_level - - assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" - self.pooling_type = pooling_type - - def forward(self, model_out: dict, batch): - r"""Forward pass. - - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - - Returns - ------- - dict - Dictionary containing the updated model output. Resulting key is "logits". - """ - x = model_out["x_0"] - batch = model_out["batch_0"] - if self.task_level == "graph": - if self.pooling_type == "max": - x = scatter(x, batch, dim=0, reduce="max") - - elif self.pooling_type == "mean": - x = scatter(x, batch, dim=0, reduce="mean") - - elif self.pooling_type == "sum": - x = scatter(x, batch, dim=0, reduce="sum") - - model_out["logits"] = self.linear(x) - return model_out diff --git a/topobenchmarkx/models/network_module.py b/topobenchmarkx/models/network_module.py index ec861871..989e4c6e 100755 --- a/topobenchmarkx/models/network_module.py +++ b/topobenchmarkx/models/network_module.py @@ -58,7 +58,7 @@ def __init__( self.metric_collector_val2 = [] self.metric_collector_test = [] - def forward(self, batch) -> dict: + def forward(self, batch: Data) -> dict: """Perform a forward pass through the model `self.backbone`. :param x: A tensor of images. @@ -67,7 +67,7 @@ def forward(self, batch) -> dict: return self.backbone(batch) def model_step( - self, batch + self, batch: Data ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform a single model step on a batch of data. @@ -78,22 +78,29 @@ def model_step( - A tensor of predictions. - A tensor of target labels. """ - # Pipeline - if self.feature_encoder: - batch = self.feature_encoder(batch) + # Feature Encoder + batch = self.feature_encoder(batch) + + # Domain model model_out = self.forward(batch) + + # Readout model_out = self.readout(model_out=model_out, batch=batch) + + # Head model model_out = self.head_model(model_out=model_out, batch=batch) - # Loss and metric + # Loss model_out = self.process_outputs(model_out=model_out, batch=batch) + + # Metric model_out = self.loss(model_out=model_out, batch=batch) self.evaluator.update(model_out) return model_out - def training_step(self, batch, batch_idx: int) -> torch.Tensor: + def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: """Perform a single training step on a batch of data from the training set. diff --git a/topobenchmarkx/models/readouts/__init__.py b/topobenchmarkx/models/readouts/__init__.py index 3038e843..4280f106 100644 --- a/topobenchmarkx/models/readouts/__init__.py +++ b/topobenchmarkx/models/readouts/__init__.py @@ -1,26 +1,18 @@ -from topobenchmarkx.models.readouts.propagate_signal_down import ( - PropagateSignalDown, -) +from topobenchmarkx.models.readouts.readout import AbstractReadOut +from topobenchmarkx.models.readouts.propagate_signal_down import PropagateSignalDown +from topobenchmarkx.models.readouts.identical import NoReadOut # ... import other readout classes here # For example: # from topobenchmarkx.models.readouts.other_readout_1 import OtherReadout1 # from topobenchmarkx.models.readouts.other_readout_2 import OtherReadout2 - -# Dictionary of all readouts -READOUTS = { - "PropagateSignalDown": PropagateSignalDown, - # "OtherReadout1": OtherReadout1, - # "OtherReadout2": OtherReadout2, - # ... add other readout mappings here -} - # Export all readouts and the dictionary __all__ = [ + "AbstractReadOut" "PropagateSignalDown", + "NoReadOut" # "OtherReadout1", # "OtherReadout2", # ... add other readout classes here - "READOUTS", ] diff --git a/topobenchmarkx/models/readouts/identical.py b/topobenchmarkx/models/readouts/identical.py new file mode 100644 index 00000000..639fb082 --- /dev/null +++ b/topobenchmarkx/models/readouts/identical.py @@ -0,0 +1,14 @@ + +import torch_geometric +from topobenchmarkx.models.readouts.readout import AbstractReadOut + + +class NoReadOut(AbstractReadOut): + def __init__(self, **kwargs): + super().__init__() + + def forward(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + return model_out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(num_cell_dimensions={len(self.dimensions)}, hidden_dim={self.hidden_dim}, readout_name={self.name}" diff --git a/topobenchmarkx/models/readouts/propagate_signal_down.py b/topobenchmarkx/models/readouts/propagate_signal_down.py index 95471249..dd4c3ede 100644 --- a/topobenchmarkx/models/readouts/propagate_signal_down.py +++ b/topobenchmarkx/models/readouts/propagate_signal_down.py @@ -1,8 +1,9 @@ -import topomodelx import torch +import torch_geometric +import topomodelx +from topobenchmarkx.models.readouts.readout import AbstractReadOut - -class PropagateSignalDown(torch.nn.Module): +class PropagateSignalDown(AbstractReadOut): def __init__(self, **kwargs): super().__init__() @@ -27,10 +28,7 @@ def __init__(self, **kwargs): torch.nn.Linear(2 * hidden_dim, hidden_dim), ) - def __call__(self, model_out, batch): - return self.forward(model_out, batch) - - def forward(self, model_out, batch): + def forward(self, model_out: dict, batch: torch_geometric.data.Data): for i in self.dimensions: x_i = getattr(self, f"agg_conv_{i}")( model_out[f"x_{i}"], batch[f"incidence_{i}"] diff --git a/topobenchmarkx/models/readouts/readout.py b/topobenchmarkx/models/readouts/readout.py index 7d618cab..7a90bdef 100755 --- a/topobenchmarkx/models/readouts/readout.py +++ b/topobenchmarkx/models/readouts/readout.py @@ -1,8 +1,7 @@ import torch import torch_geometric -from . import READOUTS - +from abc import abstractmethod class AbstractReadOut(torch.nn.Module): r"""Readout layer for GNNs that operates on the batch level. @@ -19,14 +18,14 @@ class AbstractReadOut(torch.nn.Module): Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. """ - def __init__(self, **kwargs): + def __init__(self,): super().__init__() - self.signal_readout = kwargs["readout_name"] != "None" - if self.signal_readout: - signal_readout_name = kwargs.get("readout_name") - self.readout = READOUTS[signal_readout_name](**kwargs) - + def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + """Readout logic based on model_output.""" + return self.forward(model_out, batch) + + @abstractmethod def forward(self, model_out: dict, batch: torch_geometric.data.Data): r"""Forward pass. @@ -34,14 +33,12 @@ def forward(self, model_out: dict, batch: torch_geometric.data.Data): ---------- model_out: dict Dictionary containing the model output. + + batch: torch_geometric.data.Data + Batch object containing the batched domain data. Returns ------- dict Dictionary containing the updated model output. Resulting key is "logits". - """ - # Propagate signal - if self.signal_readout: - model_out = self.readout(model_out, batch) - - return model_out + """ \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/manipulations.py b/topobenchmarkx/transforms/data_manipulations/manipulations.py index b75e1509..7d3ebb63 100644 --- a/topobenchmarkx/transforms/data_manipulations/manipulations.py +++ b/topobenchmarkx/transforms/data_manipulations/manipulations.py @@ -180,7 +180,7 @@ def forward(self, data: torch_geometric.data.Data): """ field_to_process = [ key - for key in data + for key in data.keys() for field_substring in self.parameters["selected_fields"] if field_substring in key and key != "incidence_0" ] @@ -211,10 +211,16 @@ def calculate_node_degrees( assert ( field == "edge_index" ), "Following logic of finding degrees is only implemented for edge_index" + + # Get number of nodes + if data.get("num_nodes", None): + max_num_nodes = data["num_nodes"] + else: + max_num_nodes = data["x"].shape[0] degrees = ( torch_geometric.utils.to_dense_adj( data[field], - max_num_nodes=data["x"].shape[0], # data["num_nodes"] + max_num_nodes=max_num_nodes, ) .squeeze(0) .sum(1) From 9fc0a7fd53f6ff088821ffad3696a4728351aa18 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Tue, 14 May 2024 21:46:24 +0000 Subject: [PATCH 08/32] added pythonpath for pytest --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 6d1e47be..caf89fad 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,6 +151,10 @@ ignore_missing_imports = true [tool.pytest.ini_options] addopts = "--capture=no" +pythonpath = [ + "." +] + [tool.numpydoc_validation] checks = [ From 8d41451ed59e2036eea57f3f9abd96e8516315b6 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Tue, 14 May 2024 21:47:29 +0000 Subject: [PATCH 09/32] to_data_list works --- topobenchmarkx/data/dataloaders.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index bddc1014..2578c4af 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -35,11 +35,10 @@ def to_data_list(batch): sparse_data = batch[key].coalesce() batch[key] = SparseTensor.from_torch_sparse_coo_tensor(sparse_data) data_list = batch.to_data_list() - for i, data in enumerate(data_list): - for key in data: + for key, d in data: if isinstance(data[key], SparseTensor): - data_list[i][key] = data[key].to_torch_sparse_coo_tensor() + data_list[i][key] = d.to_torch_sparse_coo_tensor() return data_list From 79cf1b656f35d03ea8acace42fb6508ad5940e06 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Tue, 14 May 2024 21:48:37 +0000 Subject: [PATCH 10/32] added test for collate_function --- test/data/test_Dataloaders.py | 115 ++++++++++++++++++++++++++++++++++ 1 file changed, 115 insertions(+) create mode 100644 test/data/test_Dataloaders.py diff --git a/test/data/test_Dataloaders.py b/test/data/test_Dataloaders.py new file mode 100644 index 00000000..c595ed13 --- /dev/null +++ b/test/data/test_Dataloaders.py @@ -0,0 +1,115 @@ +"""Test the collate function.""" +import hydra +from hydra import compose, initialize +from omegaconf import OmegaConf + +import torch + +from topobenchmarkx.data.dataloaders import to_data_list, DefaultDataModule + +from topobenchmarkx.utils.config_resolvers import ( + get_default_transform, + get_monitor_metric, + get_monitor_mode, + infer_in_channels, +) + +import rootutils + +rootutils.setup_root("./", indicator=".project-root", pythonpath=True) + +class TestCollateFunction: + """Test collate_fn.""" + + def setup_method(self): + """Setup the test. + + For this test we load the MUTAG dataset. + + Parameters + ---------- + None + """ + OmegaConf.register_new_resolver("get_default_transform", get_default_transform) + OmegaConf.register_new_resolver("get_monitor_metric", get_monitor_metric) + OmegaConf.register_new_resolver("get_monitor_mode", get_monitor_mode) + OmegaConf.register_new_resolver("infer_in_channels", infer_in_channels) + OmegaConf.register_new_resolver( + "parameter_multiplication", lambda x, y: int(int(x) * int(y)) + ) + + initialize(version_base="1.3", config_path="../../configs", job_name="job") + cfg = compose(config_name="train.yaml") + + graph_loader = hydra.utils.instantiate(cfg.dataset, _recursive_=False) + datasets = graph_loader.load() + self.batch_size = 2 + datamodule = DefaultDataModule( + dataset_train=datasets[0], + dataset_val=datasets[1], + dataset_test=datasets[2], + batch_size=self.batch_size + ) + self.val_dataloader = datamodule.val_dataloader() + self.val_dataset = datasets[1] + + def test_lift_features(self): + """Test the collate funciton. + + To test the collate function we use the DefaultDataModule class to create a dataloader that uses the collate function. We then first check that the batched data has the expected shape. We then convert the batched data back to a list and check that the data in the list is the same as the original data. + + Parameters + ---------- + None + """ + def check_separation(matrix, n_elems_0_row, n_elems_0_col): + """Check that the matrix is separated into two parts diagonally concatenated.""" + assert torch.all(matrix[:n_elems_0_row, n_elems_0_col:] == 0) + assert torch.all(matrix[n_elems_0_row:, :n_elems_0_col] == 0) + + + batch = next(iter(self.val_dataloader)) + elems = [] + for i in range(self.batch_size): + elems.append(self.val_dataset.data_lst[i]) + + # Check that the batched data has the expected shape + for key in batch.keys(): + if key in elems[0].keys(): + if 'x_' in key or 'x'==key: + assert batch[key].shape[0] == elems[0][key].shape[0]+elems[1][key].shape[0] + assert batch[key].shape[1] == elems[0][key].shape[1] + elif 'edge_index' in key: + assert batch[key].shape[0] == 2 + assert batch[key].shape[1] == elems[0][key].shape[1]+elems[1][key].shape[1] + else: + for i in range(len(batch[key].shape)): + assert batch[key].shape[i] == elems[0][key].shape[i]+elems[1][key].shape[i] + else: + if 'batch_' in key: + i = int(key.split('_')[1]) + assert batch[key].shape[0] == elems[0][f'x_{i}'].shape[0]+elems[1][f'x_{i}'].shape[0] + + # Check that the batched data is separated correctly + for key in batch.keys(): + if 'incidence_' in key: + i = int(key.split('_')[1]) + if i==0: + n0_row = 1 + else: + n0_row = torch.sum(batch[f'batch_{i-1}']==0) + n0_col = torch.sum(batch[f'batch_{i}']==0) + check_separation(batch[key].to_dense(), n0_row, n0_col) + + # Check that going back to a list of data gives the same data + batch_list = to_data_list(batch) + assert len(batch_list) == len(elems) + for i in range(len(batch_list)): + for key in elems[i].keys(): + if key in batch_list[i].keys(): + if batch_list[i][key].is_sparse: + assert torch.all(batch_list[i][key].coalesce().indices() == elems[i][key].coalesce().indices()) + assert torch.allclose(batch_list[i][key].coalesce().values(), elems[i][key].coalesce().values()) + assert batch_list[i][key].shape, elems[i][key].shape + else: + assert torch.allclose(batch_list[i][key], elems[i][key]) \ No newline at end of file From 84320f145160ba2e9284bdc7a744ce6e47dba472 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Wed, 15 May 2024 00:23:00 +0200 Subject: [PATCH 11/32] cleaning --- .../transforms/data_manipulations/node_degrees.yaml | 2 +- configs/model/cell/can.yaml | 4 ++-- configs/model/cell/{cwn_dcm.yaml => cccn.yaml} | 10 +++++----- configs/model/cell/ccxn.yaml | 4 ++-- configs/model/cell/cwn.yaml | 5 ++--- configs/model/graph/gat.yaml | 4 ++-- configs/model/graph/gcn.yaml | 4 ++-- configs/model/graph/gin.yaml | 4 ++-- configs/model/hypergraph/alldeepset.yaml | 4 ++-- configs/model/hypergraph/allsettransformer.yaml | 4 ++-- configs/model/hypergraph/edgnn.yaml | 4 ++-- configs/model/hypergraph/unignn.yaml | 4 ++-- configs/model/hypergraph/unignn2.yaml | 4 ++-- configs/model/simplicial/san.yaml | 4 ++-- configs/model/simplicial/sccn.yaml | 4 ++-- configs/model/simplicial/sccnn.yaml | 4 ++-- configs/model/simplicial/sccnn_custom.yaml | 7 +++---- configs/model/simplicial/scn.yaml | 4 ++-- configs/train.yaml | 4 ++-- custom_models/cell/{cwn_dcm.py => cccn.py} | 2 +- topobenchmarkx/models/wrappers/__init__.py | 4 ++-- topobenchmarkx/models/wrappers/cell/__init__.py | 4 ++-- .../cell/{cwndcm_wrapper.py => cccn_wrapper.py} | 2 +- 23 files changed, 47 insertions(+), 49 deletions(-) rename configs/model/cell/{cwn_dcm.yaml => cccn.yaml} (83%) rename custom_models/cell/{cwn_dcm.py => cccn.py} (98%) rename topobenchmarkx/models/wrappers/cell/{cwndcm_wrapper.py => cccn_wrapper.py} (94%) diff --git a/configs/dataset/transforms/data_manipulations/node_degrees.yaml b/configs/dataset/transforms/data_manipulations/node_degrees.yaml index c1775453..1d666d32 100755 --- a/configs/dataset/transforms/data_manipulations/node_degrees.yaml +++ b/configs/dataset/transforms/data_manipulations/node_degrees.yaml @@ -1,5 +1,5 @@ _target_: topobenchmarkx.transforms.data_transform.DataTransform transform_name: "NodeDegrees" transform_type: "data manipulation" -selected_fields: ["edge_index", "incidence"] #"incidence" +selected_fields: ["edge_index"] # "incidence" diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index a8f3495a..39df8535 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -32,8 +32,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/cell/cwn_dcm.yaml b/configs/model/cell/cccn.yaml similarity index 83% rename from configs/model/cell/cwn_dcm.yaml rename to configs/model/cell/cccn.yaml index c411e025..ad196c1d 100755 --- a/configs/model/cell/cwn_dcm.yaml +++ b/configs/model/cell/cccn.yaml @@ -14,21 +14,21 @@ feature_encoder: - 1 backbone: - _target_: custom_models.cell.cwn_dcm.CWNDCM + _target_: custom_models.cell.cccn.CCCN in_channels: ${model.feature_encoder.out_channels} n_layers: 1 dropout: 0.0 backbone_wrapper: - _target_: topobenchmarkx.models.wrappers.CWNDCMWrapper + _target_: topobenchmarkx.models.wrappers.CCCNWrapper _partial_: true - wrapper_name: CWNDCMWrapper + wrapper_name: CCCNWrapper out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/cell/ccxn.yaml b/configs/model/cell/ccxn.yaml index 783f92c5..38a884ec 100755 --- a/configs/model/cell/ccxn.yaml +++ b/configs/model/cell/ccxn.yaml @@ -28,8 +28,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/cell/cwn.yaml b/configs/model/cell/cwn.yaml index 7d5e4edf..5a140b36 100755 --- a/configs/model/cell/cwn.yaml +++ b/configs/model/cell/cwn.yaml @@ -25,10 +25,9 @@ backbone_wrapper: out_channels: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} - readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/graph/gat.yaml b/configs/model/graph/gat.yaml index fed38c36..65e353db 100755 --- a/configs/model/graph/gat.yaml +++ b/configs/model/graph/gat.yaml @@ -28,8 +28,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index 7bb1c695..abefe959 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -25,8 +25,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/graph/gin.yaml b/configs/model/graph/gin.yaml index 4b9ae61d..e0920399 100755 --- a/configs/model/graph/gin.yaml +++ b/configs/model/graph/gin.yaml @@ -25,8 +25,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/hypergraph/alldeepset.yaml b/configs/model/hypergraph/alldeepset.yaml index 2780cea9..99b0e3d9 100755 --- a/configs/model/hypergraph/alldeepset.yaml +++ b/configs/model/hypergraph/alldeepset.yaml @@ -33,8 +33,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 817e1d8f..5b5a3e18 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -27,8 +27,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 047882cb..6f7d6235 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -28,8 +28,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown, NoReadOut hidden_dim: None num_cell_dimensions: None diff --git a/configs/model/hypergraph/unignn.yaml b/configs/model/hypergraph/unignn.yaml index 91fbc1f0..68d71223 100755 --- a/configs/model/hypergraph/unignn.yaml +++ b/configs/model/hypergraph/unignn.yaml @@ -23,8 +23,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index ca080e27..1d216496 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -27,8 +27,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: None # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: None num_cell_dimensions: None diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index 46aa3709..61c63567 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -29,8 +29,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index 09d3f4cf..82e5592d 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -23,8 +23,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/simplicial/sccnn.yaml b/configs/model/simplicial/sccnn.yaml index a02853f0..fdfc824c 100755 --- a/configs/model/simplicial/sccnn.yaml +++ b/configs/model/simplicial/sccnn.yaml @@ -37,8 +37,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/simplicial/sccnn_custom.yaml b/configs/model/simplicial/sccnn_custom.yaml index cfa62f75..4bd0ebdc 100755 --- a/configs/model/simplicial/sccnn_custom.yaml +++ b/configs/model/simplicial/sccnn_custom.yaml @@ -7,13 +7,12 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 64 + out_channels: 32 selected_dimensions: - 0 - 1 - 2 - backbone: _target_: custom_models.simplicial.sccnn.SCCNNCusctom in_channels_all: @@ -38,8 +37,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/model/simplicial/scn.yaml b/configs/model/simplicial/scn.yaml index 78adb09d..cb74351f 100755 --- a/configs/model/simplicial/scn.yaml +++ b/configs/model/simplicial/scn.yaml @@ -28,8 +28,8 @@ backbone_wrapper: num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} readout: - _target_: topobenchmarkx.models.readouts.readout.AbstractReadOut - readout_name: PropagateSignalDown # Use in case readout is not needed + _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} + readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.selected_dimensions}} diff --git a/configs/train.yaml b/configs/train.yaml index e54d45e7..8de14308 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,8 +4,8 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: NCI1 #us_country_demos - - model: simplicial/scn #hypergraph/unignn2 #allsettransformer + - dataset: IMDB-BINARY #us_country_demos + - model: simplicial/sccnn_custom #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/custom_models/cell/cwn_dcm.py b/custom_models/cell/cccn.py similarity index 98% rename from custom_models/cell/cwn_dcm.py rename to custom_models/cell/cccn.py index eaa81cb9..c1e936a7 100644 --- a/custom_models/cell/cwn_dcm.py +++ b/custom_models/cell/cccn.py @@ -34,7 +34,7 @@ def forward(self, xe, Lu, Ld): return z_h + z_s + z_i -class CWNDCM(nn.Module): +class CCCN(nn.Module): def __init__(self, in_channels, n_layers=2, dropout=0.0, last_act=False): super().__init__() self.d = dropout diff --git a/topobenchmarkx/models/wrappers/__init__.py b/topobenchmarkx/models/wrappers/__init__.py index 4c2099e7..99d31d5c 100755 --- a/topobenchmarkx/models/wrappers/__init__.py +++ b/topobenchmarkx/models/wrappers/__init__.py @@ -2,7 +2,7 @@ from topobenchmarkx.models.wrappers.graph import GNNWrapper from topobenchmarkx.models.wrappers.hypergraph import HypergraphWrapper from topobenchmarkx.models.wrappers.simplicial import SANWrapper, SCNWrapper, SCCNNWrapper, SCCNWrapper -from topobenchmarkx.models.wrappers.cell import CANWrapper, CWNDCMWrapper, CWNWrapper, CCXNWrapper +from topobenchmarkx.models.wrappers.cell import CANWrapper, CCCNWrapper, CWNWrapper, CCXNWrapper # ... import other readout classes here # For example: @@ -20,7 +20,7 @@ "SCCNNWrapper", "SCCNWrapper", "CANWrapper", - "CWNDCMWrapper", + "CCCNWrapper", "CWNWrapper", "CCXNWrapper", # "OtherWrapper1", diff --git a/topobenchmarkx/models/wrappers/cell/__init__.py b/topobenchmarkx/models/wrappers/cell/__init__.py index 06af3421..efa0e5d9 100644 --- a/topobenchmarkx/models/wrappers/cell/__init__.py +++ b/topobenchmarkx/models/wrappers/cell/__init__.py @@ -1,5 +1,5 @@ from topobenchmarkx.models.wrappers.cell.can_wrapper import CANWrapper -from topobenchmarkx.models.wrappers.cell.cwndcm_wrapper import CWNDCMWrapper +from topobenchmarkx.models.wrappers.cell.cccn_wrapper import CCCNWrapper from topobenchmarkx.models.wrappers.cell.cwn_wrapper import CWNWrapper from topobenchmarkx.models.wrappers.cell.ccxn_wrapper import CCXNWrapper @@ -10,7 +10,7 @@ __all__ = [ "CANWrapper", - "CWNDCMWrapper", + "CCCNWrapper", "CWNWrapper", "CCXNWrapper", diff --git a/topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py b/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py similarity index 94% rename from topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py rename to topobenchmarkx/models/wrappers/cell/cccn_wrapper.py index 3439282b..97638300 100644 --- a/topobenchmarkx/models/wrappers/cell/cwndcm_wrapper.py +++ b/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py @@ -1,7 +1,7 @@ import torch from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper -class CWNDCMWrapper(DefaultWrapper): +class CCCNWrapper(DefaultWrapper): """Abstract class that provides an interface to loss logic within network.""" From 3b87ee937f44e9f9c28d7358e2ca839c29876d76 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Wed, 15 May 2024 01:50:14 +0200 Subject: [PATCH 12/32] cleaning --- topobenchmarkx/__init__.py | 21 + topobenchmarkx/data/cornel_dataset.ipynb | 432 ------------------ topobenchmarkx/data/datasets.py | 40 -- topobenchmarkx/models/__init__.py | 11 + .../{network_module.py => default_network.py} | 7 +- topobenchmarkx/transforms/__init__.py | 62 +++ .../transforms/data_manipulations/__init__.py | 26 ++ .../calculate_simplicial_curvature.py | 112 +++++ .../data_manipulations/equal_gaus_features.py | 40 ++ .../data_manipulations/identity_transform.py | 24 + .../infere_knn_connectivity.py | 30 ++ .../infere_radius_connectivity.py | 26 ++ .../keep_only_connected_component.py | 40 ++ .../keep_selected_data_fields.py | 39 ++ .../data_manipulations/node_degrees.py | 87 ++++ .../node_features_to_float.py | 30 ++ .../data_manipulations/one_hot_degree.py | 65 +++ .../one_hot_degree_features.py | 42 ++ topobenchmarkx/transforms/data_transform.py | 57 +-- .../transforms/feature_liftings/__init__.py | 11 + .../transforms/liftings/__init__.py | 18 + 21 files changed, 687 insertions(+), 533 deletions(-) delete mode 100644 topobenchmarkx/data/cornel_dataset.ipynb rename topobenchmarkx/models/{network_module.py => default_network.py} (98%) create mode 100644 topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py create mode 100644 topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py create mode 100644 topobenchmarkx/transforms/data_manipulations/identity_transform.py create mode 100644 topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py create mode 100644 topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py create mode 100644 topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py create mode 100644 topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py create mode 100644 topobenchmarkx/transforms/data_manipulations/node_degrees.py create mode 100644 topobenchmarkx/transforms/data_manipulations/node_features_to_float.py create mode 100644 topobenchmarkx/transforms/data_manipulations/one_hot_degree.py create mode 100644 topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py diff --git a/topobenchmarkx/__init__.py b/topobenchmarkx/__init__.py index 895b4d03..04fac166 100755 --- a/topobenchmarkx/__init__.py +++ b/topobenchmarkx/__init__.py @@ -1,3 +1,24 @@ + +# Import submodules +from . import data +from . import evaluators +from . import hp_scripts +from . import io +from . import models +from . import transforms +from . import utils + +__all__ = [ + "data", + "evaluators", + "hp_scripts", + "io", + "models", + "transforms", + "utils", +] + + __version__ = "0.0.1" # from .io import * diff --git a/topobenchmarkx/data/cornel_dataset.ipynb b/topobenchmarkx/data/cornel_dataset.ipynb deleted file mode 100644 index e6eeae39..00000000 --- a/topobenchmarkx/data/cornel_dataset.ipynb +++ /dev/null @@ -1,432 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "\n", - "# Add manually root '/home/lev/projects/TopoBenchmarkX'\n", - "root_path = \"/home/lev/projects/TopoBenchmarkX\"\n", - "if root_path not in sys.path:\n", - " sys.path.append(root_path)\n", - "\n", - "import os.path as osp\n", - "from collections.abc import Callable\n", - "\n", - "from torch_geometric.data import Data, InMemoryDataset\n", - "from torch_geometric.io import fs\n", - "\n", - "from topobenchmarkx.io.load.download_utils import download_file_from_drive\n", - "\n", - "\n", - "class CornelDataset(InMemoryDataset):\n", - " r\"\"\" \"\"\"\n", - "\n", - " URLS = {\n", - " # 'contact-high-school': 'https://drive.google.com/open?id=1VA2P62awVYgluOIh1W4NZQQgkQCBk-Eu',\n", - " \"US-county-demos\": \"https://drive.google.com/file/d/1FNF_LbByhYNICPNdT6tMaJI9FxuSvvLK/view?usp=sharing\",\n", - " }\n", - "\n", - " FILE_FORMAT = {\n", - " # 'contact-high-school': 'tar.gz',\n", - " \"US-county-demos\": \"zip\",\n", - " }\n", - "\n", - " RAW_FILE_NAMES = {}\n", - "\n", - " def __init__(\n", - " self,\n", - " root: str,\n", - " name: str,\n", - " parameters: dict = None,\n", - " transform: Callable | None = None,\n", - " pre_transform: Callable | None = None,\n", - " pre_filter: Callable | None = None,\n", - " force_reload: bool = True,\n", - " use_node_attr: bool = False,\n", - " use_edge_attr: bool = False,\n", - " ) -> None:\n", - " self.name = name.replace(\"_\", \"-\")\n", - "\n", - " super().__init__(\n", - " root, transform, pre_transform, pre_filter, force_reload=force_reload\n", - " )\n", - "\n", - " # Step 3:Load the processed data\n", - " # After the data has been downloaded from source\n", - " # Then preprocessed to obtain x,y and saved into processed folder\n", - " # We can now load the processed data from processed folder\n", - "\n", - " # Load the processed data\n", - " data, _, _ = fs.torch_load(self.processed_paths[0])\n", - "\n", - " # Map the loaded data into\n", - " data = Data.from_dict(data)\n", - "\n", - " # Step 5: Create the splits and upload desired fold\n", - "\n", - " # split_idx = random_splitting(data.y, parameters=self.parameters)\n", - "\n", - " # Assign data object to self.data, to make it be prodessed by Dataset class\n", - " self.data = data\n", - "\n", - " @property\n", - " def raw_dir(self) -> str:\n", - " return osp.join(self.root, self.name, \"raw\")\n", - "\n", - " @property\n", - " def processed_dir(self) -> str:\n", - " return osp.join(self.root, self.name, \"processed\")\n", - "\n", - " @property\n", - " def raw_file_names(self) -> list[str]:\n", - " names = [\"\", \"_2012\"]\n", - " return [f\"{self.name}_{name}.txt\" for name in names]\n", - "\n", - " @property\n", - " def processed_file_names(self) -> str:\n", - " return \"data.pt\"\n", - "\n", - " def download(self) -> None:\n", - " \"\"\"\n", - " Downloads the dataset from the specified URL and saves it to the raw directory.\n", - "\n", - " Raises:\n", - " FileNotFoundError: If the dataset URL is not found.\n", - " \"\"\"\n", - "\n", - " # Step 1: Download data from the source\n", - " self.url = self.URLS[self.name]\n", - " self.file_format = self.FILE_FORMAT[self.name]\n", - "\n", - " download_file_from_drive(\n", - " file_link=self.url,\n", - " path_to_save=self.raw_dir,\n", - " dataset_name=self.name,\n", - " file_format=self.file_format,\n", - " )\n", - "\n", - " # Extract the downloaded file if it is compressed\n", - " fs.cp(\n", - " f\"{self.raw_dir}/{self.name}.{self.file_format}\", self.raw_dir, extract=True\n", - " )\n", - "\n", - " # Move the etracted files to the datasets/domain/dataset_name/raw/ directory\n", - " for filename in fs.ls(osp.join(self.raw_dir, self.name)):\n", - " fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))\n", - " fs.rm(osp.join(self.raw_dir, self.name))\n", - "\n", - " # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}'\n", - " fs.rm(f\"{self.raw_dir}/{self.name}.{self.file_format}\")\n", - "\n", - " def process(self) -> None:\n", - " \"\"\"\n", - " Process the data for the dataset.\n", - "\n", - " This method loads the US county demographics data, applies any pre-processing transformations if specified,\n", - " and saves the processed data to the appropriate location.\n", - "\n", - " Returns:\n", - " None\n", - " \"\"\"\n", - " data = load_us_county_demos(self.raw_dir, self.name)\n", - "\n", - " data = data if self.pre_transform is None else self.pre_transform(data)\n", - " self.save([data], self.processed_paths[0])\n", - "\n", - " def __repr__(self) -> str:\n", - " return f\"{self.name}({len(self)})\"" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pandas as pd\n", - "import torch\n", - "import torch_geometric\n", - "\n", - "\n", - "def load_us_county_demos(path, dataset_name, year=2012):\n", - " edges_df = pd.read_csv(f\"{path}/county_graph.csv\")\n", - " stat = pd.read_csv(f\"{path}/county_stats_{year}.csv\", encoding=\"ISO-8859-1\")\n", - "\n", - " keep_cols = [\n", - " \"FIPS\",\n", - " \"DEM\",\n", - " \"GOP\",\n", - " \"MedianIncome\",\n", - " \"MigraRate\",\n", - " \"BirthRate\",\n", - " \"DeathRate\",\n", - " \"BachelorRate\",\n", - " \"UnemploymentRate\",\n", - " ]\n", - " # Drop rows with missing values\n", - " stat = stat[keep_cols].dropna()\n", - "\n", - " # Delete edges that are not present in stat df\n", - " unique_fips = stat[\"FIPS\"].unique()\n", - "\n", - " src_ = edges_df[\"SRC\"].apply(lambda x: x in unique_fips)\n", - " dst_ = edges_df[\"DST\"].apply(lambda x: x in unique_fips)\n", - "\n", - " edges_df = edges_df[src_ & dst_]\n", - "\n", - " # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present\n", - " stat = stat[stat[\"FIPS\"].isin(edges_df[\"SRC\"]) & stat[\"FIPS\"].isin(edges_df[\"DST\"])]\n", - " stat = stat.reset_index(drop=True)\n", - "\n", - " # Remove rows where SRC == DST\n", - " edges_df = edges_df[edges_df[\"SRC\"] != edges_df[\"DST\"]]\n", - "\n", - " # Get torch_geometric edge_index format\n", - " edge_index = torch.tensor(\n", - " np.stack([edges_df[\"SRC\"].to_numpy(), edges_df[\"DST\"].to_numpy()])\n", - " )\n", - "\n", - " # Make edge_index undirected\n", - " edge_index = torch_geometric.utils.to_undirected(edge_index)\n", - "\n", - " # Convert edge_index back to pandas DataFrame\n", - " edges_df = pd.DataFrame(edge_index.numpy().T, columns=[\"SRC\", \"DST\"])\n", - "\n", - " del edge_index\n", - "\n", - " # Map stat['FIPS'].unique() to [0, ..., num_nodes]\n", - " fips_map = {fips: i for i, fips in enumerate(stat[\"FIPS\"].unique())}\n", - " stat[\"FIPS\"] = stat[\"FIPS\"].map(fips_map)\n", - "\n", - " # Map edges_df['SRC'] and edges_df['DST'] to [0, ..., num_nodes]\n", - " edges_df[\"SRC\"] = edges_df[\"SRC\"].map(fips_map)\n", - " edges_df[\"DST\"] = edges_df[\"DST\"].map(fips_map)\n", - "\n", - " # Get torch_geometric edge_index format\n", - " edge_index = torch.tensor(\n", - " np.stack([edges_df[\"SRC\"].to_numpy(), edges_df[\"DST\"].to_numpy()])\n", - " )\n", - "\n", - " # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically)\n", - " edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index)\n", - "\n", - " # Conver mask to index\n", - " index = np.arange(mask.size(0))[mask]\n", - " stat = stat.iloc[index]\n", - " stat = stat.reset_index(drop=True)\n", - "\n", - " # Get new values for FIPS from current index\n", - " # To understand why please print stat.iloc[[516, 517, 518, 519, 520]] for 2012 year\n", - " # Basically the FIPS values has been shifted\n", - " stat[\"FIPS\"] = stat.reset_index()[\"index\"]\n", - "\n", - " # Create Election variable\n", - " stat[\"Election\"] = (stat[\"DEM\"] - stat[\"GOP\"]) / (stat[\"DEM\"] + stat[\"GOP\"])\n", - "\n", - " # Drop DEM and GOP columns and FIPS\n", - " stat = stat.drop(columns=[\"DEM\", \"GOP\", \"FIPS\"])\n", - "\n", - " # Prediction col\n", - " y_col = \"Election\" # TODO: Define through config file\n", - " x_col = list(set(stat.columns).difference(set([y_col])))\n", - "\n", - " stat[\"MedianIncome\"] = (\n", - " stat[\"MedianIncome\"]\n", - " .apply(lambda x: x.replace(\",\", \"\"))\n", - " .to_numpy()\n", - " .astype(float)\n", - " )\n", - "\n", - " x = stat[x_col].to_numpy()\n", - " y = stat[y_col].to_numpy()\n", - "\n", - " data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)\n", - "\n", - " return data" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Download complete.\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Processing...\n", - "Done!\n" - ] - } - ], - "source": [ - "a = CornelDataset(\n", - " root=\"/home/lev/projects/TopoBenchmarkX/datasets/graph\", name=\"US-county-demos\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "parameters = {\n", - " \"data_seed\": 0,\n", - " \"data_split_dir\": \"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos\",\n", - " \"train_prop\": 0.5,\n", - "}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(3107, 6)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "a[0].x.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'dict' object has no attribute 'k'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[35], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m a \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mk\u001b[39m\u001b[38;5;124m'\u001b[39m:\u001b[38;5;241m1\u001b[39m}\n\u001b[0;32m----> 2\u001b[0m a\u001b[38;5;241m.\u001b[39mk\n", - "\u001b[0;31mAttributeError\u001b[0m: 'dict' object has no attribute 'k'" - ] - } - ], - "source": [ - "a = {\"k\": 1}\n", - "a.k" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "stat = pd.read_csv(\n", - " \"/home/lev/projects/TopoBenchmarkX/datasets/graph/US-county-demos-2012/raw/US-county-demos/county_stats_2016.csv\",\n", - " encoding=\"ISO-8859-1\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Index(['FIPS', 'County', 'DEM', 'GOP', 'MedianIncome', 'MigraRate',\n", - " 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'],\n", - " dtype='object')" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "stat.columns" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "('Election',\n", - " 'MedianIncome',\n", - " 'MigraRate',\n", - " 'BirthRate',\n", - " 'DeathRate',\n", - " 'BachelorRate',\n", - " 'UnemploymentRate')" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "(\n", - " \"Election\",\n", - " \"MedianIncome\",\n", - " \"MigraRate\",\n", - " \"BirthRate\",\n", - " \"DeathRate\",\n", - " \"BachelorRate\",\n", - " \"UnemploymentRate\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "topo", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.3" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/topobenchmarkx/data/datasets.py b/topobenchmarkx/data/datasets.py index 0ed9159e..9ad8c047 100644 --- a/topobenchmarkx/data/datasets.py +++ b/topobenchmarkx/data/datasets.py @@ -40,43 +40,3 @@ def len(self): Length of the dataset. """ return len(self.data_lst) - - -class TorchGeometricDataset(torch_geometric.data.Dataset): - r"""Dataset to work with a list of data objects. - - Parameters - ---------- - data_lst: list - List of torch_geometric.data.Data objects . - """ - - def __init__(self, data_lst): - super().__init__() - self.data_lst = data_lst - - def get(self, idx): - r"""Get data object from data list. - - Parameters - ---------- - idx: int - Index of the data object to get. - - Returns - ------- - torch_geometric.data.Data - Data object of corresponding index. - """ - data = self.data_lst[idx] - return data - - def len(self): - r"""Return length of the dataset. - - Returns - ------- - int - Length of the dataset. - """ - return len(self.data_lst) diff --git a/topobenchmarkx/models/__init__.py b/topobenchmarkx/models/__init__.py index e69de29b..2e4f1c9a 100755 --- a/topobenchmarkx/models/__init__.py +++ b/topobenchmarkx/models/__init__.py @@ -0,0 +1,11 @@ +import topobenchmarkx.models.encoders +import topobenchmarkx.models.head_models +import topobenchmarkx.models.losses +import topobenchmarkx.models.readouts +import topobenchmarkx.models.wrappers + +from topobenchmarkx.models.default_network import TopologicalNetworkModule + +__all__ = [ + "TopologicalNetworkModule", +] diff --git a/topobenchmarkx/models/network_module.py b/topobenchmarkx/models/default_network.py similarity index 98% rename from topobenchmarkx/models/network_module.py rename to topobenchmarkx/models/default_network.py index 989e4c6e..4181c285 100755 --- a/topobenchmarkx/models/network_module.py +++ b/topobenchmarkx/models/default_network.py @@ -5,10 +5,7 @@ from torchmetrics import MeanMetric from torch_geometric.data import Data -# import topomodelx - - -class NetworkModule(LightningModule): +class TopologicalNetworkModule(LightningModule): """A `LightningModule` implements 8 key methods: Docs: @@ -290,4 +287,4 @@ def configure_optimizers(self) -> dict[str, Any]: if __name__ == "__main__": - _ = NetworkModule(None, None, None, None) + _ = TopologicalNetworkModule(None, None, None, None) diff --git a/topobenchmarkx/transforms/__init__.py b/topobenchmarkx/transforms/__init__.py index e69de29b..5baa9ac9 100755 --- a/topobenchmarkx/transforms/__init__.py +++ b/topobenchmarkx/transforms/__init__.py @@ -0,0 +1,62 @@ +# Data manipulation transforms +from topobenchmarkx.transforms.data_manipulations import ( + CalculateSimplicialCurvature, + EqualGausFeatures, + IdentityTransform, + InfereKNNConnectivity, + InfereRadiusConnectivity, + KeepOnlyConnectedComponent, + KeepSelectedDataFields, + NodeDegrees, + NodeFeaturesToFloat, + OneHotDegreeFeatures, +) + +# Feature liftings +from topobenchmarkx.transforms.feature_liftings import ( + ConcatentionLifting, + ProjectionSum, + SetLifting, +) + +# Topology Liftings +from topobenchmarkx.transforms.liftings import ( + CellCyclesLifting, + HypergraphKHopLifting, + HypergraphKNearestNeighborsLifting, + SimplicialCliqueLifting, + SimplicialNeighborhoodLifting, +) + +# Dictionalry of all available transforms +TRANSFORMS = { + # Graph -> Hypergraph + "HypergraphKHopLifting": HypergraphKHopLifting, + "HypergraphKNearestNeighborsLifting": HypergraphKNearestNeighborsLifting, + # Graph -> Simplicial Complex + "SimplicialNeighborhoodLifting": SimplicialNeighborhoodLifting, + "SimplicialCliqueLifting": SimplicialCliqueLifting, + # Graph -> Cell Complex + "CellCyclesLifting": CellCyclesLifting, + # Feature Liftings + "ProjectionSum": ProjectionSum, + "ConcatentionLifting": ConcatentionLifting, + "SetLifting": SetLifting, + # Data Manipulations + "Identity": IdentityTransform, + "InfereKNNConnectivity": InfereKNNConnectivity, + "InfereRadiusConnectivity": InfereRadiusConnectivity, + "NodeDegrees": NodeDegrees, + "OneHotDegreeFeatures": OneHotDegreeFeatures, + "EqualGausFeatures": EqualGausFeatures, + "NodeFeaturesToFloat": NodeFeaturesToFloat, + "CalculateSimplicialCurvature": CalculateSimplicialCurvature, + "KeepOnlyConnectedComponent": KeepOnlyConnectedComponent, + "KeepSelectedDataFields": KeepSelectedDataFields, +} + + + +__all__ = [ + "TRANSFORMS", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/__init__.py b/topobenchmarkx/transforms/data_manipulations/__init__.py index e69de29b..e626400b 100644 --- a/topobenchmarkx/transforms/data_manipulations/__init__.py +++ b/topobenchmarkx/transforms/data_manipulations/__init__.py @@ -0,0 +1,26 @@ +from topobenchmarkx.transforms.data_manipulations.identity_transform import IdentityTransform +from topobenchmarkx.transforms.data_manipulations.infere_knn_connectivity import InfereKNNConnectivity +from topobenchmarkx.transforms.data_manipulations.infere_radius_connectivity import InfereRadiusConnectivity +from topobenchmarkx.transforms.data_manipulations.equal_gaus_features import EqualGausFeatures +from topobenchmarkx.transforms.data_manipulations.node_features_to_float import NodeFeaturesToFloat +from topobenchmarkx.transforms.data_manipulations.node_degrees import NodeDegrees +from topobenchmarkx.transforms.data_manipulations.keep_only_connected_component import KeepOnlyConnectedComponent +from topobenchmarkx.transforms.data_manipulations.calculate_simplicial_curvature import CalculateSimplicialCurvature +from topobenchmarkx.transforms.data_manipulations.one_hot_degree_features import OneHotDegreeFeatures +from topobenchmarkx.transforms.data_manipulations.keep_selected_data_fields import KeepSelectedDataFields + + + + +__all__ = [ + "IdentityTransform", + "InfereKNNConnectivity", + "InfereRadiusConnectivity", + "EqualGausFeatures", + "NodeFeaturesToFloat", + "NodeDegrees", + "KeepOnlyConnectedComponent", + "CalculateSimplicialCurvature", + "OneHotDegreeFeatures", + "KeepSelectedDataFields", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py b/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py new file mode 100644 index 00000000..99ff5c05 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py @@ -0,0 +1,112 @@ +import torch +import torch_geometric + +class CalculateSimplicialCurvature(torch_geometric.transforms.BaseTransform): + """A transform that calculates the simplicial curvature of the input graph. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "simplicial_curvature" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + """Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + data = self.one_cell_curvature(data) + data = self.zero_cell_curvature(data) + data = self.two_cell_curvature(data) + return data + + def zero_cell_curvature( + self, + data: torch_geometric.data.Data, + ) -> torch_geometric.data.Data: + """Calculate the zero cell curvature of the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + Data with the zero cell curvature. + """ + data["0_cell_curvature"] = torch.mm( + abs(data["incidence_1"]), data["1_cell_curvature"] + ) + return data + + def one_cell_curvature( + self, + data: torch_geometric.data.Data, + ) -> torch_geometric.data.Data: + r"""Calculate the one cell curvature of the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + Data with the one cell curvature. + """ + data["1_cell_curvature"] = ( + 4 + - torch.mm(abs(data["incidence_1"]).T, data["0_cell_degrees"]) + + 3 * data["1_cell_degrees"] + ) + return data + + def two_cell_curvature( + self, + data: torch_geometric.data.Data, + ) -> torch_geometric.data.Data: + r"""Calculate the two cell curvature of the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + Data with the two cell curvature. + """ + # Term 1 is simply the degree of the 2-cell (i.e. each triangle belong to n tetrahedrons) + term1 = data["2_cell_degrees"] + # Find triangles that belong to multiple tetrahedrons + two_cell_degrees = data["2_cell_degrees"].clone() + idx = torch.where(data["2_cell_degrees"] > 1)[0] + two_cell_degrees[idx] = 0 + up = data["incidence_3"].to_dense() @ data["incidence_3"].to_dense().T + down = ( + data["incidence_2"].to_dense().T @ data["incidence_2"].to_dense() + ) + mask = torch.eye(up.size()[0]).bool() + up.masked_fill_(mask, 0) + down.masked_fill_(mask, 0) + diff = (down - up) * 1 + term2 = diff.sum(1, keepdim=True) + data["2_cell_curvature"] = 3 + term1 - term2 + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py b/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py new file mode 100644 index 00000000..ad0e68b9 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py @@ -0,0 +1,40 @@ +import torch +import torch_geometric + +class EqualGausFeatures(torch_geometric.transforms.BaseTransform): + r"""A transform that generates equal Gaussian features for all nodes in the + input graph. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "generate_non_informative_features" + + # Torch generate feature vector from gaus distribution + self.mean = kwargs["mean"] + self.std = kwargs["std"] + self.feature_vector = kwargs["num_features"] + self.feature_vector = torch.normal( + mean=self.mean, std=self.std, size=(1, self.feature_vector) + ) + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + data.x = self.feature_vector.expand(data.num_nodes, -1) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/identity_transform.py b/topobenchmarkx/transforms/data_manipulations/identity_transform.py new file mode 100644 index 00000000..d2462fbc --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/identity_transform.py @@ -0,0 +1,24 @@ +import torch_geometric + +class IdentityTransform(torch_geometric.transforms.BaseTransform): + r"""An identity transform that does nothing to the input data.""" + + def __init__(self, **kwargs): + super().__init__() + self.type = "domain2domain" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The (un)transformed data. + """ + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py b/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py new file mode 100644 index 00000000..78892adf --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py @@ -0,0 +1,30 @@ +import torch_geometric +from torch_geometric.nn import knn_graph + +class InfereKNNConnectivity(torch_geometric.transforms.BaseTransform): + r"""A transform that generates the k-nearest neighbor connectivity of the + input point cloud.""" + + def __init__(self, **kwargs): + super().__init__() + self.type = "infere_knn_connectivity" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + + edge_index = knn_graph(data.x, **self.parameters["args"]) + + # Remove duplicates + data.edge_index = edge_index + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py b/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py new file mode 100644 index 00000000..ad249065 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py @@ -0,0 +1,26 @@ +import torch_geometric +from torch_geometric.nn import radius_graph + +class InfereRadiusConnectivity(torch_geometric.transforms.BaseTransform): + r"""A transform that generates the radius connectivity of the input point + cloud.""" + + def __init__(self, **kwargs): + super().__init__() + self.type = "infere_radius_connectivity" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + data.edge_index = radius_graph(data.x, **self.parameters["args"]) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py b/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py new file mode 100644 index 00000000..04ca2364 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py @@ -0,0 +1,40 @@ +import torch_geometric +from torch_geometric.transforms import LargestConnectedComponents + +class KeepOnlyConnectedComponent(torch_geometric.transforms.BaseTransform): + """A transform that keeps only the largest connected components of the + input graph. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "keep_connected_component" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + """Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + + + # torch_geometric.transforms.largest_connected_components() + num_components = self.parameters["num_components"] + lcc = LargestConnectedComponents( + num_components=num_components, connection="strong" + ) + data = lcc(data) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py b/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py new file mode 100644 index 00000000..deae587e --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py @@ -0,0 +1,39 @@ +import torch_geometric + +class KeepSelectedDataFields(torch_geometric.transforms.BaseTransform): + r"""A transform that keeps only the selected fields of the input data. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "keep_selected_data_fields" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + # Keeps all the fields + fields_to_keep = ( + self.parameters["base_fields"] + + self.parameters["preserved_fields"] + ) + + for key in data: + if key not in fields_to_keep: + del data[key] + return data diff --git a/topobenchmarkx/transforms/data_manipulations/node_degrees.py b/topobenchmarkx/transforms/data_manipulations/node_degrees.py new file mode 100644 index 00000000..7bd09843 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/node_degrees.py @@ -0,0 +1,87 @@ +import torch +import torch_geometric + +class NodeDegrees(torch_geometric.transforms.BaseTransform): + r"""A transform that calculates the node degrees of the input graph. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "node_degrees" + self.parameters = kwargs + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + field_to_process = [ + key + for key in data.keys() + for field_substring in self.parameters["selected_fields"] + if field_substring in key and key != "incidence_0" + ] + for field in field_to_process: + data = self.calculate_node_degrees(data, field) + + return data + + def calculate_node_degrees( + self, data: torch_geometric.data.Data, field: str + ) -> torch_geometric.data.Data: + r"""Calculate the node degrees of the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + field : str + The field to calculate the node degrees. + + Returns + ------- + torch_geometric.data.Data + """ + if data[field].is_sparse: + degrees = abs(data[field].to_dense()).sum(1) + else: + assert ( + field == "edge_index" + ), "Following logic of finding degrees is only implemented for edge_index" + + # Get number of nodes + if data.get("num_nodes", None): + max_num_nodes = data["num_nodes"] + else: + max_num_nodes = data["x"].shape[0] + degrees = ( + torch_geometric.utils.to_dense_adj( + data[field], + max_num_nodes=max_num_nodes, + ) + .squeeze(0) + .sum(1) + ) + + if "incidence" in field: + field_name = ( + str(int(field.split("_")[1]) - 1) + "_cell" + "_degrees" + ) + else: + field_name = "node_degrees" + + data[field_name] = degrees.unsqueeze(1) + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py b/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py new file mode 100644 index 00000000..4422d39e --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py @@ -0,0 +1,30 @@ +import torch_geometric + +class NodeFeaturesToFloat(torch_geometric.transforms.BaseTransform): + r"""A transform that converts the node features of the input graph to float. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "map_node_features_to_float" + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + data.x = data.x.float() + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py new file mode 100644 index 00000000..6ed0e333 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py @@ -0,0 +1,65 @@ +import torch +import torch_geometric +from torch_geometric.utils import one_hot + +class OneHotDegree(torch_geometric.transforms.BaseTransform): + r"""Adds the node degree as one hot encodings to the node features. + + Parameters + ---------- + max_degree : int + The maximum degree of the graph. + cat : bool, optional + If set to `True`, the one hot encodings are concatenated to the node features. + """ + + def __init__( + self, + max_degree: int, + cat: bool = False, + ) -> None: + self.max_degree = max_degree + self.cat = cat + + def forward( + self, + data: torch_geometric.data.Data, + degrees_field: str, + features_field: str, + ) -> torch_geometric.data.Data: + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + degrees_field : str + The field containing the node degrees. + features_field : str + The field containing the node features. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + assert data.edge_index is not None + + deg = data[degrees_field].to(torch.long) + + if len(deg.shape) == 2: + deg = deg.squeeze(1) + + deg = one_hot(deg, num_classes=self.max_degree + 1) + + if self.cat: + x = data[features_field] + x = x.view(-1, 1) if x.dim() == 1 else x + data[features_field] = torch.cat([x, deg.to(x.dtype)], dim=-1) + else: + data[features_field] = deg + + return data + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.max_degree})" \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py new file mode 100644 index 00000000..13b75043 --- /dev/null +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py @@ -0,0 +1,42 @@ +import torch_geometric +from topobenchmarkx.transforms.data_manipulations.one_hot_degree import OneHotDegree + + + +class OneHotDegreeFeatures(torch_geometric.transforms.BaseTransform): + r"""A transform that adds the node degree as one hot encodings to the node + features. + + Parameters + ---------- + **kwargs : optional + Parameters for the transform. + """ + + def __init__(self, **kwargs): + super().__init__() + self.type = "one_hot_degree_features" + self.deg_field = kwargs["degrees_fields"] + self.features_fields = kwargs["features_fields"] + self.transform = OneHotDegree(max_degree=kwargs["max_degrees"]) + + def forward(self, data: torch_geometric.data.Data): + r"""Apply the transform to the input data. + + Parameters + ---------- + data : torch_geometric.data.Data + The input data. + + Returns + ------- + torch_geometric.data.Data + The transformed data. + """ + data = self.transform.forward( + data, + degrees_field=self.deg_field, + features_field=self.features_fields, + ) + + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_transform.py b/topobenchmarkx/transforms/data_transform.py index 89431637..3f557229 100755 --- a/topobenchmarkx/transforms/data_transform.py +++ b/topobenchmarkx/transforms/data_transform.py @@ -1,60 +1,5 @@ -# from abc import ABC, abstractmethod - import torch_geometric - -from topobenchmarkx.transforms.data_manipulations.manipulations import ( - CalculateSimplicialCurvature, - EqualGausFeatures, - IdentityTransform, - InfereKNNConnectivity, - InfereRadiusConnectivity, - KeepOnlyConnectedComponent, - KeepSelectedDataFields, - NodeDegrees, - NodeFeaturesToFloat, - OneHotDegreeFeatures, -) -from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( - ConcatentionLifting, - ProjectionSum, - SetLifting, -) -from topobenchmarkx.transforms.liftings.graph2cell import CellCyclesLifting -from topobenchmarkx.transforms.liftings.graph2hypergraph import ( - HypergraphKHopLifting, - HypergraphKNearestNeighborsLifting, -) -from topobenchmarkx.transforms.liftings.graph2simplicial import ( - SimplicialCliqueLifting, - SimplicialNeighborhoodLifting, -) - -TRANSFORMS = { - # Graph -> Hypergraph - "HypergraphKHopLifting": HypergraphKHopLifting, - "HypergraphKNearestNeighborsLifting": HypergraphKNearestNeighborsLifting, - # Graph -> Simplicial Complex - "SimplicialNeighborhoodLifting": SimplicialNeighborhoodLifting, - "SimplicialCliqueLifting": SimplicialCliqueLifting, - # Graph -> Cell Complex - "CellCyclesLifting": CellCyclesLifting, - # Feature Liftings - "ProjectionSum": ProjectionSum, - "ConcatentionLifting": ConcatentionLifting, - "SetLifting": SetLifting, - # Data Manipulations - "Identity": IdentityTransform, - "InfereKNNConnectivity": InfereKNNConnectivity, - "InfereRadiusConnectivity": InfereRadiusConnectivity, - "NodeDegrees": NodeDegrees, - "OneHotDegreeFeatures": OneHotDegreeFeatures, - "EqualGausFeatures": EqualGausFeatures, - "NodeFeaturesToFloat": NodeFeaturesToFloat, - "CalculateSimplicialCurvature": CalculateSimplicialCurvature, - "KeepOnlyConnectedComponent": KeepOnlyConnectedComponent, - "KeepSelectedDataFields": KeepSelectedDataFields, -} - +from topobenchmarkx.transforms import TRANSFORMS class DataTransform(torch_geometric.transforms.BaseTransform): """Abstract class that provides an interface to define a custom data diff --git a/topobenchmarkx/transforms/feature_liftings/__init__.py b/topobenchmarkx/transforms/feature_liftings/__init__.py index e69de29b..c7c378c7 100644 --- a/topobenchmarkx/transforms/feature_liftings/__init__.py +++ b/topobenchmarkx/transforms/feature_liftings/__init__.py @@ -0,0 +1,11 @@ +from topobenchmarkx.transforms.feature_liftings.feature_liftings import ( + ConcatentionLifting, + ProjectionSum, + SetLifting, +) + +__all__ = [ + "ConcatentionLifting", + "ProjectionSum", + "SetLifting", +] \ No newline at end of file diff --git a/topobenchmarkx/transforms/liftings/__init__.py b/topobenchmarkx/transforms/liftings/__init__.py index e69de29b..ac9a7d4d 100755 --- a/topobenchmarkx/transforms/liftings/__init__.py +++ b/topobenchmarkx/transforms/liftings/__init__.py @@ -0,0 +1,18 @@ +from topobenchmarkx.transforms.liftings.graph2cell import CellCyclesLifting + +from topobenchmarkx.transforms.liftings.graph2hypergraph import ( + HypergraphKHopLifting, + HypergraphKNearestNeighborsLifting, +) +from topobenchmarkx.transforms.liftings.graph2simplicial import ( + SimplicialCliqueLifting, + SimplicialNeighborhoodLifting, +) + +__all__ = [ + "CellCyclesLifting", + "HypergraphKHopLifting", + "HypergraphKNearestNeighborsLifting", + "SimplicialCliqueLifting", + "SimplicialNeighborhoodLifting", +] \ No newline at end of file From 73ab47dca72ad995b1ab5353e279502edd887960 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Wed, 15 May 2024 08:45:24 +0000 Subject: [PATCH 13/32] check values of incidences --- test/data/test_Dataloaders.py | 73 +++++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 25 deletions(-) diff --git a/test/data/test_Dataloaders.py b/test/data/test_Dataloaders.py index c595ed13..42b6b17b 100644 --- a/test/data/test_Dataloaders.py +++ b/test/data/test_Dataloaders.py @@ -62,44 +62,67 @@ def test_lift_features(self): ---------- None """ + def check_shape(batch, elems, key): + """Check that the batched data has the expected shape.""" + if 'x_' in key or 'x'==key: + rows = 0 + for i in range(len(elems)): + rows += elems[i][key].shape[0] + assert batch[key].shape[0] == rows + assert batch[key].shape[1] == elems[0][key].shape[1] + elif 'edge_index' in key: + cols = 0 + for i in range(len(elems)): + cols += elems[i][key].shape[1] + assert batch[key].shape[0] == 2 + assert batch[key].shape[1] == cols + elif 'batch_' in key: + rows = 0 + n = int(key.split('_')[1]) + for i in range(len(elems)): + rows += elems[i][f'x_{n}'].shape[0] + assert batch[key].shape[0] == rows + elif key in elems[0].keys(): + for i in range(len(batch[key].shape)): + i_elems = 0 + for j in range(len(elems)): + i_elems += elems[j][key].shape[i] + assert batch[key].shape[i] == i_elems + def check_separation(matrix, n_elems_0_row, n_elems_0_col): """Check that the matrix is separated into two parts diagonally concatenated.""" assert torch.all(matrix[:n_elems_0_row, n_elems_0_col:] == 0) assert torch.all(matrix[n_elems_0_row:, :n_elems_0_col] == 0) + def check_values(matrix, m1, m2): + """Check that the values in the matrix are the same as the values in the original data.""" + assert torch.allclose(matrix[:m1.shape[0], :m1.shape[1]], m1) + assert torch.allclose(matrix[m1.shape[0]:, m1.shape[1]:], m2) + batch = next(iter(self.val_dataloader)) elems = [] for i in range(self.batch_size): elems.append(self.val_dataset.data_lst[i]) - # Check that the batched data has the expected shape + # Check shape for key in batch.keys(): - if key in elems[0].keys(): - if 'x_' in key or 'x'==key: - assert batch[key].shape[0] == elems[0][key].shape[0]+elems[1][key].shape[0] - assert batch[key].shape[1] == elems[0][key].shape[1] - elif 'edge_index' in key: - assert batch[key].shape[0] == 2 - assert batch[key].shape[1] == elems[0][key].shape[1]+elems[1][key].shape[1] - else: - for i in range(len(batch[key].shape)): - assert batch[key].shape[i] == elems[0][key].shape[i]+elems[1][key].shape[i] - else: - if 'batch_' in key: - i = int(key.split('_')[1]) - assert batch[key].shape[0] == elems[0][f'x_{i}'].shape[0]+elems[1][f'x_{i}'].shape[0] + check_shape(batch, elems, key) - # Check that the batched data is separated correctly - for key in batch.keys(): - if 'incidence_' in key: - i = int(key.split('_')[1]) - if i==0: - n0_row = 1 - else: - n0_row = torch.sum(batch[f'batch_{i-1}']==0) - n0_col = torch.sum(batch[f'batch_{i}']==0) - check_separation(batch[key].to_dense(), n0_row, n0_col) + # Check that the batched data is separated correctly and the values are correct + if self.batch_size == 2: + for key in batch.keys(): + if 'incidence_' in key: + i = int(key.split('_')[1]) + if i==0: + n0_row = 1 + else: + n0_row = torch.sum(batch[f'batch_{i-1}']==0) + n0_col = torch.sum(batch[f'batch_{i}']==0) + check_separation(batch[key].to_dense(), n0_row, n0_col) + check_values(batch[key].to_dense(), + elems[0][key].to_dense(), + elems[1][key].to_dense()) # Check that going back to a list of data gives the same data batch_list = to_data_list(batch) From 7afbdca37b57dd9202b87038e83339e579996bc9 Mon Sep 17 00:00:00 2001 From: gbg141 Date: Wed, 15 May 2024 12:03:08 +0200 Subject: [PATCH 14/32] env setting --- env.bash | 31 +++++++++-------------------- pyproject.toml | 5 +++++ topobenchmarkx/=0.12.10 | 44 ----------------------------------------- 3 files changed, 14 insertions(+), 66 deletions(-) delete mode 100644 topobenchmarkx/=0.12.10 diff --git a/env.bash b/env.bash index b4225bbf..2d730954 100644 --- a/env.bash +++ b/env.bash @@ -1,33 +1,20 @@ # #!/bin/bash -# set -e - -# # Step 1: Upgrade pip -# pip install --upgrade pip - -# # Step 2: Install dependencies -# yes | pip install -e '.[all]' -# yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -# yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git -# yes | pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 -# yes | pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html -# yes | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -# yes | pip install lightning>=2.0.0 -# yes | pip install numpy pre-commit jupyterlab notebook ipykernel - - yes | conda create -n topox python=3.11.3 conda activate topox +pip install --upgrade pip pip install -e '.[all]' -yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -yes | pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git +yes | pip install git+https://github.com/pyt-team/TopoNetX.git +yes | pip install git+https://github.com/pyt-team/TopoModelX.git +yes | pip install git+https://github.com/pyt-team/TopoEmbedX.git -yes | pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 -yes | pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html -yes | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -yes | pip install numpy pre-commit jupyterlab notebook ipykernel +CUDA="cu115" # if available, select the CUDA version suitable for your system + # e.g. cpu, cu102, cu111, cu113, cu115 +yes | pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} +yes | pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html +yes | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html pytest diff --git a/pyproject.toml b/pyproject.toml index caf89fad..a6564426 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,10 @@ dependencies=[ "lightning==2.2.1", "einops==0.7.0", "wandb==0.16.4", + "tabulate", + "ipykernel", + "notebook", + "jupyterlab", "rich", "rootutils", "pytest", @@ -167,4 +171,5 @@ checks = [ exclude = [ '\.undocumented_method$', '\.__init__$', + '\.__repr__$', ] \ No newline at end of file diff --git a/topobenchmarkx/=0.12.10 b/topobenchmarkx/=0.12.10 deleted file mode 100644 index 3afb89fd..00000000 --- a/topobenchmarkx/=0.12.10 +++ /dev/null @@ -1,44 +0,0 @@ -Collecting wandb - Downloading wandb-0.16.6-py3-none-any.whl.metadata (10 kB) -Collecting Click!=8.0.0,>=7.1 (from wandb) - Using cached click-8.1.7-py3-none-any.whl.metadata (3.0 kB) -Collecting GitPython!=3.1.29,>=1.0.0 (from wandb) - Downloading GitPython-3.1.43-py3-none-any.whl.metadata (13 kB) -Requirement already satisfied: requests<3,>=2.0.0 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (2.31.0) -Requirement already satisfied: psutil>=5.0.0 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (5.9.8) -Collecting sentry-sdk>=1.0.0 (from wandb) - Downloading sentry_sdk-2.1.1-py2.py3-none-any.whl.metadata (10 kB) -Collecting docker-pycreds>=0.4.0 (from wandb) - Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl.metadata (1.8 kB) -Requirement already satisfied: PyYAML in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (6.0.1) -Collecting setproctitle (from wandb) - Using cached setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.9 kB) -Requirement already satisfied: setuptools in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from wandb) (68.2.2) -Collecting appdirs>=1.4.3 (from wandb) - Using cached appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB) -Collecting protobuf!=4.21.0,<5,>=3.19.0 (from wandb) - Using cached protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl.metadata (541 bytes) -Requirement already satisfied: six>=1.4.0 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0) -Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb) - Using cached gitdb-4.0.11-py3-none-any.whl.metadata (1.2 kB) -Requirement already satisfied: charset-normalizer<4,>=2 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (3.3.2) -Requirement already satisfied: idna<4,>=2.5 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (3.7) -Requirement already satisfied: urllib3<3,>=1.21.1 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (2.2.1) -Requirement already satisfied: certifi>=2017.4.17 in /home/lev/miniconda3/envs/topox/lib/python3.11/site-packages (from requests<3,>=2.0.0->wandb) (2024.2.2) -Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) - Using cached smmap-5.0.1-py3-none-any.whl.metadata (4.3 kB) -Downloading wandb-0.16.6-py3-none-any.whl (2.2 MB) - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.2/2.2 MB 35.5 MB/s eta 0:00:00 -Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB) -Using cached click-8.1.7-py3-none-any.whl (97 kB) -Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB) -Downloading GitPython-3.1.43-py3-none-any.whl (207 kB) - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.3/207.3 kB 59.8 MB/s eta 0:00:00 -Using cached protobuf-4.25.3-cp37-abi3-manylinux2014_x86_64.whl (294 kB) -Downloading sentry_sdk-2.1.1-py2.py3-none-any.whl (277 kB) - ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 277.3/277.3 kB 66.9 MB/s eta 0:00:00 -Using cached setproctitle-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB) -Using cached gitdb-4.0.11-py3-none-any.whl (62 kB) -Using cached smmap-5.0.1-py3-none-any.whl (24 kB) -Installing collected packages: appdirs, smmap, setproctitle, sentry-sdk, protobuf, docker-pycreds, Click, gitdb, GitPython, wandb -Successfully installed Click-8.1.7 GitPython-3.1.43 appdirs-1.4.4 docker-pycreds-0.4.0 gitdb-4.0.11 protobuf-4.25.3 sentry-sdk-2.1.1 setproctitle-1.3.3 smmap-5.0.1 wandb-0.16.6 From 83bcd8409bc4bc7b74169011cc2611478df7933f Mon Sep 17 00:00:00 2001 From: gbg141 Date: Wed, 15 May 2024 17:46:09 +0200 Subject: [PATCH 15/32] Update env.bash --- env.bash | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/env.bash b/env.bash index 2d730954..5f00c475 100644 --- a/env.bash +++ b/env.bash @@ -1,20 +1,21 @@ # #!/bin/bash -yes | conda create -n topox python=3.11.3 -conda activate topox +#conda create -n topox python=3.11.3 +#conda activate topox pip install --upgrade pip -pip install -e '.[all]' - -yes | pip install git+https://github.com/pyt-team/TopoNetX.git -yes | pip install git+https://github.com/pyt-team/TopoModelX.git -yes | pip install git+https://github.com/pyt-team/TopoEmbedX.git CUDA="cu115" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 -yes | pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} -yes | pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html -yes | pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html +pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} +pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html +pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html + +pip install -e '.[all]' + +pip install git+https://github.com/pyt-team/TopoNetX.git +pip install git+https://github.com/pyt-team/TopoModelX.git +pip install git+https://github.com/pyt-team/TopoEmbedX.git pytest From c5838de53212e887671620ba934ed17cf7ef88ee Mon Sep 17 00:00:00 2001 From: gbg141 Date: Wed, 15 May 2024 17:46:44 +0200 Subject: [PATCH 16/32] Update env.bash --- env.bash | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/env.bash b/env.bash index 5f00c475..39a413f8 100644 --- a/env.bash +++ b/env.bash @@ -4,6 +4,11 @@ #conda activate topox pip install --upgrade pip +pip install -e '.[all]' + +pip install git+https://github.com/pyt-team/TopoNetX.git +pip install git+https://github.com/pyt-team/TopoModelX.git +pip install git+https://github.com/pyt-team/TopoEmbedX.git CUDA="cu115" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 @@ -11,12 +16,6 @@ pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CU pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html -pip install -e '.[all]' - -pip install git+https://github.com/pyt-team/TopoNetX.git -pip install git+https://github.com/pyt-team/TopoModelX.git -pip install git+https://github.com/pyt-team/TopoEmbedX.git - pytest pre-commit install From 202d33e9ecb21e675eb48a988eac8221d1aab607 Mon Sep 17 00:00:00 2001 From: guille Date: Wed, 15 May 2024 16:03:56 +0000 Subject: [PATCH 17/32] Remove unused Dockerfile and devcontainer configuration files --- .devcontainer/Dockerfile | 16 ---- .devcontainer/devcontainer.json | 18 ----- .devcontainer/pyproject.toml | 129 -------------------------------- 3 files changed, 163 deletions(-) delete mode 100755 .devcontainer/Dockerfile delete mode 100755 .devcontainer/devcontainer.json delete mode 100755 .devcontainer/pyproject.toml diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile deleted file mode 100755 index afb6e0e4..00000000 --- a/.devcontainer/Dockerfile +++ /dev/null @@ -1,16 +0,0 @@ -FROM python:3.11.3 - -WORKDIR /TopoBenchmarkX - -COPY . . - -RUN pip install --upgrade pip - -RUN pip install -e '.[all]' -RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git -RUN pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 -RUN pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html -RUN pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -RUN pip install lightning>=2.0.0 -RUN pip install numpy pre-commit jupyterlab notebook ipykernel \ No newline at end of file diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100755 index 7a2a6bce..00000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,18 +0,0 @@ -//devcontainer.json -{ - "name": "TopoBenchmarkX:new", - "dockerFile": "./Dockerfile", - "customizations": { - "vscode": { - "settings": { - "terminal.integrated.shell.linux": "/bin/bash" - }, - "extensions": [ - "ms-python.python", - "ms-python.isort", - "ms-python.vscode-pylance", - "ms-toolsai.jupyter" - ] - } - } -} \ No newline at end of file diff --git a/.devcontainer/pyproject.toml b/.devcontainer/pyproject.toml deleted file mode 100755 index e9439fe3..00000000 --- a/.devcontainer/pyproject.toml +++ /dev/null @@ -1,129 +0,0 @@ -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "TopoBenchmarkX" -version = "0.0.1" -authors = [ - {name = "PyT-Team Authors", email = "tlscabinet@gmail.com"} -] -readme = "README.md" -description = "Topological Deep Learning" -license = {file = "LICENSE.txt"} -classifiers = [ - "License :: OSI Approved :: MIT License", - "Development Status :: 4 - Beta", - "Intended Audience :: Science/Research", - "Topic :: Scientific/Engineering", - "Topic :: Scientific/Engineering :: Mathematics", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Natural Language :: English", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11" -] -requires-python = ">= 3.10" -dependencies=[ - "tqdm", - "numpy", - "scipy", - "requests", - "scikit-learn", - "matplotlib", - "networkx", - "pandas", - "pyg-nightly", - "decorator", - "hypernetx < 2.0.0", - "trimesh", - "spharapy", - "hydra-core==1.3.2", - "hydra-colorlog==1.2.0", - "hydra-optuna-sweeper==1.2.0", - "rich", - "rootutils", - "pytest", - -] - -[project.optional-dependencies] -doc = [ - "jupyter", - "nbsphinx", - "nbsphinx_link", - "sphinx", - "sphinx_gallery", - "pydata-sphinx-theme" -] -lint = [ - "black", - "black[jupyter]", - "flake8", - "flake8-docstrings", - "Flake8-pyproject", - "isort", - "pre-commit" -] -test = [ - "pytest", - "pytest-cov", - "coverage", - "jupyter", - "mypy" -] - -dev = ["TopoBenchmarkX[test, lint]"] -all = ["TopoBenchmarkX[dev, doc]"] - -[project.urls] -homepage="https://github.com/pyt-team/TopoBenchmarkX" -repository="https://github.com/pyt-team/TopoBenchmarkX" - -[tool.setuptools.dynamic] -version = {attr = "topobenchmarkx.__version__"} - -[tool.setuptools.packages.find] -include = [ - "topobenchmarkx", - "topobenchmarkx.*" -] - -[tool.mypy] -warn_redundant_casts = true -warn_unused_ignores = true -show_error_codes = true -plugins = "numpy.typing.mypy_plugin" - -[[tool.mypy.overrides]] -module = [ - "torch_cluster.*","networkx.*","scipy.spatial","scipy.sparse","toponetx.classes.simplicial_complex" -] -ignore_missing_imports = true - -[tool.pytest.ini_options] -addopts = "--capture=no" - -[tool.black] -line-length = 88 - -[tool.isort] -line_length = 88 -multi_line_output = 3 -include_trailing_comma = true -skip = [".gitignore", "__init__.py"] - -[tool.flake8] -max-line-length = 88 -application_import_names = "topobenchmarkx" -docstring-convention = "numpy" -exclude = [ - "topobenchmarkx/__init__.py", - "docs/conf.py" -] - -import_order_style = "smarkets" -extend-ignore = ["E501", "E203"] -per-file-ignores = [ - "*/__init__.py: D104, F401", -] From 8173c96282982a90179acb19ca8c53a84b3e3b06 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Wed, 15 May 2024 18:16:31 +0200 Subject: [PATCH 18/32] refactor conf --- .../graph2cell_lifting/cell_cycles.yaml | 1 - configs/model/cell/can.yaml | 3 +- configs/model/cell/cccn.yaml | 6 +- configs/model/cell/ccxn.yaml | 3 +- configs/model/cell/cwn.yaml | 8 +- configs/model/graph/gat.yaml | 3 +- configs/model/graph/gcn.yaml | 7 +- configs/model/graph/gin.yaml | 3 +- configs/model/hypergraph/alldeepset.yaml | 3 +- .../model/hypergraph/allsettransformer.yaml | 3 +- configs/model/hypergraph/edgnn.yaml | 3 +- configs/model/hypergraph/unignn.yaml | 3 +- configs/model/hypergraph/unignn2.yaml | 3 +- configs/model/simplicial/san.yaml | 3 +- configs/model/simplicial/sccn.yaml | 4 +- configs/model/simplicial/sccnn.yaml | 3 +- configs/model/simplicial/sccnn_custom.yaml | 3 +- configs/model/simplicial/scn.yaml | 3 +- configs/train.yaml | 2 +- custom_models/cell/cin.py | 1 - custom_models/simplicial/sccnn.py | 1 - notebooks/cornel_dataset.ipynb | 432 ++++++++++++++++++ notebooks/result_processing.ipynb | 4 +- 23 files changed, 475 insertions(+), 30 deletions(-) create mode 100644 notebooks/cornel_dataset.ipynb diff --git a/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml b/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml index ba60f86c..79b91303 100644 --- a/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml +++ b/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml @@ -1,7 +1,6 @@ _target_: topobenchmarkx.transforms.data_transform.DataTransform transform_type: 'lifting' transform_name: "CellCyclesLifting" -k_value: 1 complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3} max_cell_length: 6 preserve_edge_attr: ${oc.select:dataset.parameters.preserve_edge_attr_if_lifted,False} diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index 39df8535..f401f622 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: can model_domain: cell @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 diff --git a/configs/model/cell/cccn.yaml b/configs/model/cell/cccn.yaml index ad196c1d..26a3360f 100755 --- a/configs/model/cell/cccn.yaml +++ b/configs/model/cell/cccn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: cwn_dcm model_domain: cell @@ -8,7 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 - proj_dropout: 0.0 + proj_dropout: 0. selected_dimensions: - 0 - 1 @@ -16,7 +16,7 @@ feature_encoder: backbone: _target_: custom_models.cell.cccn.CCCN in_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 dropout: 0.0 backbone_wrapper: diff --git a/configs/model/cell/ccxn.yaml b/configs/model/cell/ccxn.yaml index 38a884ec..851c4544 100755 --- a/configs/model/cell/ccxn.yaml +++ b/configs/model/cell/ccxn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: ccxn model_domain: cell @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.cell.ccxn.CCXN diff --git a/configs/model/cell/cwn.yaml b/configs/model/cell/cwn.yaml index 5a140b36..a5508dc7 100755 --- a/configs/model/cell/cwn.yaml +++ b/configs/model/cell/cwn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: cwn model_domain: cell @@ -7,7 +7,7 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 64 proj_dropout: 0.0 backbone: @@ -16,7 +16,7 @@ backbone: in_channels_1: ${model.feature_encoder.out_channels} in_channels_2: ${model.feature_encoder.out_channels} hid_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 backbone_wrapper: _target_: topobenchmarkx.models.wrappers.CWNWrapper @@ -27,7 +27,7 @@ backbone_wrapper: readout: _target_: topobenchmarkx.models.readouts.${model.readout.readout_name} - readout_name: PropagateSignalDown # Use in case readout is not needed Options: PropagateSignalDown + readout_name: NoReadOut # Use in case readout is not needed Options: PropagateSignalDown hidden_dim: ${model.feature_encoder.out_channels} num_cell_dimensions: ${infere_list_length:${model.feature_encoder.in_channels}} diff --git a/configs/model/graph/gat.yaml b/configs/model/graph/gat.yaml index 65e353db..27763d6a 100755 --- a/configs/model/graph/gat.yaml +++ b/configs/model/graph/gat.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: gat model_domain: graph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: torch_geometric.nn.models.GAT diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index abefe959..a0593356 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: gcn model_domain: graph @@ -7,13 +7,14 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 64 + out_channels: 32 + proj_dropout: 0.0 backbone: _target_: torch_geometric.nn.models.GCN in_channels: ${model.feature_encoder.out_channels} hidden_channels: ${model.feature_encoder.out_channels} - num_layers: 1 + num_layers: 2 dropout: 0.0 act: relu diff --git a/configs/model/graph/gin.yaml b/configs/model/graph/gin.yaml index e0920399..76658ac8 100755 --- a/configs/model/graph/gin.yaml +++ b/configs/model/graph/gin.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: gin model_domain: graph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: torch_geometric.nn.models.GIN diff --git a/configs/model/hypergraph/alldeepset.yaml b/configs/model/hypergraph/alldeepset.yaml index 99b0e3d9..58681a35 100755 --- a/configs/model/hypergraph/alldeepset.yaml +++ b/configs/model/hypergraph/alldeepset.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: alldeepset model_domain: hypergraph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.allset.AllSet diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 5b5a3e18..850d4ad1 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: allsettransformer model_domain: hypergraph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.allset_transformer.AllSetTransformer diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 6f7d6235..8309c22a 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: edgnn model_domain: hypergraph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 16 + proj_dropout: 0.0 backbone: _target_: custom_models.hypergraph.edgnn.EDGNN diff --git a/configs/model/hypergraph/unignn.yaml b/configs/model/hypergraph/unignn.yaml index 68d71223..27e2905a 100755 --- a/configs/model/hypergraph/unignn.yaml +++ b/configs/model/hypergraph/unignn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: unignn2 mode_domain: hypergraph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.unigcn.UniGCN diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index 1d216496..185b1c8a 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: unignn2 mode_domain: hypergraph @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.unigcnii.UniGCNII diff --git a/configs/model/simplicial/san.yaml b/configs/model/simplicial/san.yaml index 61c63567..53f2eeb5 100755 --- a/configs/model/simplicial/san.yaml +++ b/configs/model/simplicial/san.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: san model_domain: simplicial @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 64 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 diff --git a/configs/model/simplicial/sccn.yaml b/configs/model/simplicial/sccn.yaml index 82e5592d..55e61dde 100755 --- a/configs/model/simplicial/sccn.yaml +++ b/configs/model/simplicial/sccn.yaml @@ -1,12 +1,14 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: sccnn model_domain: simplicial + feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} # ${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 backbone: _target_: topomodelx.nn.simplicial.sccn.SCCN diff --git a/configs/model/simplicial/sccnn.yaml b/configs/model/simplicial/sccnn.yaml index fdfc824c..2c15e115 100755 --- a/configs/model/simplicial/sccnn.yaml +++ b/configs/model/simplicial/sccnn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: sccnn model_domain: simplicial @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 diff --git a/configs/model/simplicial/sccnn_custom.yaml b/configs/model/simplicial/sccnn_custom.yaml index 4bd0ebdc..29fab672 100755 --- a/configs/model/simplicial/sccnn_custom.yaml +++ b/configs/model/simplicial/sccnn_custom.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule model_name: sccnn_custom model_domain: simplicial @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 diff --git a/configs/model/simplicial/scn.yaml b/configs/model/simplicial/scn.yaml index cb74351f..f4063e4e 100755 --- a/configs/model/simplicial/scn.yaml +++ b/configs/model/simplicial/scn.yaml @@ -1,4 +1,4 @@ -_target_: topobenchmarkx.models.network_module.NetworkModule +_target_: topobenchmarkx.models.TopologicalNetworkModule mdoel_name: scn model_type: simplicial @@ -8,6 +8,7 @@ feature_encoder: encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} out_channels: 32 + proj_dropout: 0.0 selected_dimensions: - 0 - 1 diff --git a/configs/train.yaml b/configs/train.yaml index 8de14308..41127cef 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: IMDB-BINARY #us_country_demos + - dataset: REDDIT-BINARY #us_country_demos - model: simplicial/sccnn_custom #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default diff --git a/custom_models/cell/cin.py b/custom_models/cell/cin.py index db912b06..9b136bbb 100644 --- a/custom_models/cell/cin.py +++ b/custom_models/cell/cin.py @@ -4,7 +4,6 @@ import torch.nn as nn import torch.nn.functional as F from topomodelx.base.conv import Conv -from topomodelx.nn.cell.cwn_layer import CWNLayer from torch_geometric.nn.models import MLP diff --git a/custom_models/simplicial/sccnn.py b/custom_models/simplicial/sccnn.py index 586ca909..b9a816ab 100644 --- a/custom_models/simplicial/sccnn.py +++ b/custom_models/simplicial/sccnn.py @@ -1,7 +1,6 @@ """SCCNN implementation for complex classification.""" import torch -from topomodelx.nn.simplicial.sccnn_layer import SCCNNLayer from torch.nn.parameter import Parameter diff --git a/notebooks/cornel_dataset.ipynb b/notebooks/cornel_dataset.ipynb new file mode 100644 index 00000000..e6eeae39 --- /dev/null +++ b/notebooks/cornel_dataset.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "\n", + "# Add manually root '/home/lev/projects/TopoBenchmarkX'\n", + "root_path = \"/home/lev/projects/TopoBenchmarkX\"\n", + "if root_path not in sys.path:\n", + " sys.path.append(root_path)\n", + "\n", + "import os.path as osp\n", + "from collections.abc import Callable\n", + "\n", + "from torch_geometric.data import Data, InMemoryDataset\n", + "from torch_geometric.io import fs\n", + "\n", + "from topobenchmarkx.io.load.download_utils import download_file_from_drive\n", + "\n", + "\n", + "class CornelDataset(InMemoryDataset):\n", + " r\"\"\" \"\"\"\n", + "\n", + " URLS = {\n", + " # 'contact-high-school': 'https://drive.google.com/open?id=1VA2P62awVYgluOIh1W4NZQQgkQCBk-Eu',\n", + " \"US-county-demos\": \"https://drive.google.com/file/d/1FNF_LbByhYNICPNdT6tMaJI9FxuSvvLK/view?usp=sharing\",\n", + " }\n", + "\n", + " FILE_FORMAT = {\n", + " # 'contact-high-school': 'tar.gz',\n", + " \"US-county-demos\": \"zip\",\n", + " }\n", + "\n", + " RAW_FILE_NAMES = {}\n", + "\n", + " def __init__(\n", + " self,\n", + " root: str,\n", + " name: str,\n", + " parameters: dict = None,\n", + " transform: Callable | None = None,\n", + " pre_transform: Callable | None = None,\n", + " pre_filter: Callable | None = None,\n", + " force_reload: bool = True,\n", + " use_node_attr: bool = False,\n", + " use_edge_attr: bool = False,\n", + " ) -> None:\n", + " self.name = name.replace(\"_\", \"-\")\n", + "\n", + " super().__init__(\n", + " root, transform, pre_transform, pre_filter, force_reload=force_reload\n", + " )\n", + "\n", + " # Step 3:Load the processed data\n", + " # After the data has been downloaded from source\n", + " # Then preprocessed to obtain x,y and saved into processed folder\n", + " # We can now load the processed data from processed folder\n", + "\n", + " # Load the processed data\n", + " data, _, _ = fs.torch_load(self.processed_paths[0])\n", + "\n", + " # Map the loaded data into\n", + " data = Data.from_dict(data)\n", + "\n", + " # Step 5: Create the splits and upload desired fold\n", + "\n", + " # split_idx = random_splitting(data.y, parameters=self.parameters)\n", + "\n", + " # Assign data object to self.data, to make it be prodessed by Dataset class\n", + " self.data = data\n", + "\n", + " @property\n", + " def raw_dir(self) -> str:\n", + " return osp.join(self.root, self.name, \"raw\")\n", + "\n", + " @property\n", + " def processed_dir(self) -> str:\n", + " return osp.join(self.root, self.name, \"processed\")\n", + "\n", + " @property\n", + " def raw_file_names(self) -> list[str]:\n", + " names = [\"\", \"_2012\"]\n", + " return [f\"{self.name}_{name}.txt\" for name in names]\n", + "\n", + " @property\n", + " def processed_file_names(self) -> str:\n", + " return \"data.pt\"\n", + "\n", + " def download(self) -> None:\n", + " \"\"\"\n", + " Downloads the dataset from the specified URL and saves it to the raw directory.\n", + "\n", + " Raises:\n", + " FileNotFoundError: If the dataset URL is not found.\n", + " \"\"\"\n", + "\n", + " # Step 1: Download data from the source\n", + " self.url = self.URLS[self.name]\n", + " self.file_format = self.FILE_FORMAT[self.name]\n", + "\n", + " download_file_from_drive(\n", + " file_link=self.url,\n", + " path_to_save=self.raw_dir,\n", + " dataset_name=self.name,\n", + " file_format=self.file_format,\n", + " )\n", + "\n", + " # Extract the downloaded file if it is compressed\n", + " fs.cp(\n", + " f\"{self.raw_dir}/{self.name}.{self.file_format}\", self.raw_dir, extract=True\n", + " )\n", + "\n", + " # Move the etracted files to the datasets/domain/dataset_name/raw/ directory\n", + " for filename in fs.ls(osp.join(self.raw_dir, self.name)):\n", + " fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename)))\n", + " fs.rm(osp.join(self.raw_dir, self.name))\n", + "\n", + " # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}'\n", + " fs.rm(f\"{self.raw_dir}/{self.name}.{self.file_format}\")\n", + "\n", + " def process(self) -> None:\n", + " \"\"\"\n", + " Process the data for the dataset.\n", + "\n", + " This method loads the US county demographics data, applies any pre-processing transformations if specified,\n", + " and saves the processed data to the appropriate location.\n", + "\n", + " Returns:\n", + " None\n", + " \"\"\"\n", + " data = load_us_county_demos(self.raw_dir, self.name)\n", + "\n", + " data = data if self.pre_transform is None else self.pre_transform(data)\n", + " self.save([data], self.processed_paths[0])\n", + "\n", + " def __repr__(self) -> str:\n", + " return f\"{self.name}({len(self)})\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "import torch_geometric\n", + "\n", + "\n", + "def load_us_county_demos(path, dataset_name, year=2012):\n", + " edges_df = pd.read_csv(f\"{path}/county_graph.csv\")\n", + " stat = pd.read_csv(f\"{path}/county_stats_{year}.csv\", encoding=\"ISO-8859-1\")\n", + "\n", + " keep_cols = [\n", + " \"FIPS\",\n", + " \"DEM\",\n", + " \"GOP\",\n", + " \"MedianIncome\",\n", + " \"MigraRate\",\n", + " \"BirthRate\",\n", + " \"DeathRate\",\n", + " \"BachelorRate\",\n", + " \"UnemploymentRate\",\n", + " ]\n", + " # Drop rows with missing values\n", + " stat = stat[keep_cols].dropna()\n", + "\n", + " # Delete edges that are not present in stat df\n", + " unique_fips = stat[\"FIPS\"].unique()\n", + "\n", + " src_ = edges_df[\"SRC\"].apply(lambda x: x in unique_fips)\n", + " dst_ = edges_df[\"DST\"].apply(lambda x: x in unique_fips)\n", + "\n", + " edges_df = edges_df[src_ & dst_]\n", + "\n", + " # Remove rows from stat df where edges_df['SRC'] or edges_df['DST'] are not present\n", + " stat = stat[stat[\"FIPS\"].isin(edges_df[\"SRC\"]) & stat[\"FIPS\"].isin(edges_df[\"DST\"])]\n", + " stat = stat.reset_index(drop=True)\n", + "\n", + " # Remove rows where SRC == DST\n", + " edges_df = edges_df[edges_df[\"SRC\"] != edges_df[\"DST\"]]\n", + "\n", + " # Get torch_geometric edge_index format\n", + " edge_index = torch.tensor(\n", + " np.stack([edges_df[\"SRC\"].to_numpy(), edges_df[\"DST\"].to_numpy()])\n", + " )\n", + "\n", + " # Make edge_index undirected\n", + " edge_index = torch_geometric.utils.to_undirected(edge_index)\n", + "\n", + " # Convert edge_index back to pandas DataFrame\n", + " edges_df = pd.DataFrame(edge_index.numpy().T, columns=[\"SRC\", \"DST\"])\n", + "\n", + " del edge_index\n", + "\n", + " # Map stat['FIPS'].unique() to [0, ..., num_nodes]\n", + " fips_map = {fips: i for i, fips in enumerate(stat[\"FIPS\"].unique())}\n", + " stat[\"FIPS\"] = stat[\"FIPS\"].map(fips_map)\n", + "\n", + " # Map edges_df['SRC'] and edges_df['DST'] to [0, ..., num_nodes]\n", + " edges_df[\"SRC\"] = edges_df[\"SRC\"].map(fips_map)\n", + " edges_df[\"DST\"] = edges_df[\"DST\"].map(fips_map)\n", + "\n", + " # Get torch_geometric edge_index format\n", + " edge_index = torch.tensor(\n", + " np.stack([edges_df[\"SRC\"].to_numpy(), edges_df[\"DST\"].to_numpy()])\n", + " )\n", + "\n", + " # Remove isolated nodes (Note: this function maps the nodes to [0, ..., num_nodes] automatically)\n", + " edge_index, _, mask = torch_geometric.utils.remove_isolated_nodes(edge_index)\n", + "\n", + " # Conver mask to index\n", + " index = np.arange(mask.size(0))[mask]\n", + " stat = stat.iloc[index]\n", + " stat = stat.reset_index(drop=True)\n", + "\n", + " # Get new values for FIPS from current index\n", + " # To understand why please print stat.iloc[[516, 517, 518, 519, 520]] for 2012 year\n", + " # Basically the FIPS values has been shifted\n", + " stat[\"FIPS\"] = stat.reset_index()[\"index\"]\n", + "\n", + " # Create Election variable\n", + " stat[\"Election\"] = (stat[\"DEM\"] - stat[\"GOP\"]) / (stat[\"DEM\"] + stat[\"GOP\"])\n", + "\n", + " # Drop DEM and GOP columns and FIPS\n", + " stat = stat.drop(columns=[\"DEM\", \"GOP\", \"FIPS\"])\n", + "\n", + " # Prediction col\n", + " y_col = \"Election\" # TODO: Define through config file\n", + " x_col = list(set(stat.columns).difference(set([y_col])))\n", + "\n", + " stat[\"MedianIncome\"] = (\n", + " stat[\"MedianIncome\"]\n", + " .apply(lambda x: x.replace(\",\", \"\"))\n", + " .to_numpy()\n", + " .astype(float)\n", + " )\n", + "\n", + " x = stat[x_col].to_numpy()\n", + " y = stat[y_col].to_numpy()\n", + "\n", + " data = torch_geometric.data.Data(x=x, y=y, edge_index=edge_index)\n", + "\n", + " return data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Download complete.\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Processing...\n", + "Done!\n" + ] + } + ], + "source": [ + "a = CornelDataset(\n", + " root=\"/home/lev/projects/TopoBenchmarkX/datasets/graph\", name=\"US-county-demos\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "parameters = {\n", + " \"data_seed\": 0,\n", + " \"data_split_dir\": \"/home/lev/projects/TopoBenchmarkX/datasets/data_splits/US-county-demos\",\n", + " \"train_prop\": 0.5,\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(3107, 6)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a[0].x.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "ename": "AttributeError", + "evalue": "'dict' object has no attribute 'k'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[35], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m a \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mk\u001b[39m\u001b[38;5;124m'\u001b[39m:\u001b[38;5;241m1\u001b[39m}\n\u001b[0;32m----> 2\u001b[0m a\u001b[38;5;241m.\u001b[39mk\n", + "\u001b[0;31mAttributeError\u001b[0m: 'dict' object has no attribute 'k'" + ] + } + ], + "source": [ + "a = {\"k\": 1}\n", + "a.k" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "stat = pd.read_csv(\n", + " \"/home/lev/projects/TopoBenchmarkX/datasets/graph/US-county-demos-2012/raw/US-county-demos/county_stats_2016.csv\",\n", + " encoding=\"ISO-8859-1\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Index(['FIPS', 'County', 'DEM', 'GOP', 'MedianIncome', 'MigraRate',\n", + " 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'],\n", + " dtype='object')" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "stat.columns" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "('Election',\n", + " 'MedianIncome',\n", + " 'MigraRate',\n", + " 'BirthRate',\n", + " 'DeathRate',\n", + " 'BachelorRate',\n", + " 'UnemploymentRate')" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "(\n", + " \"Election\",\n", + " \"MedianIncome\",\n", + " \"MigraRate\",\n", + " \"BirthRate\",\n", + " \"DeathRate\",\n", + " \"BachelorRate\",\n", + " \"UnemploymentRate\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "topo", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/notebooks/result_processing.ipynb b/notebooks/result_processing.ipynb index c684a18d..15313a9b 100644 --- a/notebooks/result_processing.ipynb +++ b/notebooks/result_processing.ipynb @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -87,7 +87,7 @@ " dtype='object')" ] }, - "execution_count": 2, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } From d1a0ffa03dcbb4766e726c624d2ae8fa8dc7879d Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Wed, 15 May 2024 23:07:07 +0200 Subject: [PATCH 19/32] graph scripts --- configs/dataset/MUTAG.yaml | 3 +- configs/dataset/NCI1.yaml | 2 +- configs/dataset/PROTEINS_TU.yaml | 2 +- configs/dataset/ZINC.yaml | 8 +- configs/dataset/coauthorship_citeseer.yaml | 2 +- configs/dataset/coauthorship_cora.yaml | 2 +- configs/dataset/manual_dataset.yaml | 2 +- configs/logger/wandb.yaml | 2 +- configs/model/graph/gcn.yaml | 2 +- configs/train.yaml | 4 +- env.bash | 6 +- hp_scripts/main_exp/graph/gat.sh | 136 ++++++++++++++++++ hp_scripts/main_exp/graph/gcn.sh | 136 ++++++++++++++++++ hp_scripts/main_exp/graph/gin.sh | 136 ++++++++++++++++++ .../main_exp}/simplicial/SCN.sh | 0 notebooks/result_processing.ipynb | 55 ++++++- topobenchmarkx/__init__.py | 1 - topobenchmarkx/run_graph_scripts.sh | 5 + topobenchmarkx/simplicial.sh | 3 + topobenchmarkx/train.py | 1 + topobenchmarkx/utils/config_resolvers.py | 6 +- topobenchmarkx/utils/logging_utils.py | 1 + 22 files changed, 492 insertions(+), 23 deletions(-) create mode 100644 hp_scripts/main_exp/graph/gat.sh create mode 100644 hp_scripts/main_exp/graph/gcn.sh create mode 100644 hp_scripts/main_exp/graph/gin.sh rename {topobenchmarkx/hp_scripts => hp_scripts/main_exp}/simplicial/SCN.sh (100%) create mode 100644 topobenchmarkx/run_graph_scripts.sh diff --git a/configs/dataset/MUTAG.yaml b/configs/dataset/MUTAG.yaml index 654c7b33..e3cdbdb7 100755 --- a/configs/dataset/MUTAG.yaml +++ b/configs/dataset/MUTAG.yaml @@ -15,6 +15,7 @@ parameters: num_features: - 7 # initial node features - 4 # initial edge features + num_classes: 2 task: classification loss_type: cross_entropy @@ -26,7 +27,7 @@ parameters: train_prop: 0.5 # for "random" strategy splitting # Lifting parameters - max_dim_if_lifted: 2 + max_dim_if_lifted: 3 # This is the maximum dimension of the simplicial complex in the dataset preserve_edge_attr_if_lifted: False # Dataloader parameters diff --git a/configs/dataset/NCI1.yaml b/configs/dataset/NCI1.yaml index 622481e3..60572768 100755 --- a/configs/dataset/NCI1.yaml +++ b/configs/dataset/NCI1.yaml @@ -21,7 +21,7 @@ parameters: monitor_metric: accuracy task_level: graph data_seed: 0 - split_type: k-fold #'k-fold' # either "k-fold" or "random" strategies + split_type: random #'k-fold' # either "k-fold" or "random" strategies k: 10 # for "k-fold" Cross-Validation train_prop: 0.5 # for "random" strategy splitting diff --git a/configs/dataset/PROTEINS_TU.yaml b/configs/dataset/PROTEINS_TU.yaml index d69a2a31..1b4afb96 100755 --- a/configs/dataset/PROTEINS_TU.yaml +++ b/configs/dataset/PROTEINS_TU.yaml @@ -19,7 +19,7 @@ parameters: monitor_metric: accuracy task_level: graph data_seed: 9 - split_type: k-fold #'k-fold' # either "k-fold" or "random" strategies + split_type: random #'k-fold' # either "k-fold" or "random" strategies k: 10 # for "k-fold" Cross-Validation train_prop: 0.5 # for "random" strategy splitting diff --git a/configs/dataset/ZINC.yaml b/configs/dataset/ZINC.yaml index 62a39319..5a77747d 100644 --- a/configs/dataset/ZINC.yaml +++ b/configs/dataset/ZINC.yaml @@ -1,7 +1,10 @@ _target_: topobenchmarkx.io.load.loaders.GraphLoader +# USE python train.py dataset.transforms.one_hot_node_degree_features.degrees_fields=x to run this config + defaults: - - transforms/data_manipulations: node_feat_to_float + #- transforms/data_manipulations: node_feat_to_float + - transforms/data_manipulations@transforms.one_hot_node_degree_features: one_hot_node_degree_features - transforms: ${get_default_transform:graph,${model}} # Data definition @@ -13,7 +16,8 @@ parameters: data_split_dir: ${paths.data_dir}data_splits/${dataset.parameters.data_name} # Dataset parameters - num_features: 1 # here basically I specify the initial num features in mutang at x aka x_0 + num_features: 21 # torch_geometric ZINC dataset has 21 atom types + max_node_degree: 20 # Use it to one_hot encode node degrees. Additional parameter to run dataset.transforms.one_hot_node_degree_features.degrees_fields=x num_classes: 1 task: regression loss_type: mse diff --git a/configs/dataset/coauthorship_citeseer.yaml b/configs/dataset/coauthorship_citeseer.yaml index 2b9dc3c4..078b0e21 100755 --- a/configs/dataset/coauthorship_citeseer.yaml +++ b/configs/dataset/coauthorship_citeseer.yaml @@ -20,7 +20,7 @@ parameters: monitor_metric: accuracy task_level: node data_seed: 0 - split_type: k-fold #'k-fold' # either k-fold or test + split_type: random #'k-fold' # either k-fold or test k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/dataset/coauthorship_cora.yaml b/configs/dataset/coauthorship_cora.yaml index 0c626534..d48233de 100755 --- a/configs/dataset/coauthorship_cora.yaml +++ b/configs/dataset/coauthorship_cora.yaml @@ -19,7 +19,7 @@ parameters: monitor_metric: accuracy task_level: node data_seed: 0 - split_type: k-fold #'k-fold' # either k-fold or test + split_type: random #'k-fold' # either k-fold or test k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/dataset/manual_dataset.yaml b/configs/dataset/manual_dataset.yaml index 6d7191cc..861ad4c8 100755 --- a/configs/dataset/manual_dataset.yaml +++ b/configs/dataset/manual_dataset.yaml @@ -19,7 +19,7 @@ parameters: monitor_metric: accuracy task_level: node data_seed: 0 - split_type: k-fold #'k-fold' # either k-fold or test + split_type: random #'k-fold' # either k-fold or test k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/logger/wandb.yaml b/configs/logger/wandb.yaml index b40863f7..da285376 100755 --- a/configs/logger/wandb.yaml +++ b/configs/logger/wandb.yaml @@ -7,7 +7,7 @@ wandb: offline: False id: null # pass correct id to resume experiment! anonymous: null # enable anonymous logging - project: "topox_10fold_sweep" + project: "None" log_model: False # upload lightning ckpts prefix: "" # a string to put at the beginning of metric keys # entity: "" # set to name of your wandb team diff --git a/configs/model/graph/gcn.yaml b/configs/model/graph/gcn.yaml index a0593356..a40f2bd6 100755 --- a/configs/model/graph/gcn.yaml +++ b/configs/model/graph/gcn.yaml @@ -7,7 +7,7 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 64 proj_dropout: 0.0 backbone: diff --git a/configs/train.yaml b/configs/train.yaml index 41127cef..d345c74c 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,8 +4,8 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: REDDIT-BINARY #us_country_demos - - model: simplicial/sccnn_custom #hypergraph/unignn2 #allsettransformer + - dataset: NCI109 #us_country_demos + - model: graph/gcn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/env.bash b/env.bash index 39a413f8..3a5beb54 100644 --- a/env.bash +++ b/env.bash @@ -1,7 +1,7 @@ # #!/bin/bash -#conda create -n topox python=3.11.3 -#conda activate topox +conda create -n topoxx python=3.11.3 +conda activate topoxx pip install --upgrade pip pip install -e '.[all]' @@ -10,7 +10,7 @@ pip install git+https://github.com/pyt-team/TopoNetX.git pip install git+https://github.com/pyt-team/TopoModelX.git pip install git+https://github.com/pyt-team/TopoEmbedX.git -CUDA="cu115" # if available, select the CUDA version suitable for your system +CUDA="cu117" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html diff --git a/hp_scripts/main_exp/graph/gat.sh b/hp_scripts/main_exp/graph/gat.sh new file mode 100644 index 00000000..231afe7c --- /dev/null +++ b/hp_scripts/main_exp/graph/gat.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=graph/gat \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=graph/gat \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gat \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/hp_scripts/main_exp/graph/gcn.sh b/hp_scripts/main_exp/graph/gcn.sh new file mode 100644 index 00000000..07107d16 --- /dev/null +++ b/hp_scripts/main_exp/graph/gcn.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=graph/gcn \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=graph/gcn \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gcn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/hp_scripts/main_exp/graph/gin.sh b/hp_scripts/main_exp/graph/gin.sh new file mode 100644 index 00000000..a71f31fe --- /dev/null +++ b/hp_scripts/main_exp/graph/gin.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=graph/gin \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=graph/gin \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=graph/gin \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Graph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/topobenchmarkx/hp_scripts/simplicial/SCN.sh b/hp_scripts/main_exp/simplicial/SCN.sh similarity index 100% rename from topobenchmarkx/hp_scripts/simplicial/SCN.sh rename to hp_scripts/main_exp/simplicial/SCN.sh diff --git a/notebooks/result_processing.ipynb b/notebooks/result_processing.ipynb index 15313a9b..24ce72e4 100644 --- a/notebooks/result_processing.ipynb +++ b/notebooks/result_processing.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -71,7 +71,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -87,7 +87,7 @@ " dtype='object')" ] }, - "execution_count": 8, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -98,9 +98,52 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 14, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'enforce_tags': True, 'print_config': True, 'ignore_warnings': False}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df_init['extras'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'model'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "File \u001b[0;32m~/miniconda3/envs/topox/lib/python3.11/site-packages/pandas/core/indexes/base.py:3805\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3804\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 3805\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mcasted_key\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 3806\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m err:\n", + "File \u001b[0;32mindex.pyx:167\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mindex.pyx:196\u001b[0m, in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7081\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", + "File \u001b[0;32mpandas/_libs/hashtable_class_helper.pxi:7089\u001b[0m, in \u001b[0;36mpandas._libs.hashtable.PyObjectHashTable.get_item\u001b[0;34m()\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: 'model'", + "\nThe above exception was the direct cause of the following exception:\n", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 23\u001b[0m\n\u001b[1;32m 21\u001b[0m config_columns \u001b[38;5;241m=\u001b[39m []\n\u001b[1;32m 22\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m column \u001b[38;5;129;01min\u001b[39;00m columns_to_normalize:\n\u001b[0;32m---> 23\u001b[0m df, columns \u001b[38;5;241m=\u001b[39m \u001b[43mnormalize_column\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdf\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcolumn\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 24\u001b[0m config_columns\u001b[38;5;241m.\u001b[39mextend(columns)\n", + "Cell \u001b[0;32mIn[7], line 3\u001b[0m, in \u001b[0;36mnormalize_column\u001b[0;34m(df, column_to_normalize)\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mnormalize_column\u001b[39m(df, column_to_normalize):\n\u001b[1;32m 2\u001b[0m \u001b[38;5;66;03m# Use json_normalize to flatten the nested dictionaries into separate columns\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m flattened_df \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mjson_normalize(\u001b[43mdf\u001b[49m\u001b[43m[\u001b[49m\u001b[43mcolumn_to_normalize\u001b[49m\u001b[43m]\u001b[49m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Rename columns to include 'nested_column' prefix\u001b[39;00m\n\u001b[1;32m 5\u001b[0m flattened_df\u001b[38;5;241m.\u001b[39mcolumns \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 6\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcolumn_to_normalize\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcol\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m col \u001b[38;5;129;01min\u001b[39;00m flattened_df\u001b[38;5;241m.\u001b[39mcolumns\n\u001b[1;32m 7\u001b[0m ]\n", + "File \u001b[0;32m~/miniconda3/envs/topox/lib/python3.11/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n", + "File \u001b[0;32m~/miniconda3/envs/topox/lib/python3.11/site-packages/pandas/core/indexes/base.py:3812\u001b[0m, in \u001b[0;36mIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 3807\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(casted_key, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;129;01mor\u001b[39;00m (\n\u001b[1;32m 3808\u001b[0m \u001b[38;5;28misinstance\u001b[39m(casted_key, abc\u001b[38;5;241m.\u001b[39mIterable)\n\u001b[1;32m 3809\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28many\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(x, \u001b[38;5;28mslice\u001b[39m) \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m casted_key)\n\u001b[1;32m 3810\u001b[0m ):\n\u001b[1;32m 3811\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m InvalidIndexError(key)\n\u001b[0;32m-> 3812\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 3813\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 3814\u001b[0m \u001b[38;5;66;03m# If we have a listlike key, _check_indexing_error will raise\u001b[39;00m\n\u001b[1;32m 3815\u001b[0m \u001b[38;5;66;03m# InvalidIndexError. Otherwise we fall through and re-raise\u001b[39;00m\n\u001b[1;32m 3816\u001b[0m \u001b[38;5;66;03m# the TypeError.\u001b[39;00m\n\u001b[1;32m 3817\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n", + "\u001b[0;31mKeyError\u001b[0m: 'model'" + ] + } + ], "source": [ "def normalize_column(df, column_to_normalize):\n", " # Use json_normalize to flatten the nested dictionaries into separate columns\n", @@ -119,7 +162,7 @@ "\n", "\n", "# Config columns to normalize\n", - "columns_to_normalize = [\"model\", \"dataset\", \"callbacks\"]\n", + "columns_to_normalize = [\"model\", \"dataset\", \"callbacks\", \"paths\"]\n", "\n", "# Keep track of config columns added\n", "config_columns = []\n", diff --git a/topobenchmarkx/__init__.py b/topobenchmarkx/__init__.py index 04fac166..e1e887f3 100755 --- a/topobenchmarkx/__init__.py +++ b/topobenchmarkx/__init__.py @@ -2,7 +2,6 @@ # Import submodules from . import data from . import evaluators -from . import hp_scripts from . import io from . import models from . import transforms diff --git a/topobenchmarkx/run_graph_scripts.sh b/topobenchmarkx/run_graph_scripts.sh new file mode 100644 index 00000000..21685efb --- /dev/null +++ b/topobenchmarkx/run_graph_scripts.sh @@ -0,0 +1,5 @@ +# Run the scripts from the graph directory +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/graph/gcn.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/graph/gin.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/graph/gat.sh + diff --git a/topobenchmarkx/simplicial.sh b/topobenchmarkx/simplicial.sh index 2e20131e..9573e371 100644 --- a/topobenchmarkx/simplicial.sh +++ b/topobenchmarkx/simplicial.sh @@ -23,6 +23,9 @@ # # python train.py dataset=IMDB-BINARY model=simplicial/sccn model.optimizer.lr=0.01,0.001 model.feature_encoder.out_channels=16,64 model.backbone.n_layers=1,2 dataset.parameters.batch_size=128 dataset.parameters.data_seed=0,3,5 trainer.check_val_every_n_epoch=5 callbacks.early_stopping.patience=10 trainer=default logger.wandb.project=topobenchmark_0503 model.backbone_wrapper.wrapper_readout=original,signal_prop_down model.readout.pooling_type=sum,mean --multirun # # python train.py dataset=IMDB-MULTI model=simplicial/sccn model.optimizer.lr=0.01,0.001 model.feature_encoder.out_channels=16,64 model.backbone.n_layers=1,2 dataset.parameters.batch_size=128 dataset.parameters.data_seed=0,3,5 trainer.check_val_every_n_epoch=5 callbacks.early_stopping.patience=10 trainer=default logger.wandb.project=topobenchmark_0503 model.backbone_wrapper.wrapper_readout=original,signal_prop_down model.readout.pooling_type=sum,mean --multirun +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x + + # # SCCNN # # Fixed split # python train.py dataset=ZINC model=simplicial/sccnn model.optimizer.lr=0.01 model.feature_encoder.out_channels=16,64 model.backbone.n_layers=2,4 dataset.parameters.batch_size=128 dataset.parameters.data_seed=0 trainer.check_val_every_n_epoch=5 callbacks.early_stopping.patience=10 trainer=default logger.wandb.project=topobenchmark_0503 dataset.transforms.graph2simplicial_lifting.complex_dim=3 model.backbone_wrapper.wrapper_readout=original,signal_prop_down model.readout.pooling_type=sum,mean callbacks.early_stopping.min_delta=0.005 dataset.transforms.graph2simplicial_lifting.signed=True,False --multirun diff --git a/topobenchmarkx/train.py b/topobenchmarkx/train.py index ff3430d3..d777de81 100755 --- a/topobenchmarkx/train.py +++ b/topobenchmarkx/train.py @@ -22,6 +22,7 @@ log_hyperparameters, task_wrapper, ) + from topobenchmarkx.utils.config_resolvers import ( get_default_transform, get_monitor_metric, diff --git a/topobenchmarkx/utils/config_resolvers.py b/topobenchmarkx/utils/config_resolvers.py index f5aa1bf1..da79d56b 100644 --- a/topobenchmarkx/utils/config_resolvers.py +++ b/topobenchmarkx/utils/config_resolvers.py @@ -168,11 +168,14 @@ def check_for_type_feature_lifting(dataset, lifting): lifting ].complex_dim else: - if not dataset.transforms[lifting].preserve_edge_attr: + # Case when the dataset has not edge attributes + if dataset.transforms[lifting].preserve_edge_attr == False: + if feature_lifting == "projection": return [ dataset.parameters.num_features[0] ] * dataset.transforms[lifting].complex_dim + elif feature_lifting == "concatenation": return_value = [dataset.parameters.num_features] for i in range( @@ -183,6 +186,7 @@ def check_for_type_feature_lifting(dataset, lifting): ] return return_value + else: return [ dataset.parameters.num_features diff --git a/topobenchmarkx/utils/logging_utils.py b/topobenchmarkx/utils/logging_utils.py index 84feb757..316d0191 100755 --- a/topobenchmarkx/utils/logging_utils.py +++ b/topobenchmarkx/utils/logging_utils.py @@ -51,6 +51,7 @@ def log_hyperparameters(object_dict: dict[str, Any]) -> None: hparams["tags"] = cfg.get("tags") hparams["ckpt_path"] = cfg.get("ckpt_path") hparams["seed"] = cfg.get("seed") + hparams["paths"] = cfg.get("paths") # send hparams to all loggers for logger in trainer.loggers: From f782d1500af501270012692ed9c5fdf5cb1fb91f Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Thu, 16 May 2024 18:31:06 +0200 Subject: [PATCH 20/32] hypergraph initial sweep --- .../model/hypergraph/allsettransformer.yaml | 2 +- configs/model/hypergraph/edgnn.yaml | 2 +- configs/train.yaml | 2 +- env.bash | 16 +-- .../main_exp/hypergraph/allsettransformer.sh | 136 ++++++++++++++++++ hp_scripts/main_exp/hypergraph/unignn2.sh | 136 ++++++++++++++++++ topobenchmarkx/graph_search.sh | 14 -- 7 files changed, 283 insertions(+), 25 deletions(-) create mode 100644 hp_scripts/main_exp/hypergraph/allsettransformer.sh create mode 100644 hp_scripts/main_exp/hypergraph/unignn2.sh delete mode 100644 topobenchmarkx/graph_search.sh diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 850d4ad1..04e62c5e 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -14,7 +14,7 @@ backbone: _target_: topomodelx.nn.hypergraph.allset_transformer.AllSetTransformer in_channels: ${model.feature_encoder.out_channels} hidden_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 heads: 4 dropout: 0. mlp_num_layers: 1 diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 8309c22a..6f84cce4 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -16,7 +16,7 @@ backbone: input_dropout: 0.2 dropout: 0.2 activation: relu - MLP_num_layers: 0 + MLP_num_layers: 1 All_num_layers: 1 edconv_type: EquivSet aggregate: 'add' diff --git a/configs/train.yaml b/configs/train.yaml index d345c74c..5518144f 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: NCI109 #us_country_demos + - dataset: NCI1 #us_country_demos - model: graph/gcn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default diff --git a/env.bash b/env.bash index 3a5beb54..32b5ab1d 100644 --- a/env.bash +++ b/env.bash @@ -1,21 +1,21 @@ # #!/bin/bash -conda create -n topoxx python=3.11.3 -conda activate topoxx +conda create -n topox15 python=3.11.3 +conda activate topox15 pip install --upgrade pip pip install -e '.[all]' -pip install git+https://github.com/pyt-team/TopoNetX.git -pip install git+https://github.com/pyt-team/TopoModelX.git -pip install git+https://github.com/pyt-team/TopoEmbedX.git +pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git +pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git -CUDA="cu117" # if available, select the CUDA version suitable for your system +CUDA="cu115" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html +#pip install torch-geometric -f https://data.pyg.org/whl/torch-2.6.0.dev20240506+${CUDA}.html -pytest +# pytest -pre-commit install +# pre-commit install diff --git a/hp_scripts/main_exp/hypergraph/allsettransformer.sh b/hp_scripts/main_exp/hypergraph/allsettransformer.sh new file mode 100644 index 00000000..b0d192ee --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/allsettransformer.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/allsettransformer \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.n_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/allsettransformer \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.n_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/hp_scripts/main_exp/hypergraph/unignn2.sh b/hp_scripts/main_exp/hypergraph/unignn2.sh new file mode 100644 index 00000000..b0d192ee --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/unignn2.sh @@ -0,0 +1,136 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/allsettransformer \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.n_layers="1,2,3,4" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/allsettransformer \ + model.feature_encoder.out_channels="32,64,128" \ + model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.backbone.n_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + diff --git a/topobenchmarkx/graph_search.sh b/topobenchmarkx/graph_search.sh deleted file mode 100644 index d8f49210..00000000 --- a/topobenchmarkx/graph_search.sh +++ /dev/null @@ -1,14 +0,0 @@ -# GCN -python train.py dataset=cocitation_cora model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun #tags "[first_tag, second_tag]" -python train.py dataset=cocitation_citeseer model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=cocitation_pubmed model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=PROTEINS_TU model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=NCI1 model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun - -# python train.py dataset=IMDB-BINARY model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -# python train.py dataset=IMDB-MULTI model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=MUTAG model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=32,64 logger.wandb.project=topobenchmark_22Apr2024 trainer=cpu --multirun -python train.py dataset=ZINC model=graph/gcn model.optimizer.lr=0.01,0.001 model.optimizer.weight_decay=0 model.backbone.hidden_channels=16,32,64,128 model.backbone.num_layers=1,2,3,4 dataset.parameters.batch_size=128,256 dataset.parameters.data_seed=0 model.backbone.dropout=0,0.25,0.5 logger.wandb.project=topobenchmark_22Apr2024 callbacks.early_stopping.patience=10 trainer=default --multirun -python train.py dataset=REDDIT-BINARY model=graph/gcn model.optimizer.lr=0.01,0.001 model.backbone.hidden_channels=64,128,256 model.backbone.num_layers=1,2,3,4 dataset.parameters.data_seed=0,3,5 model.backbone.dropout=0,0.25,0.5 callbacks.early_stopping.patience=10 dataset.parameters.data_seed=0,3,5 dataset.parameters.batch_size=128,256 logger.wandb.project=topobenchmark_22Apr2024 trainer=default --multirun - - From bc604a966c03e4b5774b0a00fa2c496889a711fa Mon Sep 17 00:00:00 2001 From: guille Date: Thu, 16 May 2024 16:33:27 +0000 Subject: [PATCH 21/32] env setting --- env.bash => env.sh | 4 ++-- test.bash => test.sh | 0 2 files changed, 2 insertions(+), 2 deletions(-) rename env.bash => env.sh (91%) rename test.bash => test.sh (100%) diff --git a/env.bash b/env.sh similarity index 91% rename from env.bash rename to env.sh index 3a5beb54..f817fb2e 100644 --- a/env.bash +++ b/env.sh @@ -1,7 +1,7 @@ # #!/bin/bash -conda create -n topoxx python=3.11.3 -conda activate topoxx +#conda create -n topoxx python=3.11.3 +#conda activate topoxx pip install --upgrade pip pip install -e '.[all]' diff --git a/test.bash b/test.sh similarity index 100% rename from test.bash rename to test.sh From 0c991c7d7c7d50a8882f5d6130f90d7fd73c40fa Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Thu, 16 May 2024 21:58:33 +0200 Subject: [PATCH 22/32] deleted table --- tables/dataset_statistics.csv | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 tables/dataset_statistics.csv diff --git a/tables/dataset_statistics.csv b/tables/dataset_statistics.csv deleted file mode 100644 index e0be3002..00000000 --- a/tables/dataset_statistics.csv +++ /dev/null @@ -1,7 +0,0 @@ -,num_hyperedges,zero_cell,one_cell,two_cell,three_cell,dataset,domain -0,0,3224,9483,6266,0,US-county-demos,cell -1,0,2708,5278,2648,0,Cora,cell -2,0,3327,4552,1663,0,citeseer,cell -3,0,19717,44324,23605,0,PubMed,cell -4,0,277864,298985,33121,0,ZINC,cell -5,0,22662,32927,10266,0,roman_empire,cell From 5513c5e59527b21c2141b7d1d30c1bb59049f4ac Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Thu, 16 May 2024 20:19:53 +0000 Subject: [PATCH 23/32] check docstrings --- topobenchmarkx/data/dataloaders.py | 75 ++++++----- topobenchmarkx/data/datasets.py | 29 ++--- topobenchmarkx/data/heteriphilic_dataset.py | 49 +++---- .../data/us_county_demos_dataset.py | 38 +++--- topobenchmarkx/evaluators/comparisons.py | 46 ++++--- topobenchmarkx/evaluators/evaluator.py | 63 ++++----- topobenchmarkx/io/load/download_utils.py | 4 +- topobenchmarkx/io/load/heterophilic.py | 16 +++ topobenchmarkx/io/load/loader.py | 19 +-- topobenchmarkx/io/load/loaders.py | 115 ++++++++--------- topobenchmarkx/io/load/preprocessor.py | 57 +++------ topobenchmarkx/io/load/split_utils.py | 111 ++++++---------- topobenchmarkx/io/load/us_county_demos.py | 19 +-- topobenchmarkx/io/load/utils.py | 120 +++++++----------- 14 files changed, 325 insertions(+), 436 deletions(-) diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index 2578c4af..1b836577 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -10,15 +10,17 @@ class DomainData(Data): - """Data object class that overwrites some methods from - torch_geometric.data.Data so that not only sparse matrices with adj in the - name can work with the torch_geometric dataloaders.""" - + r"""Data object class that overwrites some methods from + `torch_geometric.data.Data` so that not only sparse matrices with adj in the + name can work with the `torch_geometric` dataloaders.""" + def is_valid(self, string): + r"""Check if the string contains any of the valid names.""" valid_names = ["adj", "incidence", "laplacian"] return any(name in string for name in valid_names) def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: + r"""Overwrite the `__cat_dim__` method to handle sparse matrices to handle the names specified in `is_valid`.""" if is_sparse(value) and self.is_valid(key): return (0, 1) elif "index" in key or key == "face": @@ -28,8 +30,7 @@ def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: def to_data_list(batch): - """Workaround needed since torch_geometric doesn't work well with - torch.sparse.""" + """Workaround needed since `torch_geometric` doesn't work when using `torch.sparse` instead of `torch_sparse`.""" for key in batch.keys(): if batch[key].is_sparse: sparse_data = batch[key].coalesce() @@ -43,13 +44,13 @@ def to_data_list(batch): def collate_fn(batch): - """ - args: - batch - list of (tensor, label) + r"""This function overwrites the `torch_geometric.data.DataLoader` collate function to use the `DomainData` class. This ensures that the `torch_geometric` dataloaders work with sparse matrices that are not necessarily named `adj`. The function also generates the batch slices for the different cell dimensions. + + Args: + batch (list): List of data objects (e.g., `torch_geometric.data.Data`). - return: - xs - a tensor of all examples in 'batch' after padding - ys - a LongTensor of all labels in batch + Returns: + torch_geometric.data.Batch: A `torch_geometric.data.Batch` object. """ data_list = [] batch_idx_dict = defaultdict(list) @@ -114,15 +115,15 @@ def collate_fn(batch): class DefaultDataModule(LightningDataModule): - """Initializes the DefaultDataModule class. + r"""This class takes care of returning the dataloaders for the training, validation, and test datasets. It also handles the collate function. The class is designed to work with the `torch` dataloaders. Args: - dataset_train: The training dataset. - dataset_val: The validation dataset (optional). - dataset_test: The test dataset (optional). - batch_size: The batch size for the dataloader. - num_workers: The number of worker processes to use for data loading. - pin_memory: If True, the data loader will copy tensors into pinned memory before returning them. + dataset_train (CustomDataset): The training dataset. + dataset_val (CustomDataset, optional): The validation dataset. (default: None) + dataset_test (CustomDataset, optional): The test dataset. (default: None) + batch_size (int, optional): The batch size for the dataloader. (default: 1) + num_workers (int, optional): The number of worker processes to use for data loading. (default: 0) + pin_memory (bool, optional): If True, the data loader will copy tensors into pinned memory before returning them. (default: False) Returns: None @@ -162,11 +163,15 @@ def __init__( else: self.dataset_val = dataset_val self.dataset_test = dataset_test + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(dataset_train={self.dataset_train}, dataset_val={self.dataset_val}, dataset_test={self.dataset_test}, batch_size={self.batch_size}, num_workers={self.hparams.num_workers}, pin_memory={self.hparams.pin_memory})" def train_dataloader(self) -> DataLoader: - """Create and return the train dataloader. + r"""Create and return the train dataloader. - :return: The train dataloader. + Returns: + torch.utils.data.DataLoader: The train dataloader. """ return DataLoader( dataset=self.dataset_train, @@ -178,9 +183,10 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - """Create and return the validation dataloader. + r"""Create and return the validation dataloader. - :return: The validation dataloader. + Returns: + torch.utils.data.DataLoader: The validation dataloader. """ return DataLoader( dataset=self.dataset_val, @@ -192,9 +198,10 @@ def val_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: - """Create and return the test dataloader. + r"""Create and return the test dataloader. - :return: The test dataloader. + Returns: + torch.utils.data.DataLoader: The test dataloader. """ if self.dataset_test is None: raise ValueError("There is no test dataloader.") @@ -208,26 +215,26 @@ def test_dataloader(self) -> DataLoader: ) def teardown(self, stage: str | None = None) -> None: - """Lightning hook for cleaning up after `trainer.fit()`, + r"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and `trainer.predict()`. - :param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. - Defaults to ``None``. + Args: + stage (str, optional): The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. (default: None) """ def state_dict(self) -> dict[Any, Any]: - """Called when saving a checkpoint. Implement to generate and save the + r"""Called when saving a checkpoint. Implement to generate and save the datamodule state. - :return: A dictionary containing the datamodule state that you want to - save. + Returns: + dict: A dictionary containing the datamodule state that you want to save. """ return {} def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Called when loading a checkpoint. Implement to reload datamodule + r"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule `state_dict()`. - :param state_dict: The datamodule state returned by - `self.state_dict()`. + Args: + state_dict (dict): The datamodule state. This is the object returned by `state_dict()`. """ diff --git a/topobenchmarkx/data/datasets.py b/topobenchmarkx/data/datasets.py index 9ad8c047..d8d5b1bb 100644 --- a/topobenchmarkx/data/datasets.py +++ b/topobenchmarkx/data/datasets.py @@ -4,39 +4,34 @@ class CustomDataset(torch_geometric.data.Dataset): r"""Custom dataset to return all the values added to the dataset object. - Parameters - ---------- - data_lst: list - List of torch_geometric.data.Data objects . + Args: + data_lst (list[torch_geometric.data.Data]): List of torch_geometric.data.Data objects. """ def __init__(self, data_lst): super().__init__() self.data_lst = data_lst + def __repr__(self): + return f"{self.__class__.__name__}(data_lst={self.data_lst})" + def get(self, idx): r"""Get data object from data list. - Parameters - ---------- - idx: int - Index of the data object to get. + Args: + idx (int): Index of the data object to get. - Returns - ------- - tuple - tuple containing a list of all the values for the data and the keys corresponding to the values. + Returns: + tuple: tuple containing a list of all the values for the data and the corresponding keys. """ data = self.data_lst[idx] keys = list(data.keys()) return ([data[key] for key in keys], keys) def len(self): - r"""Return length of the dataset. + r"""Return the length of the dataset. - Returns - ------- - int - Length of the dataset. + Returns: + int: Length of the dataset. """ return len(self.data_lst) diff --git a/topobenchmarkx/data/heteriphilic_dataset.py b/topobenchmarkx/data/heteriphilic_dataset.py index 073d8b65..a073ab8d 100644 --- a/topobenchmarkx/data/heteriphilic_dataset.py +++ b/topobenchmarkx/data/heteriphilic_dataset.py @@ -15,32 +15,30 @@ class HeteroDataset(InMemoryDataset): - r"""Dataset class for US County Demographics dataset. + r"""Dataset class for heterophilic datasets. Args: root (str): Root directory where the dataset will be saved. name (str): Name of the dataset. parameters (DictConfig): Configuration parameters for the dataset. - transform (Optional[Callable]): A function/transform that takes in an + transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. - The transform function is applied to the loaded data before saving it. - pre_transform (Optional[Callable]): A function/transform that takes in an + The transform function is applied to the loaded data before saving it. (default: None) + pre_transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. The pre_transform function is applied to the data before the transform - function is applied. - pre_filter (Optional[Callable]): A function that takes in an + function is applied. (default: None) + pre_filter (Callable, optional): A function that takes in an `torch_geometric.data.Data` object and returns a boolean value - indicating whether the data object should be included in the dataset. - force_reload (bool): If set to True, the dataset will be re-downloaded + indicating whether the data object should be included in the dataset. (default: None) + force_reload (bool, optional): If set to True, the dataset will be re-downloaded even if it already exists on disk. (default: True) - use_node_attr (bool): If set to True, the node attributes will be included + use_node_attr (bool, optional): If set to True, the node attributes will be included in the dataset. (default: False) - use_edge_attr (bool): If set to True, the edge attributes will be included + use_edge_attr (bool, optional): If set to True, the edge attributes will be included in the dataset. (default: False) Attributes: - URLS (dict): Dictionary containing the URLs for downloading the dataset. - FILE_FORMAT (dict): Dictionary containing the file formats for the dataset. RAW_FILE_NAMES (dict): Dictionary containing the raw file names for the dataset. """ @@ -68,18 +66,13 @@ def __init__( force_reload=force_reload, ) - # Step 3:Load the processed data - # After the data has been downloaded from source - # Then preprocessed to obtain x,y and saved into processed folder - # We can now load the processed data from processed folder - # Load the processed data data, _, _ = fs.torch_load(self.processed_paths[0]) # Map the loaded data into data = Data.from_dict(data) - # Step 5: Create the splits and upload desired fold + # Create the splits and upload desired fold splits = random_splitting(data.y, parameters=self.parameters) # Assign train val test masks to the graph @@ -90,7 +83,9 @@ def __init__( # Assign data object to self.data, to make it be prodessed by Dataset class self.data, self.slices = self.collate([data]) - # Do not forget to take care of properties + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload}, self.use_node_attr={self.use_node_attr}, self.use_edge_attr={self.use_edge_attr})" + @property def raw_dir(self) -> str: return osp.join(self.root, self.name, "raw") @@ -105,33 +100,23 @@ def processed_file_names(self) -> str: @property def raw_file_names(self) -> list[str]: - """Spefify the downloaded raw fine name.""" return [f"{self.name}.npz"] def download(self) -> None: - """Downloads the dataset from the specified URL and saves it to the raw + r"""Downloads the dataset from the specified URL and saves it to the raw directory. Raises: FileNotFoundError: If the dataset URL is not found. """ - - # Step 1: Download data from the source download_hetero_datasets(name=self.name, path=self.raw_dir) def process(self) -> None: - """Process the data for the dataset. + r"""Process the data for the dataset. - This method loads the US county demographics data, applies any pre-processing transformations if specified, + This method loads the heterophilic data, applies any pre-processing transformations if specified, and saves the processed data to the appropriate location. - - Returns: - None """ - data = load_heterophilic_data(name=self.name, path=self.raw_dir) data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) - - def __repr__(self) -> str: - return f"{self.name}()" diff --git a/topobenchmarkx/data/us_county_demos_dataset.py b/topobenchmarkx/data/us_county_demos_dataset.py index 118e7816..df1fb632 100644 --- a/topobenchmarkx/data/us_county_demos_dataset.py +++ b/topobenchmarkx/data/us_county_demos_dataset.py @@ -19,21 +19,21 @@ class USCountyDemosDataset(InMemoryDataset): root (str): Root directory where the dataset will be saved. name (str): Name of the dataset. parameters (DictConfig): Configuration parameters for the dataset. - transform (Optional[Callable]): A function/transform that takes in an + transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. - The transform function is applied to the loaded data before saving it. - pre_transform (Optional[Callable]): A function/transform that takes in an + The transform function is applied to the loaded data before saving it. (default: None) + pre_transform (Callable, optional): A function/transform that takes in an `torch_geometric.data.Data` object and returns a transformed version. The pre_transform function is applied to the data before the transform - function is applied. - pre_filter (Optional[Callable]): A function that takes in an + function is applied. (default: None) + pre_filter (Callable, optional): A function that takes in an `torch_geometric.data.Data` object and returns a boolean value - indicating whether the data object should be included in the dataset. - force_reload (bool): If set to True, the dataset will be re-downloaded + indicating whether the data object should be included in the dataset. (default: None) + force_reload (bool, optional): If set to True, the dataset will be re-downloaded even if it already exists on disk. (default: True) - use_node_attr (bool): If set to True, the node attributes will be included + use_node_attr (bool, optional): If set to True, the node attributes will be included in the dataset. (default: False) - use_edge_attr (bool): If set to True, the edge attributes will be included + use_edge_attr (bool, optional): If set to True, the edge attributes will be included in the dataset. (default: False) Attributes: @@ -76,18 +76,13 @@ def __init__( force_reload=force_reload, ) - # Step 3:Load the processed data - # After the data has been downloaded from source - # Then preprocessed to obtain x,y and saved into processed folder - # We can now load the processed data from processed folder - # Load the processed data data, _, _ = fs.torch_load(self.processed_paths[0]) # Map the loaded data into data = Data.from_dict(data) - # Step 5: Create the splits and upload desired fold + # Create the splits and upload desired fold splits = random_splitting(data.y, parameters=self.parameters) # Assign train val test masks to the graph data.train_mask = torch.from_numpy(splits["train"]) @@ -104,6 +99,9 @@ def __init__( # Assign data object to self.data, to make it be prodessed by Dataset class self.data, self.slices = self.collate([data]) + + def __repr__(self) -> str: + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload}, self.use_node_attr={self.use_node_attr}, self.use_edge_attr={self.use_edge_attr})" @property def raw_dir(self) -> str: @@ -123,7 +121,7 @@ def processed_file_names(self) -> str: return "data.pt" def download(self) -> None: - """Downloads the dataset from the specified URL and saves it to the raw + r"""Downloads the dataset from the specified URL and saves it to the raw directory. Raises: @@ -157,13 +155,10 @@ def download(self) -> None: fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}") def process(self) -> None: - """Process the data for the dataset. + r"""Process the data for the dataset. This method loads the US county demographics data, applies any pre-processing transformations if specified, and saves the processed data to the appropriate location. - - Returns: - None """ data = load_us_county_demos( self.raw_dir, @@ -173,6 +168,3 @@ def process(self) -> None: data = data if self.pre_transform is None else self.pre_transform(data) self.save([data], self.processed_paths[0]) - - def __repr__(self) -> str: - return f"{self.name}()" diff --git a/topobenchmarkx/evaluators/comparisons.py b/topobenchmarkx/evaluators/comparisons.py index 27b09b50..ac3980e9 100644 --- a/topobenchmarkx/evaluators/comparisons.py +++ b/topobenchmarkx/evaluators/comparisons.py @@ -2,29 +2,33 @@ from scipy.stats import friedmanchisquare, wilcoxon -def signed_ranks_test(result1, result2): - """Calculates the p-value for the Wilcoxon signed-rank test between the +def signed_ranks_test(results_1, results_2): + r"""Calculates the p-value for the Wilcoxon signed-rank test between the results of two models. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wilcoxon.html - :param results: A 2xN numpy array with the results from the two models. N + Args: + results_1 (numpy.array): A numpy array with the results from the first model. N is the number of datasets over which the models have been tested on. - :return: The p-value of the test + results_2 (numpy.array): A numpy array with the results from the second model. Needs to have the same shape as results_1. + Returns: + float: The p-value of the test. """ - xs = result1 - result2 + xs = results_1 - results_2 return wilcoxon(xs[xs != 0])[1] def friedman_test(results): - """Calculates the p-value of the Friedman test between M models on N + r"""Calculates the p-value of the Friedman test between M models on N datasets. https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.friedmanchisquare.html - :param results: A MxN numpy array with the results of M models over N - dataset - :return: The p-value of the test + Args: + results (numpy.array): A MxN numpy array with the results of M models. + Returns: + float: The p-value of the test. """ res = [r for r in results] return friedmanchisquare(*res)[1] @@ -35,13 +39,14 @@ def compare_models(results, p_limit=0.05, verbose=False): the models are significantly different, then it uses pairwise comparisons to study the ranking of the models. - :param results: A MxN numpy array with the results of M models over N - dataset - :param p_limit: The limit below which a hypothesis is considered false - :param verbose: Whether to print the results of the tests or not - :return average_rank: The average ranks of the models - :return groups: List of lists with the groups of models that are - statistically similar + Args: + results (numpy.array): A MxN numpy array with the results of M models + over N dataset. + p_limit (float, optional): The limit below which a hypothesis is considered false. (default: 0.05) + verbose (bool, optional): Whether to print the results of the tests or not. (default: False) + Returns: + numpy.array: The average ranks of the models + list: A list of lists with the indices of the models that are in the same group. The first group is the best one. """ M = results.shape[0] @@ -86,6 +91,7 @@ def compare_models(results, p_limit=0.05, verbose=False): [0.6, 0.65, 0.7, 0.9, 0.5, 0.552, 0.843, 0.78, 0.665, 0.876], ] ) + print("Signed ranks test:") print(signed_ranks_test(results[0, :], results[1, :])) results2 = np.array( [ @@ -94,6 +100,7 @@ def compare_models(results, p_limit=0.05, verbose=False): [0.1, 0.2, 0.2, 0.3, 0.4, 0.5, 0.22, 0.32, 0.11, 0.4], ] ) + print("Friedman test with very different results:") print(friedman_test(results2)) results3 = np.array( [ @@ -102,7 +109,12 @@ def compare_models(results, p_limit=0.05, verbose=False): [0.89, 0.91, 0.79, 0.81, 0.69], ] ) + print("Friedman test with similar results:") print(friedman_test(results3)) + print("-"*50) + print("Compare models with different results:") print(compare_models(results2, verbose=True)) - print(compare_models(results3)) + print("-"*50) + print("Compare models with similar results:") + print(compare_models(results3, verbose=True)) diff --git a/topobenchmarkx/evaluators/evaluator.py b/topobenchmarkx/evaluators/evaluator.py index 35cec944..7d94a19e 100755 --- a/topobenchmarkx/evaluators/evaluator.py +++ b/topobenchmarkx/evaluators/evaluator.py @@ -9,22 +9,15 @@ class TorchEvaluator: r"""Evaluator class that is responsible for computing the metrics for a given task. - Parameters - ---------- - task : str - The task type. It can be either "classification" or "regression". - - **kwargs : - Additional arguments for the class. The arguments depend on the task. - In "classification" scenario, the following arguments are expected: - - num_classes : int - The number of classes. - - classification_metrics : list - A list of classification metrics to be computed. - - In "regression" scenario, the following arguments are expected: - - regression_metrics : list - A list of regression metrics to be computed. + Args: + task (str): The task type. It can be either "classification" or "regression". + **kwargs : Additional arguments for the class. The arguments depend on the task. + In "classification" scenario, the following arguments are expected: + - num_classes (int): The number of classes. + - classification_metrics (list[str]): A list of classification metrics to be computed. + + In "regression" scenario, the following arguments are expected: + - regression_metrics (list[str]): A list of regression metrics to be computed. """ def __init__(self, task, **kwargs): @@ -67,17 +60,18 @@ def __init__(self, task, **kwargs): self.best_metric = {} + def __repr__(self) -> str: + return f"{self.__class__.__name__}(task={self.task})" + def update(self, model_out: dict): - """Update the metrics with the model output. - - Parameters - ---------- - model_out : dict - The model output. It should contain the following keys: - - logits : torch.Tensor - The model predictions. - - labels : torch.Tensor - The ground truth labels. + r"""Update the metrics with the model output. + + Args: + model_out (dict): The model output. It should contain the following keys: + - logits : torch.Tensor + The model predictions. + - labels : torch.Tensor + The ground truth labels. """ preds = model_out["logits"].cpu() target = model_out["labels"].cpu() @@ -92,21 +86,14 @@ def update(self, model_out: dict): raise ValueError(f"Invalid task {self.task}") def compute(self): - """Compute the metrics. + r"""Compute the metrics. - Returns - ------- - res_dict : dict - A dictionary containing the computed metrics. + Args: + res_dict (dict): A dictionary containing the computed metrics. """ + return self.metrics.compute() - res_dict = self.metrics.compute() - - return res_dict - - def reset( - self, - ): + def reset(self): """Reset the metrics. This method should be called after each epoch diff --git a/topobenchmarkx/io/load/download_utils.py b/topobenchmarkx/io/load/download_utils.py index 6ffca758..66c71159 100644 --- a/topobenchmarkx/io/load/download_utils.py +++ b/topobenchmarkx/io/load/download_utils.py @@ -5,7 +5,7 @@ # Function to extract file ID from Google Drive URL def get_file_id_from_url(url): - """Extracts the file ID from a Google Drive file URL. + r"""Extracts the file ID from a Google Drive file URL. Args: url (str): The Google Drive file URL. @@ -35,7 +35,7 @@ def get_file_id_from_url(url): def download_file_from_drive( file_link, path_to_save, dataset_name, file_format="tar.gz" ): - """Downloads a file from a Google Drive link and saves it to the specified + r"""Downloads a file from a Google Drive link and saves it to the specified path. Args: diff --git a/topobenchmarkx/io/load/heterophilic.py b/topobenchmarkx/io/load/heterophilic.py index f8b9c7b8..67b06e4e 100644 --- a/topobenchmarkx/io/load/heterophilic.py +++ b/topobenchmarkx/io/load/heterophilic.py @@ -7,6 +7,14 @@ def load_heterophilic_data(name, path): + r"""Load a heterophilic dataset from a .npz file. + + Args: + name (str): The name of the dataset. + path (str): The path to the directory containing the dataset file. + Returns: + torch_geometric.data.Data: The dataset. + """ file_name = f"{name}.npz" data = np.load(os.path.join(path, file_name)) @@ -26,6 +34,14 @@ def load_heterophilic_data(name, path): def download_hetero_datasets(name, path): + r"""Download a heterophilic dataset from the OpenGSL repository. + + Args: + name (str): The name of the dataset. + path (str): The path to the directory where the dataset will be saved. + Raises: + Exception: If the download fails. + """ url = "https://github.com/OpenGSL/HeterophilousDatasets/raw/main/data/" name = f"{name}.npz" try: diff --git a/topobenchmarkx/io/load/loader.py b/topobenchmarkx/io/load/loader.py index c0eb168e..1f2b7fb4 100755 --- a/topobenchmarkx/io/load/loader.py +++ b/topobenchmarkx/io/load/loader.py @@ -9,28 +9,23 @@ class AbstractLoader(ABC): """Abstract class that provides an interface to load data. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. """ def __init__(self, parameters: DictConfig): self.cfg = parameters + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.cfg})" + @abstractmethod def load( self, ) -> torch_geometric.data.Data: """Load data into Data. - Parameters - ---------- - None - - Returns - ------- - Data - Data object containing the loaded data. + Raise: + NotImplementedError: If the method is not implemented. """ raise NotImplementedError diff --git a/topobenchmarkx/io/load/loaders.py b/topobenchmarkx/io/load/loaders.py index f24eb30d..05a6b0b0 100755 --- a/topobenchmarkx/io/load/loaders.py +++ b/topobenchmarkx/io/load/loaders.py @@ -28,29 +28,24 @@ class CellComplexLoader(AbstractLoader): r"""Loader for cell complex datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. """ def __init__(self, parameters: DictConfig): super().__init__(parameters) self.parameters = parameters + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters})" def load( self, ) -> CustomDataset: r"""Load cell complex dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = load_cell_complex_dataset(self.parameters) dataset = CustomDataset([data]) @@ -60,29 +55,24 @@ def load( class SimplicialLoader(AbstractLoader): r"""Loader for simplicial datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. """ def __init__(self, parameters: DictConfig): super().__init__(parameters) self.parameters = parameters + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters})" def load( self, ) -> CustomDataset: r"""Load simplicial dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = load_simplicial_dataset(self.parameters) dataset = CustomDataset([data]) @@ -92,30 +82,26 @@ def load( class HypergraphLoader(AbstractLoader): r"""Loader for hypergraph datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. + transforms (DictConfig, optional): The parameters for the transforms to be applied to the dataset. (default: None) """ def __init__(self, parameters: DictConfig, transforms=None): super().__init__(parameters) self.parameters = parameters self.transforms_config = transforms + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters}, transforms={self.transforms_config})" def load( self, ) -> CustomDataset: r"""Load hypergraph dataset. - - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = load_hypergraph_pickle_dataset(self.parameters) data = load_split(data, self.parameters) @@ -126,29 +112,38 @@ def load( class GraphLoader(AbstractLoader): r"""Loader for graph datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. The parameters must contain the following keys: + - data_dir (str): The directory where the dataset is stored. + - data_name (str): The name of the dataset. + - data_type (str): The type of the dataset. + - split_type (str): The type of split to be used. It can be "fixed", "random", or "k-fold". + + If split_type is "random", the parameters must also contain the following keys: + - data_seed (int): The seed for the split. + - data_split_dir (str): The directory where the split is stored. + - train_prop (float): The proportion of the training set. + If split_type is "k-fold", the parameters must also contain the following keys: + - data_split_dir (str): The directory where the split is stored. + - k (int): The number of folds. + - data_seed (int): The seed for the split. + The parameters can be defined in a yaml file and then loaded using `omegaconf.OmegaConf.load('path/to/dataset/config.yaml')`. + transforms (DictConfig, optional): The parameters for the transforms to be applied to the dataset. The parameters for a transformation can be defined in a yaml file and then loaded using `omegaconf.OmegaConf.load('path/to/transform/config.yaml'). (default: None) """ - def __init__(self, parameters: DictConfig, transforms=None): super().__init__(parameters) self.parameters = parameters # Still not instantiated self.transforms_config = transforms + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters}, transforms={self.transforms_config})" def load(self) -> CustomDataset: r"""Load graph dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data_dir = os.path.join( self.parameters["data_dir"], self.parameters["data_name"] @@ -322,10 +317,9 @@ def load(self) -> CustomDataset: class ManualGraphLoader(AbstractLoader): r"""Loader for manual graph datasets. - Parameters - ---------- - parameters : DictConfig - Configuration parameters. + Args: + parameters (DictConfig): Configuration parameters. + transforms (DictConfig, optional): The parameters for the transforms to be applied to the dataset. (default: None) """ def __init__(self, parameters: DictConfig, transforms=None): @@ -333,18 +327,15 @@ def __init__(self, parameters: DictConfig, transforms=None): self.parameters = parameters # Still not instantiated self.transforms_config = transforms + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(parameters={self.parameters}, transforms={self.transforms_config})" def load(self) -> CustomDataset: r"""Load manual graph dataset. - Parameters - ---------- - None - - Returns - ------- - CustomDataset - CustomDataset object containing the loaded data. + Returns: + CustomDataset: CustomDataset object containing the loaded data. """ data = manual_graph() diff --git a/topobenchmarkx/io/load/preprocessor.py b/topobenchmarkx/io/load/preprocessor.py index feb331e2..e1ebd639 100644 --- a/topobenchmarkx/io/load/preprocessor.py +++ b/topobenchmarkx/io/load/preprocessor.py @@ -10,16 +10,12 @@ class Preprocessor(torch_geometric.data.InMemoryDataset): r"""Preprocessor for datasets. - Parameters - ---------- - data_dir : str - Path to the directory containing the data. - data_list : list - List of data objects. - transforms_config : DictConfig - Configuration parameters for the transforms. - **kwargs: optional - Additional arguments. + Args: + data_dir (str): Path to the directory containing the data. + data_list (list): List of data objects. + transforms_config (DictConfig): Configuration parameters for the transforms. + force_reload (bool): Whether to force reload the data. (default: False) + **kwargs: Optional additional arguments. """ def __init__( @@ -47,15 +43,16 @@ def __init__( ) self.save_transform_parameters() self.load(self.processed_paths[0]) + + def __repr__(self): + return f"{self.__class__.__name__}(data_dir={self.root}, data_list={self.data_list}, processed_data_dir={self.processed_data_dir}, processed_file_names={self.processed_file_names})" @property def processed_dir(self) -> str: r"""Return the path to the processed directory. - Returns - ------- - str - Path to the processed directory. + Returns: + str: Path to the processed directory. """ return self.root @@ -64,9 +61,7 @@ def processed_file_names(self) -> str: r"""Return the name of the processed file. Returns - ------- - str - Name of the processed file. + str: Name of the processed file. """ return "data.pt" @@ -75,17 +70,11 @@ def instantiate_pre_transform( ) -> torch_geometric.transforms.Compose: r"""Instantiate the pre-transforms. - Parameters - ---------- - data_dir : str - Path to the directory containing the data. - transforms_config : DictConfig - Configuration parameters for the transforms. - - Returns - ------- - torch_geometric.transforms.Compose - Pre-transform object. + Parameters: + data_dir (str): Path to the directory containing the data. + transforms_config (DictConfig): Configuration parameters for the transforms. + Returns: + torch_geometric.transforms.Compose: Pre-transform object. """ pre_transforms_dict = hydra.utils.instantiate(transforms_config) pre_transforms = torch_geometric.transforms.Compose( @@ -101,14 +90,10 @@ def set_processed_data_dir( ) -> None: r"""Set the processed data directory. - Parameters - ---------- - pre_transforms_dict : dict - Dictionary containing the pre-transforms. - data_dir : str - Path to the directory containing the data. - transforms_config : DictConfig - Configuration parameters for the transforms. + Args: + pre_transforms_dict (dict): Dictionary containing the pre-transforms. + data_dir (str): Path to the directory containing the data. + transforms_config (DictConfig): Configuration parameters for the transforms. """ # Use self.transform_parameters to define unique save/load path for each transform parameters repo_name = "_".join(list(transforms_config.keys())) diff --git a/topobenchmarkx/io/load/split_utils.py b/topobenchmarkx/io/load/split_utils.py index ce662679..df89ce09 100644 --- a/topobenchmarkx/io/load/split_utils.py +++ b/topobenchmarkx/io/load/split_utils.py @@ -9,21 +9,16 @@ # Generate splits in different fasions def k_fold_split(labels, parameters): - """Returns train and valid indices as in K-Fold Cross-Validation. If the + r"""Returns train and valid indices as in K-Fold Cross-Validation. If the split already exists it loads it automatically, otherwise it creates the split file for the subsequent runs. - Parameters - ---------- - labels : torch.Tensor - Label tensor. - parameters : DictConfig - Configuration parameters. - - Returns - ------- - dict - Dictionary containing the train, validation and test indices. + Args: + labels (torch.Tensor): Label tensor. + parameters (DictConfig): Configuration parameters. + + Returns: + dict: Dictionary containing the train, validation and test indices, with keys "train", "valid", and "test". """ data_dir = parameters.data_split_dir @@ -90,22 +85,15 @@ def k_fold_split(labels, parameters): def random_splitting(labels, parameters, global_data_seed=42): - """Adapted from https://github.com/CUAI/Non-Homophily-Benchmarks + r"""Adapted from https://github.com/CUAI/Non-Homophily-Benchmarks randomly splits label into train/valid/test splits. - Parameters - ---------- - labels : torch.Tensor - Label tensor. - parameters : DictConfig - Configuration parameters. - global_data_seed : int - Seed for the random number generator. - - Returns - ------- - dict - Dictionary containing the train, validation and test indices. + Args: + labels (torch.Tensor): Label tensor. + parameters (DictConfig): Configuration parameters. + global_data_seed (int, optional): Seed for the random number generator. (default: 42) + Returns: + dict: Dictionary containing the train, validation and test indices with keys "train", "valid", and "test". """ fold = parameters["data_seed"] data_dir = parameters["data_split_dir"] @@ -169,21 +157,14 @@ def random_splitting(labels, parameters, global_data_seed=42): def load_split(data, cfg, train_prop=0.5): - r"""Loads the split for generated by rand_train_test_idx function. - - Parameters - ---------- - data : torch_geometric.data.Data - Graph dataset. - cfg : DictConfig - Configuration parameters. - train_prop : float - Proportion of training data. - - Returns - ------- - torch_geometric.data.Data - Graph dataset with the specified split. + r"""Loads the split generated by rand_train_test_idx function. + + Args: + data (torch_geometric.data.Data): Graph dataset. + cfg (DictConfig): Configuration parameters. + train_prop (float): Proportion of training data. + Returns: + torch_geometric.data.Data: Graph dataset with the specified split. """ data_dir = os.path.join(cfg["data_split_dir"], f"train_prop={train_prop}") @@ -208,17 +189,11 @@ def load_split(data, cfg, train_prop=0.5): def assing_train_val_test_mask_to_graphs(dataset, split_idx): r"""Splits the graph dataset into train, validation, and test datasets. - Parameters - ---------- - dataset : torch_geometric.data.Dataset - Graph dataset. - split_idx : dict - Dictionary containing the indices for the train, validation, and test splits. - - Returns - ------- - datasets : list - List containing the train, validation, and test datasets. + Args: + dataset (torch_geometric.data.Dataset): Graph dataset. + split_idx (dict): Dictionary containing the indices for the train, validation, and test splits. + Returns: + list: List containing the train, validation, and test datasets. """ data_train_lst, data_val_lst, data_test_lst = [], [], [] @@ -262,17 +237,11 @@ def assing_train_val_test_mask_to_graphs(dataset, split_idx): def load_graph_tudataset_split(dataset, cfg): r"""Loads the graph dataset with the specified split. - Parameters - ---------- - dataset : torch_geometric.data.Dataset - Graph dataset. - cfg : DictConfig - Configuration parameters. - - Returns - ------- - list - List containing the train, validation, and test splits. + Args: + dataset (torch_geometric.data.Dataset): Graph dataset. + cfg (DictConfig): Configuration parameters. + Returns: + list: List containing the train, validation, and test splits. """ # Extract labels from dataset object assert ( @@ -301,17 +270,11 @@ def load_graph_tudataset_split(dataset, cfg): def load_graph_cocitation_split(dataset, cfg): r"""Loads cocitation graph datasets with the specified split. - Parameters - ---------- - dataset : torch_geometric.data.Dataset - Graph dataset. - cfg : DictConfig - Configuration parameters. - - Returns - ------- - list - List containing the train, validation, and test splits. + Args: + dataset (torch_geometric.data.Dataset): Graph dataset. + cfg (DictConfig): Configuration parameters. + Returns: + list: List containing the train, validation, and test splits. """ # Extract labels from dataset object diff --git a/topobenchmarkx/io/load/us_county_demos.py b/topobenchmarkx/io/load/us_county_demos.py index 516694d0..6828692a 100644 --- a/topobenchmarkx/io/load/us_county_demos.py +++ b/topobenchmarkx/io/load/us_county_demos.py @@ -7,19 +7,12 @@ def load_us_county_demos(path, year=2012, y_col="Election"): r"""Load US County Demos dataset. - Parameters - ---------- - path: str - Path to the dataset. - year: int - Year to load the features. - y_col: str - Column to use as label. - - Returns - ------- - torch_geometric.data.Data - Data object of the graph for the US County Demos dataset. + Args: + path (str): Path to the dataset. + year (int, optional): Year to load the features. (default: 2012) + y_col (str, optional): Column to use as label. Can be one of ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate']. (default: "Election") + Returns: + torch_geometric.data.Data: Data object of the graph for the US County Demos dataset. """ edges_df = pd.read_csv(f"{path}/county_graph.csv") diff --git a/topobenchmarkx/io/load/utils.py b/topobenchmarkx/io/load/utils.py index eb5414eb..024f8817 100755 --- a/topobenchmarkx/io/load/utils.py +++ b/topobenchmarkx/io/load/utils.py @@ -18,19 +18,12 @@ def get_complex_connectivity(complex, max_rank, signed=False): r"""Gets the connectivity matrices for the complex. - Parameters - ---------- - complex : topnetx.CellComplex, topnetx.SimplicialComplex - Cell complex. - max_rank : int - Maximum rank of the complex. - signed : bool - If True, returns signed connectivity matrices. - + Args: + complex (topnetx.CellComplex, topnetx.SimplicialComplex): Cell complex. + max_rank (int): Maximum rank of the complex. + signed (bool, optional): If True, returns signed connectivity matrices. (default: False) Returns - ------- - dict - Dictionary containing the connectivity matrices. + dict: Dictionary containing the connectivity matrices. """ practical_shape = list( np.pad(list(complex.shape), (0, max_rank + 1 - len(complex.shape))) @@ -72,37 +65,32 @@ def get_complex_connectivity(complex, max_rank, signed=False): def generate_zero_sparse_connectivity(m, n): r"""Generates a zero sparse connectivity matrix. - Parameters - ---------- - m : int - Number of rows. - n : int - Number of columns. - - Returns - ------- - torch.sparse_coo_tensor - Zero sparse connectivity matrix. + Args: + m (int): Number of rows. + n (int): Number of columns. + Returns: + torch.sparse_coo_tensor: Zero sparse connectivity matrix. """ return torch.sparse_coo_tensor((m, n)).coalesce() def load_cell_complex_dataset(cfg): - r"""Loads cell complex datasets.""" + r"""Loads cell complex datasets. + + Args: + cfg (DictConfig): Configuration parameters. + """ def load_simplicial_dataset(cfg): r"""Loads simplicial datasets. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. + Args: + cfg (DictConfig): Configuration parameters. It needs to contain the following keys: + - data_name (str): Name of the dataset. - Returns - ------- - torch_geometric.data.Data - Simplicial dataset. + Returns: + torch_geometric.data.Data: Simplicial dataset. """ if cfg["data_name"] != "KarateClub": return NotImplementedError @@ -186,15 +174,11 @@ def load_simplicial_dataset(cfg): def load_hypergraph_pickle_dataset(cfg): r"""Loads hypergraph datasets from pickle files. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. + Args: + cfg (DictConfig): Configuration parameters. - Returns - ------- - torch_geometric.data.Data - Hypergraph dataset. + Returns: + torch_geometric.data.Data: Hypergraph dataset. """ data_dir = cfg["data_dir"] print(f"Loading {cfg['data_domain']} dataset name: {cfg['data_name']}") @@ -294,15 +278,12 @@ def load_hypergraph_pickle_dataset(cfg): def get_Planetoid_pyg(cfg): r"""Loads Planetoid graph datasets from torch_geometric. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. - - Returns - ------- - torch_geometric.data.Data - Graph dataset. + Args: + cfg (DictConfig): Configuration parameters. It needs to contain the following keys: + - data_dir (str): Path to the directory containing the data. + - data_name (str): Name of the dataset. + Returns: + torch_geometric.data.Data: Graph dataset. """ data_dir, data_name = cfg["data_dir"], cfg["data_name"] dataset = torch_geometric.datasets.Planetoid(data_dir, data_name) @@ -314,15 +295,12 @@ def get_Planetoid_pyg(cfg): def get_TUDataset_pyg(cfg): r"""Loads TU graph datasets from torch_geometric. - Parameters - ---------- - cfg : DictConfig - Configuration parameters. - - Returns - ------- - list - List containing the graph dataset. + Args: + cfg (DictConfig): Configuration parameters. It needs to contain the following keys: + - data_dir (str): Path to the directory containing the data. + - data_name (str): Name of the dataset. + Returns: + list: List containing the graphs in the dataset. """ data_dir, data_name = cfg["data_dir"], cfg["data_name"] dataset = torch_geometric.datasets.TUDataset(root=data_dir, name=data_name) @@ -333,15 +311,10 @@ def get_TUDataset_pyg(cfg): def ensure_serializable(obj): r"""Ensures that the object is serializable. - Parameters - ---------- - obj : object - Object to ensure serializability. - - Returns - ------- - object - Object that is serializable. + Args: + obj (object): Object to ensure serializability. + Returns: + object: Object that is serializable. """ if isinstance(obj, dict): for key, value in obj.items(): @@ -364,15 +337,10 @@ def make_hash(o): contains only other hashable types (including any lists, tuples, sets, and dictionaries). - Parameters - ---------- - o : dict, list, tuple, set - Object to hash. - - Returns - ------- - int - Hash of the object. + Args: + o (dict, list, tuple, set): Object to hash. + Returns: + int: Hash of the object. """ sha1 = hashlib.sha1() sha1.update(str.encode(str(o))) From 9f3700d2efa325dbd23aae6424d1d5fc03cbb8a2 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Fri, 17 May 2024 06:10:30 +0200 Subject: [PATCH 24/32] hypergraphs --- configs/dataset/ZINC.yaml | 2 +- .../model/hypergraph/allsettransformer.yaml | 2 +- configs/model/hypergraph/edgnn.yaml | 6 +- configs/model/hypergraph/unignn2.yaml | 10 +- configs/train.yaml | 4 +- .../main_exp/hypergraph/allsettransformer.sh | 52 ++--- hp_scripts/main_exp/hypergraph/edgnn.sh | 135 +++++++++++ hp_scripts/main_exp/hypergraph/left_out.sh | 63 ++++++ hp_scripts/main_exp/hypergraph/unignn2.sh | 69 +++--- hp_scripts/main_exp/simplicial/SCN.sh | 211 ++++++++++-------- topobenchmarkx/io/load/loaders.py | 2 +- topobenchmarkx/models/default_network.py | 2 +- topobenchmarkx/run_hypergraph_scripts.sh | 7 + 13 files changed, 396 insertions(+), 169 deletions(-) create mode 100644 hp_scripts/main_exp/hypergraph/edgnn.sh create mode 100644 hp_scripts/main_exp/hypergraph/left_out.sh create mode 100644 topobenchmarkx/run_hypergraph_scripts.sh diff --git a/configs/dataset/ZINC.yaml b/configs/dataset/ZINC.yaml index 5a77747d..f39a9c4c 100644 --- a/configs/dataset/ZINC.yaml +++ b/configs/dataset/ZINC.yaml @@ -24,7 +24,7 @@ parameters: monitor_metric: mae task_level: graph data_seed: 0 - split_type: 'fixed' # either k-fold or test + split_type: 'fixed' # ZINC accept only split #k: 10 # for k-Fold Cross-Validation # Dataloader parameters diff --git a/configs/model/hypergraph/allsettransformer.yaml b/configs/model/hypergraph/allsettransformer.yaml index 04e62c5e..984d78ec 100755 --- a/configs/model/hypergraph/allsettransformer.yaml +++ b/configs/model/hypergraph/allsettransformer.yaml @@ -7,7 +7,7 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 proj_dropout: 0.0 backbone: diff --git a/configs/model/hypergraph/edgnn.yaml b/configs/model/hypergraph/edgnn.yaml index 6f84cce4..2512dcf3 100755 --- a/configs/model/hypergraph/edgnn.yaml +++ b/configs/model/hypergraph/edgnn.yaml @@ -7,14 +7,14 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 16 + out_channels: 128 proj_dropout: 0.0 backbone: _target_: custom_models.hypergraph.edgnn.EDGNN num_features: ${model.feature_encoder.out_channels} # ${dataset.parameters.num_features} - input_dropout: 0.2 - dropout: 0.2 + input_dropout: 0. + dropout: 0. activation: relu MLP_num_layers: 1 All_num_layers: 1 diff --git a/configs/model/hypergraph/unignn2.yaml b/configs/model/hypergraph/unignn2.yaml index bfe5e4be..1d380cfa 100755 --- a/configs/model/hypergraph/unignn2.yaml +++ b/configs/model/hypergraph/unignn2.yaml @@ -7,18 +7,18 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 proj_dropout: 0.0 backbone: _target_: topomodelx.nn.hypergraph.unigcnii.UniGCNII - in_channels: ${model.feature_encoder.out_channels} # ${dataset.parameters.num_features} + in_channels: ${model.feature_encoder.out_channels} hidden_channels: ${model.feature_encoder.out_channels} - n_layers: 1 + n_layers: 4 alpha: 0.5 beta: 0.5 - input_drop: 0.2 - layer_drop: 0.2 + input_drop: 0.0 + layer_drop: 0.0 backbone_wrapper: _target_: topobenchmarkx.models.wrappers.HypergraphWrapper diff --git a/configs/train.yaml b/configs/train.yaml index 56828c07..4acf13bb 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,8 +4,8 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: cocitation_cora #us_country_demos - - model: simplicial/scn #hypergraph/unignn2 #allsettransformer + - dataset: amazon_ratings #us_country_demos + - model: hypergraph/allsettransformer #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) diff --git a/hp_scripts/main_exp/hypergraph/allsettransformer.sh b/hp_scripts/main_exp/hypergraph/allsettransformer.sh index b0d192ee..3f7987c0 100644 --- a/hp_scripts/main_exp/hypergraph/allsettransformer.sh +++ b/hp_scripts/main_exp/hypergraph/allsettransformer.sh @@ -9,8 +9,8 @@ do dataset.parameters.data_seed=0,3,5,7,9 \ dataset.parameters.task_variable=$task_variable \ model=hypergraph/allsettransformer \ - model.feature_encoder.out_channels="32,64,128" \ - model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ model.backbone.n_layers="1,2,3,4" \ model.optimizer.lr="0.01,0.001" \ trainer.max_epochs=1000 \ @@ -33,7 +33,7 @@ do dataset.parameters.data_seed=0,3,5,7,9 \ model=hypergraph/allsettransformer \ model.feature_encoder.out_channels="32,64,128" \ - model.feature_encoder.proj_dropout="0,0.25,0.5" \ + model.feature_encoder.proj_dropout=0.25,0.5 \ model.backbone.n_layers="1,2" \ model.optimizer.lr="0.01,0.001" \ trainer.max_epochs=500 \ @@ -68,29 +68,6 @@ python train.py \ tags="[MainExperiment]" \ --multirun -# ----Heterophilic datasets---- - -datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) - -for dataset in ${datasets[*]} -do - python train.py \ - dataset=$dataset \ - model=hypergraph/allsettransformer \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64,128 \ - model.backbone.n_layers=1,2,3,4 \ - model.feature_encoder.proj_dropout=0.25,0.5 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.parameters.batch_size=128,256 \ - logger.wandb.project=TopoBenchmarkX_Hypergraph \ - trainer.max_epochs=1000 \ - trainer.min_epochs=50 \ - trainer.check_val_every_n_epoch=1 \ - callbacks.early_stopping.patience=50 \ - tags="[MainExperiment]" \ - --multirun -done # ----TU graph datasets---- # MUTAG have very few samples, so we use a smaller batch size @@ -134,3 +111,26 @@ do --multirun done +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done \ No newline at end of file diff --git a/hp_scripts/main_exp/hypergraph/edgnn.sh b/hp_scripts/main_exp/hypergraph/edgnn.sh new file mode 100644 index 00000000..87a28e90 --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/edgnn.sh @@ -0,0 +1,135 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.All_num_layers="1,2" \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + tags="[MainExperiment]" \ + --multirun +done + +# ----Graph regression dataset---- +# Train on ZINC dataset +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/hypergraph/left_out.sh b/hp_scripts/main_exp/hypergraph/left_out.sh new file mode 100644 index 00000000..94ac96e9 --- /dev/null +++ b/hp_scripts/main_exp/hypergraph/left_out.sh @@ -0,0 +1,63 @@ +# ----Heterophilic datasets---- + +datasets=( questions ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.All_num_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/allsettransformer \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done \ No newline at end of file diff --git a/hp_scripts/main_exp/hypergraph/unignn2.sh b/hp_scripts/main_exp/hypergraph/unignn2.sh index b0d192ee..f62d8ac0 100644 --- a/hp_scripts/main_exp/hypergraph/unignn2.sh +++ b/hp_scripts/main_exp/hypergraph/unignn2.sh @@ -8,10 +8,10 @@ do dataset=us_country_demos \ dataset.parameters.data_seed=0,3,5,7,9 \ dataset.parameters.task_variable=$task_variable \ - model=hypergraph/allsettransformer \ - model.feature_encoder.out_channels="32,64,128" \ - model.feature_encoder.proj_dropout="0,0.25,0.5" \ - model.backbone.n_layers="1,2,3,4" \ + model=hypergraph/unignn2 \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ model.optimizer.lr="0.01,0.001" \ trainer.max_epochs=1000 \ trainer.min_epochs=500 \ @@ -31,10 +31,10 @@ do python train.py \ dataset=$dataset \ dataset.parameters.data_seed=0,3,5,7,9 \ - model=hypergraph/allsettransformer \ - model.feature_encoder.out_channels="32,64,128" \ - model.feature_encoder.proj_dropout="0,0.25,0.5" \ - model.backbone.n_layers="1,2" \ + model=hypergraph/unignn2 \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ model.optimizer.lr="0.01,0.001" \ trainer.max_epochs=500 \ trainer.min_epochs=50 \ @@ -50,7 +50,7 @@ done python train.py \ dataset=ZINC \ seed=42,3,5,23,150 \ - model=hypergraph/allsettransformer \ + model=hypergraph/unignn2 \ model.optimizer.lr=0.01,0.001 \ model.optimizer.weight_decay=0 \ model.feature_encoder.out_channels=32,64,128 \ @@ -68,36 +68,12 @@ python train.py \ tags="[MainExperiment]" \ --multirun -# ----Heterophilic datasets---- - -datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) - -for dataset in ${datasets[*]} -do - python train.py \ - dataset=$dataset \ - model=hypergraph/allsettransformer \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64,128 \ - model.backbone.n_layers=1,2,3,4 \ - model.feature_encoder.proj_dropout=0.25,0.5 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.parameters.batch_size=128,256 \ - logger.wandb.project=TopoBenchmarkX_Hypergraph \ - trainer.max_epochs=1000 \ - trainer.min_epochs=50 \ - trainer.check_val_every_n_epoch=1 \ - callbacks.early_stopping.patience=50 \ - tags="[MainExperiment]" \ - --multirun -done - # ----TU graph datasets---- # MUTAG have very few samples, so we use a smaller batch size # Train on MUTAG dataset python train.py \ dataset=MUTAG \ - model=hypergraph/allsettransformer \ + model=hypergraph/unignn2 \ model.optimizer.lr=0.01,0.001 \ model.feature_encoder.out_channels=32,64,128 \ model.backbone.n_layers=1,2,3,4 \ @@ -119,7 +95,7 @@ for dataset in ${datasets[*]} do python train.py \ dataset=$dataset \ - model=hypergraph/allsettransformer \ + model=hypergraph/unignn2 \ model.optimizer.lr=0.01,0.001 \ model.feature_encoder.out_channels=32,64,128 \ model.backbone.n_layers=1,2,3,4 \ @@ -134,3 +110,26 @@ do --multirun done +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/unignn2 \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Hypergraph \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/simplicial/SCN.sh b/hp_scripts/main_exp/simplicial/SCN.sh index 63c9045e..9c139fcd 100644 --- a/hp_scripts/main_exp/simplicial/SCN.sh +++ b/hp_scripts/main_exp/simplicial/SCN.sh @@ -1,114 +1,137 @@ -# Create a logger file in the same repo to keep track of the experiments executed +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) -# SCN model - Fixed split -python train.py \ - dataset=ZINC \ - model=simplicial/scn \ - model.backbone.n_layers=1,2,4 \ - model.feature_encoder.out_channels=16,64 \ - model.optimizer.lr=0.01,0.001 \ - dataset.parameters.batch_size=128 \ - dataset.parameters.data_seed=0,3 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - callbacks.early_stopping.min_delta=0.005 \ - logger.wandb.project=topobenchmark_22Apr2024 \ +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + model.readout.readout_name=NoReadOut \ + dataset.transforms.graph2simplicial_lifting.signed=True \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + tags="[MainExperiment]" \ --multirun + +done -# Batch size = 1 -python train.py \ - dataset=cocitation_cora \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) -python train.py \ - dataset=cocitation_citeseer \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64 \ +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=hypergraph/edgnn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ model.backbone.n_layers=1,2 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ + model.optimizer.lr="0.01,0.001" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + tags="[MainExperiment]" \ --multirun +done +# ----Graph regression dataset---- +# Train on ZINC dataset python train.py \ - dataset=cocitation_pubmed \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=32,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun -# Vary batch size +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset python train.py \ - dataset=PROTEINS_TU \ - model=simplicial/scn \ - model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=16,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.batch_size=32 \ - dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ - --multirun + dataset=MUTAG \ + model=hypergraph/edgnn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5 \ + dataset.parameters.batch_size=32,64 \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun -python train.py \ - dataset=NCI1 \ - model=simplicial/scn \ +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=16,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.batch_size=32 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ trainer.check_val_every_n_epoch=5 \ callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ --multirun +done -python train.py \ - dataset=MUTAG \ - model=simplicial/scn \ +# ----Heterophilic datasets---- + +datasets=( roman_empire amazon_ratings tolokers minesweeper questions ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=hypergraph/edgnn \ model.optimizer.lr=0.01,0.001 \ - model.feature_encoder.out_channels=16,64 \ - model.backbone.n_layers=1,2 \ - dataset.parameters.batch_size=32 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ dataset.parameters.data_seed=0,3,5 \ - dataset.transforms.graph2simplicial_lifting.complex_dim=3 \ - dataset.transforms.graph2simplicial_lifting.signed=False \ - trainer=default \ - trainer.check_val_every_n_epoch=5 \ - callbacks.early_stopping.patience=10 \ - logger.wandb.project=topobenchmark_22Apr2024 \ + dataset.parameters.batch_size=128,256 \ + logger.wandb.project=TopoBenchmarkX_Simplicial \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ --multirun +done diff --git a/topobenchmarkx/io/load/loaders.py b/topobenchmarkx/io/load/loaders.py index f24eb30d..d1e37857 100755 --- a/topobenchmarkx/io/load/loaders.py +++ b/topobenchmarkx/io/load/loaders.py @@ -305,7 +305,7 @@ def load(self) -> CustomDataset: data_dir, dataset, self.transforms_config, - force_reload=True, + force_reload=False, ) # We need to map original dataset into custom one to make batching work diff --git a/topobenchmarkx/models/default_network.py b/topobenchmarkx/models/default_network.py index 4181c285..c338d416 100755 --- a/topobenchmarkx/models/default_network.py +++ b/topobenchmarkx/models/default_network.py @@ -68,7 +68,7 @@ def model_step( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Perform a single model step on a batch of data. - :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + :param batch: A batch of domain data. :return: A tuple containing (in order): - A tensor of losses. diff --git a/topobenchmarkx/run_hypergraph_scripts.sh b/topobenchmarkx/run_hypergraph_scripts.sh new file mode 100644 index 00000000..42c205e2 --- /dev/null +++ b/topobenchmarkx/run_hypergraph_scripts.sh @@ -0,0 +1,7 @@ +# Run the scripts from the hypergraph directory +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/edgnn.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/allsettransformer.sh +bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/unignn2.sh + +# Run in case we have time +# bash ~/projects/TopoBenchmarkX/hp_scripts/main_exp/hypergraph/left_out.sh \ No newline at end of file From 84d1322cda2c0dde84b6c335464ca0b95ff5e254 Mon Sep 17 00:00:00 2001 From: Coerulatus Date: Fri, 17 May 2024 16:23:46 +0000 Subject: [PATCH 25/32] fixed documentation --- topobenchmarkx/data/dataloaders.py | 2 +- topobenchmarkx/data/heteriphilic_dataset.py | 2 +- .../data/us_county_demos_dataset.py | 2 +- topobenchmarkx/evaluators/evaluator.py | 2 +- topobenchmarkx/models/default_network.py | 121 ++++--- .../models/encoders/all_cell_encoder.py | 74 ++--- topobenchmarkx/models/encoders/encoder.py | 14 +- topobenchmarkx/models/encoders/perceiver.py | 187 +++++------ .../models/head_models/head_model.py | 33 +- .../models/head_models/zero_cell_model.py | 37 +-- topobenchmarkx/models/losses/default_loss.py | 27 +- topobenchmarkx/models/losses/loss.py | 9 +- topobenchmarkx/models/readouts/identical.py | 11 +- .../models/readouts/propagate_signal_down.py | 15 + topobenchmarkx/models/readouts/readout.py | 31 +- .../models/wrappers/cell/can_wrapper.py | 11 +- .../models/wrappers/cell/cccn_wrapper.py | 11 +- .../models/wrappers/cell/ccxn_wrapper.py | 11 +- .../models/wrappers/cell/cwn_wrapper.py | 11 +- .../models/wrappers/graph/gnn_wrapper.py | 11 +- .../wrappers/hypergraph/hypergraph_wrapper.py | 11 +- .../models/wrappers/simplicial/san_wrapper.py | 11 +- .../wrappers/simplicial/sccn_wrapper.py | 11 +- .../wrappers/simplicial/sccnn_wrapper.py | 11 +- .../models/wrappers/simplicial/scn_wrapper.py | 18 +- topobenchmarkx/models/wrappers/wrapper.py | 25 +- .../calculate_simplicial_curvature.py | 67 ++-- .../data_manipulations/equal_gaus_features.py | 24 +- .../data_manipulations/identity_transform.py | 22 +- .../infere_knn_connectivity.py | 20 +- .../infere_radius_connectivity.py | 20 +- .../keep_only_connected_component.py | 25 +- .../keep_selected_data_fields.py | 22 +- .../data_manipulations/manipulations.py | 310 ++++++++---------- .../data_manipulations/node_degrees.py | 37 +-- .../node_features_to_float.py | 22 +- .../data_manipulations/one_hot_degree.py | 37 +-- .../one_hot_degree_features.py | 23 +- topobenchmarkx/transforms/data_transform.py | 26 +- .../feature_liftings/feature_liftings.py | 109 +++--- .../transforms/liftings/graph2cell.py | 76 ++--- .../transforms/liftings/graph2hypergraph.py | 92 +++--- .../transforms/liftings/graph2simplicial.py | 94 ++---- .../transforms/liftings/graph_lifting.py | 68 ++-- topobenchmarkx/utils/config_resolvers.py | 119 +++---- topobenchmarkx/utils/instantiators.py | 17 +- topobenchmarkx/utils/logging_utils.py | 11 +- topobenchmarkx/utils/pylogger.py | 25 +- topobenchmarkx/utils/rich_utils.py | 19 +- topobenchmarkx/utils/utils.py | 25 +- 50 files changed, 918 insertions(+), 1101 deletions(-) diff --git a/topobenchmarkx/data/dataloaders.py b/topobenchmarkx/data/dataloaders.py index 1b836577..d474a7bf 100755 --- a/topobenchmarkx/data/dataloaders.py +++ b/topobenchmarkx/data/dataloaders.py @@ -165,7 +165,7 @@ def __init__( self.dataset_test = dataset_test def __repr__(self) -> str: - return f"{self.__class__.__name__}(dataset_train={self.dataset_train}, dataset_val={self.dataset_val}, dataset_test={self.dataset_test}, batch_size={self.batch_size}, num_workers={self.hparams.num_workers}, pin_memory={self.hparams.pin_memory})" + return f"{self.__class__.__name__}(dataset_train={self.dataset_train}, dataset_val={self.dataset_val}, dataset_test={self.dataset_test}, batch_size={self.batch_size})" def train_dataloader(self) -> DataLoader: r"""Create and return the train dataloader. diff --git a/topobenchmarkx/data/heteriphilic_dataset.py b/topobenchmarkx/data/heteriphilic_dataset.py index a073ab8d..7df490ec 100644 --- a/topobenchmarkx/data/heteriphilic_dataset.py +++ b/topobenchmarkx/data/heteriphilic_dataset.py @@ -84,7 +84,7 @@ def __init__( self.data, self.slices = self.collate([data]) def __repr__(self) -> str: - return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload}, self.use_node_attr={self.use_node_attr}, self.use_edge_attr={self.use_edge_attr})" + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload})" @property def raw_dir(self) -> str: diff --git a/topobenchmarkx/data/us_county_demos_dataset.py b/topobenchmarkx/data/us_county_demos_dataset.py index df1fb632..b22883f9 100644 --- a/topobenchmarkx/data/us_county_demos_dataset.py +++ b/topobenchmarkx/data/us_county_demos_dataset.py @@ -101,7 +101,7 @@ def __init__( self.data, self.slices = self.collate([data]) def __repr__(self) -> str: - return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload}, self.use_node_attr={self.use_node_attr}, self.use_edge_attr={self.use_edge_attr})" + return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload})" @property def raw_dir(self) -> str: diff --git a/topobenchmarkx/evaluators/evaluator.py b/topobenchmarkx/evaluators/evaluator.py index 7d94a19e..3f1771f5 100755 --- a/topobenchmarkx/evaluators/evaluator.py +++ b/topobenchmarkx/evaluators/evaluator.py @@ -61,7 +61,7 @@ def __init__(self, task, **kwargs): self.best_metric = {} def __repr__(self) -> str: - return f"{self.__class__.__name__}(task={self.task})" + return f"{self.__class__.__name__}(task={self.task}, metrics={self.metrics})" def update(self, model_out: dict): r"""Update the metrics with the model output. diff --git a/topobenchmarkx/models/default_network.py b/topobenchmarkx/models/default_network.py index 4181c285..49c4fb45 100755 --- a/topobenchmarkx/models/default_network.py +++ b/topobenchmarkx/models/default_network.py @@ -6,12 +6,16 @@ from torch_geometric.data import Data class TopologicalNetworkModule(LightningModule): - """A `LightningModule` implements 8 key methods: - - Docs: - https://lightning.ai/docs/pytorch/latest/common/lightning_module.html + r"""A `LightningModule` to define a network. + + Args: + backbone (torch.nn.Module): The backbone model to train. + backbone_wrapper (torch.nn.Module): The backbone wrapper class. + readout (torch.nn.Module): The readout class. + head_model (torch.nn.Module): The head model. + loss (torch.nn.Module): The loss class. + feature_encoder (torch.nn.Module, optional): The feature encoder. (default: None) """ - def __init__( self, backbone: torch.nn.Module, @@ -22,14 +26,6 @@ def __init__( feature_encoder: torch.nn.Module | None = None, **kwargs, ) -> None: - """Initialize a `NetworkModule`. - - :param backbone: The backbone model to train. - :param readout: The readout class. - :param loss: The loss class. - :param optimizer: The optimizer to use for training. - :param scheduler: The learning rate scheduler to use for training. - """ super().__init__() # This line allows to access init params with 'self.hparams' attribute @@ -55,25 +51,28 @@ def __init__( self.metric_collector_val2 = [] self.metric_collector_test = [] + def __repr__(self) -> str: + return f"{self.__class__.__name__}(backbone={self.backbone}, readout={self.readout}, head_model={self.head_model}, loss={self.loss}, feature_encoder={self.feature_encoder})" + def forward(self, batch: Data) -> dict: - """Perform a forward pass through the model `self.backbone`. + r"""Perform a forward pass through the model `self.backbone`. - :param x: A tensor of images. - :return: A tensor of logits. + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + torch.Tensor: A tensor of logits. """ return self.backbone(batch) def model_step( self, batch: Data - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Perform a single model step on a batch of data. - - :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. + ) -> dict: + r"""Perform a single model step on a batch of data. - :return: A tuple containing (in order): - - A tensor of losses. - - A tensor of predictions. - - A tensor of target labels. + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the model output. """ # Feature Encoder @@ -97,14 +96,15 @@ def model_step( return model_out - def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> torch.Tensor: - """Perform a single training step on a batch of data from the training + def training_step(self, batch: Data, batch_idx: int) -> torch.Tensor: + r"""Perform a single training step on a batch of data from the training set. - :param batch: A batch of data (a tuple) containing the input tensor of - images and target labels. - :param batch_idx: The index of the current batch. - :return: A tensor of losses between model predictions and targets. + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + batch_idx (int): The index of the current batch. + Returns: + torch.Tensor: A tensor of losses between model predictions and targets. """ self.state_str = "Training" model_out = self.model_step(batch) @@ -123,14 +123,14 @@ def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int return model_out["loss"] def validation_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + self, batch: Data, batch_idx: int ) -> None: - """Perform a single validation step on a batch of data from the - validation set. + r"""Perform a single validation step on a batch of data from the validation + set. - :param batch: A batch of data (a tuple) containing the input tensor of - images and target labels. - :param batch_idx: The index of the current batch. + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + batch_idx (int): The index of the current batch. """ self.state_str = "Validation" model_out = self.model_step(batch) @@ -146,13 +146,14 @@ def validation_step( ) def test_step( - self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int + self, batch: Data, batch_idx: int ) -> None: - """Perform a single test step on a batch of data from the test set. + r"""Perform a single test step on a batch of data from the test + set. - :param batch: A batch of data (a tuple) containing the input tensor of - images and target labels. - :param batch_idx: The index of the current batch. + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + batch_idx (int): The index of the current batch. """ self.state_str = "Test" model_out = self.model_step(batch) @@ -168,7 +169,14 @@ def test_step( ) def process_outputs(self, model_out: dict, batch: Data) -> dict: - """Process model outputs.""" + r"""Process model outputs. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ # Get the correct mask if self.state_str == "Training": @@ -189,7 +197,11 @@ def process_outputs(self, model_out: dict, batch: Data) -> dict: return model_out def log_metrics(self, mode=None): - """Log metrics.""" + r"""Log metrics. + + Args: + mode (str, optional): The mode of the model, either "train", "val", or "test". (default: None) + """ metrics_dict = self.evaluator.compute() for key in metrics_dict: self.log( @@ -203,7 +215,7 @@ def log_metrics(self, mode=None): self.evaluator.reset() def on_validation_epoch_start(self) -> None: - """According pytorch lightning documentation, this hook is called at + r"""According pytorch lightning documentation, this hook is called at the beginning of the validation epoch. https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks @@ -217,56 +229,59 @@ def on_validation_epoch_start(self) -> None: self.train_metrics_logged = True def on_train_epoch_end(self) -> None: + r"""Lightning hook that is called when a train epoch ends. This hook is used to log the train metrics.""" # Log train metrics and reset evaluator if not self.train_metrics_logged: self.log_metrics(mode="train") self.train_metrics_logged = True def on_validation_epoch_end(self) -> None: - """Lightning hook that is called when a test epoch ends.""" + r"""Lightning hook that is called when a validation epoch ends. This hook is used to log the validation metrics.""" # Log validation metrics and reset evaluator self.log_metrics(mode="val") def on_test_epoch_end(self) -> None: - """Lightning hook that is called when a test epoch ends.""" + r"""Lightning hook that is called when a test epoch ends. This hook is used to log the test metrics.""" self.log_metrics(mode="test") print() def on_train_epoch_start(self) -> None: - """Lightning hook that is called when a test epoch ends.""" + r"""Lightning hook that is called when a train epoch begins. This hook is used to reset the train metrics.""" self.evaluator.reset() self.train_metrics_logged = False def on_val_epoch_start(self) -> None: - """Lightning hook that is called when a test epoch ends.""" + r"""Lightning hook that is called when a validation epoch begins. This hook is used to reset the validation metrics.""" self.evaluator.reset() def on_test_epoch_start(self) -> None: - """Lightning hook that is called when a test epoch ends.""" + r"""Lightning hook that is called when a test epoch begins. This hook is used to reset the test metrics.""" self.evaluator.reset() def setup(self, stage: str) -> None: - """Lightning hook that is called at the beginning of fit (train + + r"""Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP. - :param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. + Args: + stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. """ if self.hparams.compile and stage == "fit": self.net = torch.compile(self.net) def configure_optimizers(self) -> dict[str, Any]: - """Choose what optimizers and learning-rate schedulers to use in your + r"""Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you'd need one. But in the case of GANs or similar you might have multiple. Examples: https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers - :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training. + Returns: + dict: A dict containing the configured optimizers and learning-rate schedulers to be used for training. """ optimizer = self.hparams.optimizer( params=list(self.trainer.model.parameters()) diff --git a/topobenchmarkx/models/encoders/all_cell_encoder.py b/topobenchmarkx/models/encoders/all_cell_encoder.py index 469ad55d..cf086c4a 100644 --- a/topobenchmarkx/models/encoders/all_cell_encoder.py +++ b/topobenchmarkx/models/encoders/all_cell_encoder.py @@ -5,21 +5,15 @@ class AllCellFeatureEncoder(AbstractFeatureEncoder): - r"""Encoder class to apply BaseEncoder to the features of higher order - structures. - - Parameters - ---------- - in_channels: list(int) - Input dimensions for the features. - out_channels: list(int) - Output dimensions for the features. - proj_dropout: float - Dropout for the BaseEncoders. - selected_dimensions: list(int) - List of indexes to apply the BaseEncoders to. + r"""Encoder class to apply BaseEncoder to the features of higher order structures. The class creates a BaseEncoder for each dimension specified in selected_dimensions. Then during the forward pass, the BaseEncoders are applied to the features of the corresponding dimensions. + + Args: + in_channels (list[int]): Input dimensions for the features. + out_channels (list[int]): Output dimensions for the features. + proj_dropout (float, optional): Dropout for the BaseEncoders. (default: 0) + selected_dimensions (list[int], optional): List of indexes to apply the BaseEncoders to. (default: None) + **kwargs: Additional arguments. """ - def __init__( self, in_channels, @@ -47,21 +41,19 @@ def __init__( dropout=proj_dropout, ), ) - + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, dimensions={self.dimensions})" + def forward( self, data: torch_geometric.data.Data ) -> torch_geometric.data.Data: - r"""Forward pass. + r"""Forward pass. The method applies the BaseEncoders to the features of the selected_dimensions. - Parameters - ---------- - data: torch_geometric.data.Data - Input data object which should contain x_{i} features for each i in the selected_dimensions. + Args: + data (torch_geometric.data.Data): Input data object which should contain x_{i} features for each i in the selected_dimensions. - Returns - ------- - torch_geometric.data.Data - Output data object. + Returns: + torch_geometric.data.Data: Output data object with updated x_{i} features. """ if not hasattr(data, "x_0"): data.x_0 = data.x @@ -78,16 +70,11 @@ class BaseEncoder(torch.nn.Module): r"""Encoder class that uses two linear layers with GraphNorm, Relu activation function, and dropout between the two layers. - Parameters - ---------- - in_channels: int - Dimension of input features. - out_channels: int - Dimensions of output features. - dropout: float - Percentage of channels to discard between the two linear layers. + Args: + in_channels (int): Dimension of input features. + out_channels (int): Dimensions of output features. + dropout (float, optional): Percentage of channels to discard between the two linear layers. (default: 0) """ - def __init__(self, in_channels, out_channels, dropout=0): super().__init__() self.linear1 = torch.nn.Linear(in_channels, out_channels) @@ -95,21 +82,18 @@ def __init__(self, in_channels, out_channels, dropout=0): self.relu = torch.nn.ReLU() self.BN = GraphNorm(out_channels) self.dropout = torch.nn.Dropout(dropout) + + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.linear1.in_features}, out_channels={self.linear1.out_features})" def forward(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor of dimensions [N, in_channels]. - batch: torch.Tensor - The batch vector which assigns each element to a specific example. + r"""Forward pass of the encoder. It applies two linear layers with GraphNorm, Relu activation function, and dropout between the two layers. - Returns - ------- - torch.Tensor - Output tensor of shape [N, out_channels]. + Args: + x (torch.Tensor): Input tensor of dimensions [N, in_channels]. + batch (torch.Tensor): The batch vector which assigns each element to a specific example. + Returns: + torch.Tensor: Output tensor of shape [N, out_channels]. """ x = self.linear1(x) x = self.BN(x, batch=batch) if batch.shape[0] > 0 else self.BN(x) diff --git a/topobenchmarkx/models/encoders/encoder.py b/topobenchmarkx/models/encoders/encoder.py index b21cc512..6426ef6b 100644 --- a/topobenchmarkx/models/encoders/encoder.py +++ b/topobenchmarkx/models/encoders/encoder.py @@ -5,11 +5,14 @@ class AbstractFeatureEncoder(torch.nn.Module): - """Abstract class that provides an interface to define a custom feature encoder.""" + r"""Abstract class that provides an interface to define a custom feature encoder.""" def __init__(self, **kwargs): super().__init__() return + + def __repr__(self): + return f"{self.__class__.__name__}()" def __call__(self, data): return self.forward(data) @@ -18,11 +21,10 @@ def __call__(self, data): def forward( self, data: torch_geometric.data.Data ) -> torch_geometric.data.Data: - """Forward pass of the feature encoder model. - - Parameters: - :data: torch_geometric.data.Data + r"""Forward pass of the feature encoder model. + Args: + data (torch_geometric.data.Data): Input data object which should contain x features. Returns: - :data: torch_geometric.data.Data + torch_geometric.data.Data: Output data object with updated x features. """ \ No newline at end of file diff --git a/topobenchmarkx/models/encoders/perceiver.py b/topobenchmarkx/models/encoders/perceiver.py index 42381d07..b12bb36f 100644 --- a/topobenchmarkx/models/encoders/perceiver.py +++ b/topobenchmarkx/models/encoders/perceiver.py @@ -8,7 +8,6 @@ # helpers - def exists(val): return val is not None @@ -66,14 +65,10 @@ def cached_fn(*args, _cache=True, **kwargs): class PreNorm(nn.Module): r"""Class to wrap together LayerNorm and a specified function. - Parameters - ---------- - dim: int - Size of the dimension to normalize. - fn: torch.nn.Module - Function after LayerNorm. - context_dim: int - Size of the context to normalize. + Args: + dim (int): Size of the dimension to normalize. + fn (torch.nn.Module): Function after LayerNorm. + context_dim (int, optional): Size of the context to normalize. (default: None) """ def __init__(self, dim, fn, context_dim=None): @@ -83,21 +78,18 @@ def __init__(self, dim, fn, context_dim=None): self.norm_context = ( nn.LayerNorm(context_dim) if exists(context_dim) else None ) - + + def __repr__(self): + return f"{self.__class__.__name__}(dim={self.norm.normalized_shape[0]}, fn={self.fn}, context_dim={self.norm_context.normalized_shape[0] if exists(self.norm_context) else None})" + def forward(self, x, **kwargs): - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor. - kwargs: dict - Dictionary of keyword arguments. - - Returns - ------- - torch.Tensor - Output tensor. + r"""Forward pass of the PreNorm class. + + Args: + x (torch.Tensor): Input tensor. + **kwargs: Additional arguments. If context_dim is not None the context tensor should be passed. + Returns: + torch.Tensor: Output tensor. """ x = self.norm(x) @@ -113,60 +105,53 @@ class GEGLU(nn.Module): r"""GEGLU activation function.""" def forward(self, x): - r"""Forward pass. + r"""Forward pass of the GEGLU activation function. - Parameters - ---------- - x: torch.Tensor - Input tensor. + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output tensor. """ x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): - r"""Feedforward network. - - Parameters - ---------- - dim: int - Size of the input dimension. - mult: int - Multiplier for the hidden dimension. - """ + r"""Feedforward network with two linear layers and GEGLU activation function in between. + Args: + dim (int): Size of the input dimension. + mult (int, optional): Multiplier for the hidden dimension. (default: 4) + """ def __init__(self, dim, mult=4): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Linear(dim * mult, dim) ) - + + def __repr__(self): + return f"{self.__class__.__name__}(dim={self.net[0].in_features}, mult={self.net[0].out_features // self.net[0].in_features})" + def forward(self, x): - r"""Forward pass. + r"""Forward pass of the FeedForward class. - Parameters - ---------- - x: torch.Tensor - Input tensor. + Args: + x (torch.Tensor): Input tensor. + Returns: + torch.Tensor: Output tensor. """ return self.net(x) class Attention(nn.Module): - r"""Attention function. - - Parameters - ---------- - query_dim: int - Size of the query dimension. - context_dim: int - Size of the context dimension. - heads: int - Number of heads. - dim_head: int - Size for each head. - """ + r"""Attention class to calculate the attention weights. + Args: + query_dim (int): Size of the query dimension. + context_dim (int, optional): Size of the context dimension. (default: None) + heads (int, optional): Number of heads. (default: 8) + dim_head (int, optional): Size for each head. (default: 64) + """ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): super().__init__() inner_dim = dim_head * heads @@ -178,22 +163,18 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64): self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, query_dim) + def __repr__(self): + return f"{self.__class__.__name__}(query_dim={self.to_q.in_features}, context_dim={self.to_kv.in_features // 2}, heads={self.heads}, dim_head={self.to_q.out_features // self.heads})" + def forward(self, x, context=None, mask=None): - r"""Forward pass. - - Parameters - ---------- - x: torch.Tensor - Input tensor. - context: torch.Tensor - Context tensor. - mask: torch.Tensor - Mask for attention calculation purposes. - - Returns - ------- - torch.Tensor - Output tensor. + r"""Forward pass of the Attention class. + + Args: + x (torch.Tensor): Input tensor. + context (torch.Tensor, optional): Context tensor. (default: None) + mask (torch.Tensor, optional): Mask for attention calculation purposes. (default: None) + Returns: + torch.Tensor: Output tensor. """ h = self.heads @@ -225,28 +206,18 @@ def forward(self, x, context=None, mask=None): class Perceiver(nn.Module): - r"""Perceiver model. - - Parameters - ---------- - depth: int - Number of layers to add to the model. - dim: int - Size of the input dimension. - num_latents: int - Number of latent vectors. - cross_heads: int - Number of heads for cross attention. - latent_heads: int - Number of heads for latent attention. - cross_dim_head: int - Size of the cross attention head. - latent_dim_head: int - Size of the latent attention head. - weight_tie_layers: bool - Whether to tie the weights of the layers. - decoder_ff: bool - Whether to use a feedforward network in the decoder. + r"""Perceiver model. For more information https://arxiv.org/abs/2103.03206. + + Args: + depth (int): Number of layers to add to the model. + dim (int): Size of the input dimension. + num_latents (int, optional): Number of latent vectors. (default: 1) + cross_heads (int, optional): Number of heads for cross attention. (default: 1) + latent_heads (int, optional): Number of heads for latent attention. (default: 8) + cross_dim_head (int, optional): Size of the cross attention head. (default: 64) + latent_dim_head (int, optional): Size of the latent attention head. (default: 64) + weight_tie_layers (bool, optional): Whether to tie the weights of the layers. (default: False) + decoder_ff (bool, optional): Whether to use a feedforward network in the decoder. (default: False) """ def __init__( @@ -332,21 +303,31 @@ def get_latent_ff(): else None ) + self.dim = dim + self.num_latents = num_latents + self.cross_heads = cross_heads + self.latent_heads = latent_heads + self.cross_dim_head = cross_dim_head + self.latent_dim_head = latent_dim_head + self.weight_tie_layers = weight_tie_layers + self.decoder_ff = decoder_ff + # self.to_logits = ( # nn.Linear(queries_dim, logits_dim) if exists(logits_dim) else nn.Identity() # ) + + def __repr__(self): + return f"{self.__class__.__name__}(depth={len(self.layers)}, dim={self.dim}, num_latents={self.num_latents}, cross_heads={self.cross_heads}, latent_heads={self.latent_heads}, cross_dim_head={self.cross_dim_head}, latent_dim_head={self.latent_dim_head}, weight_tie_layers={self.weight_tie_layers}, decoder_ff={self.decoder_ff}" def forward(self, data, mask=None, queries=None): - r"""Forward pass. - - Parameters - ---------- - data: torch.Tensor - Input tensor. - mask: torch.Tensor - Mask for attention calculation purposes. - queries: torch.Tensor - Queries tensor. + r"""Forward pass of the Perceiver model. + + Args: + data (torch.Tensor): Input tensor. + mask (torch.Tensor, optional): Mask for attention calculation purposes. (default: None) + queries (torch.Tensor, optional): Queries tensor. (default: None) + Returns: + torch.Tensor: Output tensor. """ b, *_ = *data.shape diff --git a/topobenchmarkx/models/head_models/head_model.py b/topobenchmarkx/models/head_models/head_model.py index 117dce5d..bd26baea 100644 --- a/topobenchmarkx/models/head_models/head_model.py +++ b/topobenchmarkx/models/head_models/head_model.py @@ -3,16 +3,12 @@ from abc import abstractmethod class AbstractHeadModel(torch.nn.Module): - r"""Head model. + r"""Abstract head model class. - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. + Args: + in_channels (int): Input dimension. + out_channels (int): Output dimension. """ - def __init__( self, in_channels: int, @@ -21,6 +17,9 @@ def __init__( ): super().__init__() self.linear = torch.nn.Linear(in_channels, out_channels) + + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.linear.in_features}, out_channels={self.linear.out_features})" def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: x = self.forward(model_out, batch) @@ -29,19 +28,13 @@ def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: @abstractmethod def forward(self, model_out: dict, batch: torch_geometric.data.Data): - r"""Forward pass. - - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - batch: torch_geometric.data.Data - Batch object containing the batched domain data. + r"""Forward pass of the head model. - Returns - ------- - x: torch.Tensor - Output tensor over which the final linear layer is applied. + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + torch.Tensor: Output tensor over which the final linear layer is applied. """ pass \ No newline at end of file diff --git a/topobenchmarkx/models/head_models/zero_cell_model.py b/topobenchmarkx/models/head_models/zero_cell_model.py index 5b196821..8c6dacec 100644 --- a/topobenchmarkx/models/head_models/zero_cell_model.py +++ b/topobenchmarkx/models/head_models/zero_cell_model.py @@ -4,18 +4,13 @@ from topobenchmarkx.models.head_models.head_model import AbstractHeadModel class ZeroCellModel(AbstractHeadModel): - r"""Head model. + r"""Zero cell head model. This model produces an output based only on the features of the nodes (the zero cells). The output is obtained by applying a linear layer to the input features. Based on the task level, the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph or return a value for each node. - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. + Args: + in_channels (int): Input dimension. + out_channels (int): Output dimension. + task_level (str): Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. + pooling_type (str, optional): Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. (default: "sum") """ def __init__( @@ -34,21 +29,17 @@ def __init__( assert pooling_type in ["max", "sum", "mean"], "Invalid pooling_type" self.pooling_type = pooling_type + def __repr__(self): + return f"{self.__class__.__name__}(in_channels={self.linear.in_features}, out_channels={self.linear.out_features}, task_level={self.task_level}, pooling_type={self.pooling_type})" def forward(self, model_out: dict, batch: torch_geometric.data.Data): - r"""Forward pass. + r"""Forward pass of the zero cell head model. - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - batch: torch_geometric.data.Data - Batch object containing the batched domain data. - - Returns - ------- - x: torch.Tensor - Output tensor over which the final linear layer is applied. + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + torch.Tensor: Output tensor. """ x = model_out["x_0"] batch = batch["batch_0"] diff --git a/topobenchmarkx/models/losses/default_loss.py b/topobenchmarkx/models/losses/default_loss.py index db4893cc..3f96383e 100644 --- a/topobenchmarkx/models/losses/default_loss.py +++ b/topobenchmarkx/models/losses/default_loss.py @@ -3,8 +3,13 @@ from topobenchmarkx.models.losses.loss import AbstractltLoss class DefaultLoss(AbstractltLoss): - """Abstract class that provides an interface to loss logic within - netowrk.""" + r"""Abstract class that provides an interface to loss logic within + netowrk. + + Args: + task (str): Task type, either "classification" or "regression". + loss_type (str, optional): Loss type, either "cross_entropy", "mse", or "mae". (default: None) + """ def __init__(self, task, loss_type=None): super().__init__() @@ -20,10 +25,20 @@ def __init__(self, task, loss_type=None): else: raise Exception("Loss is not defined") + self.loss_type = loss_type + def __repr__(self) -> str: + return f'{self.__class__.__name__}(task={self.task}, loss_type={self.loss_type})' + def forward(self, model_out: dict, batch: torch_geometric.data.Data): - """Loss logic based on model_out.""" - + r"""Forward pass of the loss function. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + model_out (dict): Dictionary containing the model output with the loss. + """ logits = model_out["logits"] target = model_out["labels"] @@ -33,6 +48,4 @@ def forward(self, model_out: dict, batch: torch_geometric.data.Data): model_out["loss"] = self.criterion(logits, target) return model_out - - def __repr__(self) -> str: - return f'{self.__class__.__name__}(task={self.task}, criterion={self.criterion.__class__.__name__})' \ No newline at end of file + \ No newline at end of file diff --git a/topobenchmarkx/models/losses/loss.py b/topobenchmarkx/models/losses/loss.py index 21bf170d..e41aca40 100755 --- a/topobenchmarkx/models/losses/loss.py +++ b/topobenchmarkx/models/losses/loss.py @@ -2,16 +2,17 @@ from abc import ABC, abstractmethod class AbstractltLoss(ABC): - """Abstract class that provides an interface to loss logic within - netowrk.""" - + r"""Abstract class for the loss class.""" def __init__(self,): super().__init__() def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: - """Loss logic based on model_output.""" + r"""Loss logic based on model_output.""" return self.forward(model_out, batch) @abstractmethod def forward(self, model_out: dict, batch: torch_geometric.data.Data): pass + + def __repr__(self) -> str: + return f'{self.__class__.__name__}()' \ No newline at end of file diff --git a/topobenchmarkx/models/readouts/identical.py b/topobenchmarkx/models/readouts/identical.py index 639fb082..32c96620 100644 --- a/topobenchmarkx/models/readouts/identical.py +++ b/topobenchmarkx/models/readouts/identical.py @@ -4,11 +4,20 @@ class NoReadOut(AbstractReadOut): + r"""No readout layer. This readout layer does not perform any operation on the node embeddings.""" def __init__(self, **kwargs): super().__init__() def forward(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: + r"""Forward pass of the no readout layer. It returns the model output without any modification. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + model_out (dict): Dictionary containing the model output. + """ return model_out def __repr__(self) -> str: - return f"{self.__class__.__name__}(num_cell_dimensions={len(self.dimensions)}, hidden_dim={self.hidden_dim}, readout_name={self.name}" + return f"{self.__class__.__name__}()" diff --git a/topobenchmarkx/models/readouts/propagate_signal_down.py b/topobenchmarkx/models/readouts/propagate_signal_down.py index dd4c3ede..4088dbbf 100644 --- a/topobenchmarkx/models/readouts/propagate_signal_down.py +++ b/topobenchmarkx/models/readouts/propagate_signal_down.py @@ -4,6 +4,13 @@ from topobenchmarkx.models.readouts.readout import AbstractReadOut class PropagateSignalDown(AbstractReadOut): + r"""Propagate signal down readout layer. This readout layer propagates the signal from cells of a certain order to the cells of the lower order. + + Args: + num_cell_dimensions (int): Highest order of cells considered by the model. + hidden_dim (int): Dimension of the cells representations. + readout_name (str): Readout name. + """ def __init__(self, **kwargs): super().__init__() @@ -29,6 +36,14 @@ def __init__(self, **kwargs): ) def forward(self, model_out: dict, batch: torch_geometric.data.Data): + r"""Forward pass of the propagate signal down readout layer. The layer takes the embeddings of the cells of a certain order and applies a convolutional layer to them. Layer normalization is then applied to the features. The output is concatenated with the initial embeddings of the cells and the result is projected with the use of a linear layer to the dimensions of the cells of lower rank. The process is repeated until the nodes embeddings, which are the cells of rank 0, are reached. + + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + model_out (dict): Dictionary containing the model output. + """ for i in self.dimensions: x_i = getattr(self, f"agg_conv_{i}")( model_out[f"x_{i}"], batch[f"incidence_{i}"] diff --git a/topobenchmarkx/models/readouts/readout.py b/topobenchmarkx/models/readouts/readout.py index 7a90bdef..d663da16 100755 --- a/topobenchmarkx/models/readouts/readout.py +++ b/topobenchmarkx/models/readouts/readout.py @@ -5,21 +5,13 @@ class AbstractReadOut(torch.nn.Module): r"""Readout layer for GNNs that operates on the batch level. - - Parameters - ---------- - in_channels: int - Input dimension. - out_channels: int - Output dimension. - task_level: str - Task level, either "graph" or "node". If "graph", the readout layer will pool the node embeddings to the graph level to obtain a single graph embedding for each batched graph. If "node", the readout layer will return the node embeddings. - pooling_type: str - Pooling type, either "max", "sum", or "mean". Specifies the type of pooling operation to be used for the graph-level embedding. """ def __init__(self,): super().__init__() + + def __repr__(self): + return f"{self.__class__.__name__}()" def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: """Readout logic based on model_output.""" @@ -29,16 +21,9 @@ def __call__(self, model_out: dict, batch: torch_geometric.data.Data) -> dict: def forward(self, model_out: dict, batch: torch_geometric.data.Data): r"""Forward pass. - Parameters - ---------- - model_out: dict - Dictionary containing the model output. - - batch: torch_geometric.data.Data - Batch object containing the batched domain data. - - Returns - ------- - dict - Dictionary containing the updated model output. Resulting key is "logits". + Args: + model_out (dict): Dictionary containing the model output. + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + dict: Dictionary containing the updated model output. """ \ No newline at end of file diff --git a/topobenchmarkx/models/wrappers/cell/can_wrapper.py b/topobenchmarkx/models/wrappers/cell/can_wrapper.py index 02d6bc60..7b3da4d3 100644 --- a/topobenchmarkx/models/wrappers/cell/can_wrapper.py +++ b/topobenchmarkx/models/wrappers/cell/can_wrapper.py @@ -2,11 +2,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class CANWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the CAN model. This wrapper defines the forward pass of the model. The CAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the CAN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ x_1 = self.backbone( x_0=batch.x_0, diff --git a/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py b/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py index 97638300..8f275d94 100644 --- a/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py +++ b/topobenchmarkx/models/wrappers/cell/cccn_wrapper.py @@ -2,11 +2,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class CCCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the CCCN model. This wrapper defines the forward pass of the model. The CCCN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the CCCN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ x_1 = self.backbone( batch.x_1, diff --git a/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py b/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py index ed159c9a..0c80c813 100644 --- a/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py +++ b/topobenchmarkx/models/wrappers/cell/ccxn_wrapper.py @@ -1,11 +1,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class CCXNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the CCXN model. This wrapper defines the forward pass of the model. The CCXN model returns the embeddings of the cells of rank 0, 1, and 2.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the CCXN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + dict: Dictionary containing the updated model output. + """ x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, diff --git a/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py b/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py index 941ba2a7..a4efd697 100644 --- a/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py +++ b/topobenchmarkx/models/wrappers/cell/cwn_wrapper.py @@ -1,11 +1,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class CWNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the CWN model. This wrapper defines the forward pass of the model. The CWN model returns the embeddings of the cells of rank 0, 1, and 2.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the CWN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched domain data. + Returns: + dict: Dictionary containing the updated model output. + """ x_0, x_1, x_2 = self.backbone( x_0=batch.x_0, diff --git a/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py b/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py index 59ffa7d3..0e7b816f 100644 --- a/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py +++ b/topobenchmarkx/models/wrappers/graph/gnn_wrapper.py @@ -1,11 +1,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class GNNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the GNN models. This wrapper defines the forward pass of the model. The GNN models return the embeddings of the cells of rank 0.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the GNN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ x_0 = self.backbone(batch.x_0, batch.edge_index) model_out = {"labels": batch.y, "batch_0": batch.batch_0} diff --git a/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py b/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py index f458592b..c98a1572 100644 --- a/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py +++ b/topobenchmarkx/models/wrappers/hypergraph/hypergraph_wrapper.py @@ -1,11 +1,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class HypergraphWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the hypergraph models. This wrapper defines the forward pass of the model. The hypergraph model return the embeddings of the cells of rank 0, and 1 (the hyperedges).""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the hypergraph wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ x_0, x_1 = self.backbone(batch.x_0, batch.incidence_hyperedges) model_out = {"labels": batch.y, "batch_0": batch.batch_0} model_out["x_0"] = x_0 diff --git a/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py index a55fb308..022fc8f5 100644 --- a/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py +++ b/topobenchmarkx/models/wrappers/simplicial/san_wrapper.py @@ -2,11 +2,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class SANWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the SAN model. This wrapper defines the forward pass of the model. The SAN model returns the embeddings of the cells of rank 1. The embeddings of the cells of rank 0 are computed as the sum of the embeddings of the cells of rank 1 connected to them.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the SAN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ x_1 = self.backbone( batch.x_1, batch.up_laplacian_1, batch.down_laplacian_1 ) diff --git a/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py index 9e64112f..9433274b 100644 --- a/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py +++ b/topobenchmarkx/models/wrappers/simplicial/sccn_wrapper.py @@ -1,11 +1,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class SCCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the SCCN model. This wrapper defines the forward pass of the model. The SCCN model returns the embeddings of the cells of any rank.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the SCCN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ features = { f"rank_{r}": batch[f"x_{r}"] diff --git a/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py index d9eef982..07e7687e 100644 --- a/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py +++ b/topobenchmarkx/models/wrappers/simplicial/sccnn_wrapper.py @@ -1,11 +1,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class SCCNNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the SCCNN model. This wrapper defines the forward pass of the model. The SCCNN model returns the embeddings of the cells of rank 0, 1, and 2.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the SCCNN wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ x_all = (batch.x_0, batch.x_1, batch.x_2) laplacian_all = ( diff --git a/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py b/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py index 42e5f45a..75cb5005 100644 --- a/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py +++ b/topobenchmarkx/models/wrappers/simplicial/scn_wrapper.py @@ -2,11 +2,16 @@ from topobenchmarkx.models.wrappers.wrapper import DefaultWrapper class SCNWrapper(DefaultWrapper): - """Abstract class that provides an interface to loss logic within - network.""" + r"""Wrapper for the SCNW model. This wrapper defines the forward pass of the model. The SCNW model returns the embeddings of the cells of rank 0, 1, and 2.""" def forward(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the SCNW wrapper. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ laplacian_0 = self.normalize_matrix(batch.hodge_laplacian_0) laplacian_1 = self.normalize_matrix(batch.hodge_laplacian_1) @@ -28,6 +33,13 @@ def forward(self, batch): return model_out def normalize_matrix(self, matrix): + r"""Normalize the input matrix. The normalization is performed using the diagonal matrix of the inverse square root of the sum of the absolute values of the rows. + + Args: + matrix (torch.sparse.FloatTensor): Input matrix to be normalized. + Returns: + torch.sparse.FloatTensor: Normalized matrix. + """ matrix_ = matrix.to_dense() n, _ = matrix_.shape abs_matrix = abs(matrix_) diff --git a/topobenchmarkx/models/wrappers/wrapper.py b/topobenchmarkx/models/wrappers/wrapper.py index 372c9a07..b9d6fd9e 100755 --- a/topobenchmarkx/models/wrappers/wrapper.py +++ b/topobenchmarkx/models/wrappers/wrapper.py @@ -3,9 +3,14 @@ import torch.nn as nn class DefaultWrapper(ABC, torch.nn.Module): - """Abstract class that provides an interface to handle the network - output.""" - + r"""Abstract class that provides an interface to handle the network + output. + + Args: + backbone (torch.nn.Module): Backbone model. + out_channels (int): Number of output channels. + num_cell_dimensions (int): Number of cell dimensions. + """ def __init__(self, backbone, **kwargs): super().__init__() self.backbone = backbone @@ -19,13 +24,17 @@ def __init__(self, backbone, **kwargs): nn.LayerNorm(out_channels), ) + def __repr__(self): + return f"{self.__class__.__name__}(backbone={self.backbone}, out_channels={self.backbone.out_channels}, dimensions={self.dimensions})" + def __call__(self, batch): - """Define logic for forward pass.""" + r"""Forward pass for the model. This method calls the forward method and adds the residual connection.""" model_out = self.forward(batch) model_out = self.residual_connection(model_out=model_out, batch=batch) return model_out def residual_connection(self, model_out, batch): + r"""Residual connection for the model. This method sums, for the embeddings of the cells of any rank, the output of the model with the input embeddings and applies layer normalization.""" for i in self.dimensions: if ( (f"x_{i}" in batch) @@ -38,5 +47,11 @@ def residual_connection(self, model_out, batch): @abstractmethod def forward(self, batch): - """Define handling output here.""" + r"""Forward pass for the model. This method should be implemented by the child class. + + Args: + batch (torch_geometric.data.Data): Batch object containing the batched data. + Returns: + dict: Dictionary containing the updated model output. + """ pass diff --git a/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py b/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py index 99ff5c05..d0e9e572 100644 --- a/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py +++ b/topobenchmarkx/transforms/data_manipulations/calculate_simplicial_curvature.py @@ -2,31 +2,27 @@ import torch_geometric class CalculateSimplicialCurvature(torch_geometric.transforms.BaseTransform): - """A transform that calculates the simplicial curvature of the input graph. + r"""A transform that calculates the simplicial curvature of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the transform. """ def __init__(self, **kwargs): super().__init__() self.type = "simplicial_curvature" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" def forward(self, data: torch_geometric.data.Data): - """Apply the transform to the input data. + r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data = self.one_cell_curvature(data) data = self.zero_cell_curvature(data) @@ -37,17 +33,12 @@ def zero_cell_curvature( self, data: torch_geometric.data.Data, ) -> torch_geometric.data.Data: - """Calculate the zero cell curvature of the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. + r"""Calculate the zero cell curvature of the input data. - Returns - ------- - torch_geometric.data.Data - Data with the zero cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the zero cell curvature. """ data["0_cell_curvature"] = torch.mm( abs(data["incidence_1"]), data["1_cell_curvature"] @@ -60,15 +51,10 @@ def one_cell_curvature( ) -> torch_geometric.data.Data: r"""Calculate the one cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the one cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the one cell curvature. """ data["1_cell_curvature"] = ( 4 @@ -83,15 +69,10 @@ def two_cell_curvature( ) -> torch_geometric.data.Data: r"""Calculate the two cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the two cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the two cell curvature. """ # Term 1 is simply the degree of the 2-cell (i.e. each triangle belong to n tetrahedrons) term1 = data["2_cell_degrees"] diff --git a/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py b/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py index ad0e68b9..1926d47d 100644 --- a/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py +++ b/topobenchmarkx/transforms/data_manipulations/equal_gaus_features.py @@ -5,10 +5,10 @@ class EqualGausFeatures(torch_geometric.transforms.BaseTransform): r"""A transform that generates equal Gaussian features for all nodes in the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + mean (float): The mean of the Gaussian distribution. + std (float): The standard deviation of the Gaussian distribution. + num_features (int): The number of features to generate. """ def __init__(self, **kwargs): @@ -22,19 +22,17 @@ def __init__(self, **kwargs): self.feature_vector = torch.normal( mean=self.mean, std=self.std, size=(1, self.feature_vector) ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, mean={self.mean!r}, std={self.std!r}, feature_vector={self.feature_vector!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.x = self.feature_vector.expand(data.num_nodes, -1) return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/identity_transform.py b/topobenchmarkx/transforms/data_manipulations/identity_transform.py index d2462fbc..14b7bdc1 100644 --- a/topobenchmarkx/transforms/data_manipulations/identity_transform.py +++ b/topobenchmarkx/transforms/data_manipulations/identity_transform.py @@ -1,24 +1,26 @@ import torch_geometric class IdentityTransform(torch_geometric.transforms.BaseTransform): - r"""An identity transform that does nothing to the input data.""" + r"""An identity transform that does nothing to the input data. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "domain2domain" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The (un)transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The same data. """ return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py b/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py index 78892adf..403c280f 100644 --- a/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py +++ b/topobenchmarkx/transforms/data_manipulations/infere_knn_connectivity.py @@ -3,24 +3,26 @@ class InfereKNNConnectivity(torch_geometric.transforms.BaseTransform): r"""A transform that generates the k-nearest neighbor connectivity of the - input point cloud.""" + input point cloud. + + Args: + kwargs (optional): Parameters for the base transform.""" def __init__(self, **kwargs): super().__init__() self.type = "infere_knn_connectivity" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ edge_index = knn_graph(data.x, **self.parameters["args"]) diff --git a/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py b/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py index ad249065..d4481984 100644 --- a/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py +++ b/topobenchmarkx/transforms/data_manipulations/infere_radius_connectivity.py @@ -3,24 +3,26 @@ class InfereRadiusConnectivity(torch_geometric.transforms.BaseTransform): r"""A transform that generates the radius connectivity of the input point - cloud.""" + cloud. + + Args: + kwargs (optional): Parameters for the base transform.""" def __init__(self, **kwargs): super().__init__() self.type = "infere_radius_connectivity" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.edge_index = radius_graph(data.x, **self.parameters["args"]) return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py b/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py index 04ca2364..d65643a4 100644 --- a/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py +++ b/topobenchmarkx/transforms/data_manipulations/keep_only_connected_component.py @@ -5,31 +5,26 @@ class KeepOnlyConnectedComponent(torch_geometric.transforms.BaseTransform): """A transform that keeps only the largest connected components of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): super().__init__() self.type = "keep_connected_component" self.parameters = kwargs - + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): """Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ - # torch_geometric.transforms.largest_connected_components() num_components = self.parameters["num_components"] diff --git a/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py b/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py index deae587e..c9d400a8 100644 --- a/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py +++ b/topobenchmarkx/transforms/data_manipulations/keep_selected_data_fields.py @@ -3,10 +3,8 @@ class KeepSelectedDataFields(torch_geometric.transforms.BaseTransform): r"""A transform that keeps only the selected fields of the input data. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -14,18 +12,16 @@ def __init__(self, **kwargs): self.type = "keep_selected_data_fields" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ # Keeps all the fields fields_to_keep = ( diff --git a/topobenchmarkx/transforms/data_manipulations/manipulations.py b/topobenchmarkx/transforms/data_manipulations/manipulations.py index 7d3ebb63..7cbc147a 100644 --- a/topobenchmarkx/transforms/data_manipulations/manipulations.py +++ b/topobenchmarkx/transforms/data_manipulations/manipulations.py @@ -5,51 +5,55 @@ class IdentityTransform(torch_geometric.transforms.BaseTransform): - r"""An identity transform that does nothing to the input data.""" + r"""An identity transform that does nothing to the input data. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "domain2domain" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The (un)transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The same data. """ return data class InfereKNNConnectivity(torch_geometric.transforms.BaseTransform): r"""A transform that generates the k-nearest neighbor connectivity of the - input point cloud.""" + input point cloud. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "infere_knn_connectivity" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ - edge_index = knn_graph(data.x, **self.parameters["args"]) # Remove duplicates @@ -59,24 +63,27 @@ def forward(self, data: torch_geometric.data.Data): class InfereRadiusConnectivity(torch_geometric.transforms.BaseTransform): r"""A transform that generates the radius connectivity of the input point - cloud.""" + cloud. + + Args: + kwargs (optional): Parameters for the base transform. + """ def __init__(self, **kwargs): super().__init__() self.type = "infere_radius_connectivity" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.edge_index = radius_graph(data.x, **self.parameters["args"]) return data @@ -86,10 +93,10 @@ class EqualGausFeatures(torch_geometric.transforms.BaseTransform): r"""A transform that generates equal Gaussian features for all nodes in the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + mean (float): The mean of the Gaussian distribution. + std (float): The standard deviation of the Gaussian distribution. + num_features (int): The number of features to generate. """ def __init__(self, **kwargs): @@ -104,18 +111,16 @@ def __init__(self, **kwargs): mean=self.mean, std=self.std, size=(1, self.feature_vector) ) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, mean={self.mean!r}, std={self.std!r}, feature_vector={self.feature_vector!r}" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.x = self.feature_vector.expand(data.num_nodes, -1) return data @@ -123,29 +128,25 @@ def forward(self, data: torch_geometric.data.Data): class NodeFeaturesToFloat(torch_geometric.transforms.BaseTransform): r"""A transform that converts the node features of the input graph to float. - - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): super().__init__() self.type = "map_node_features_to_float" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r})" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.x = data.x.float() return data @@ -154,29 +155,25 @@ def forward(self, data: torch_geometric.data.Data): class NodeDegrees(torch_geometric.transforms.BaseTransform): r"""A transform that calculates the node degrees of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): super().__init__() self.type = "node_degrees" self.parameters = kwargs + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r}" def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ field_to_process = [ key @@ -194,16 +191,11 @@ def calculate_node_degrees( ) -> torch_geometric.data.Data: r"""Calculate the node degrees of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - field : str - The field to calculate the node degrees. - - Returns - ------- - torch_geometric.data.Data + Args: + data (torch_geometric.data.Data): The input data. + field (str): The field to calculate the node degrees. + Returns: + torch_geometric.data.Data: The transformed data. """ if data[field].is_sparse: degrees = abs(data[field].to_dense()).sum(1) @@ -241,10 +233,8 @@ class KeepOnlyConnectedComponent(torch_geometric.transforms.BaseTransform): """A transform that keeps only the largest connected components of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -252,18 +242,16 @@ def __init__(self, **kwargs): self.type = "keep_connected_component" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): """Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ from torch_geometric.transforms import LargestConnectedComponents @@ -279,10 +267,8 @@ def forward(self, data: torch_geometric.data.Data): class CalculateSimplicialCurvature(torch_geometric.transforms.BaseTransform): """A transform that calculates the simplicial curvature of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -290,18 +276,16 @@ def __init__(self, **kwargs): self.type = "simplicial_curvature" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r}" + def forward(self, data: torch_geometric.data.Data): """Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data = self.one_cell_curvature(data) data = self.zero_cell_curvature(data) @@ -314,15 +298,10 @@ def zero_cell_curvature( ) -> torch_geometric.data.Data: """Calculate the zero cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the zero cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the zero cell curvature added as a field. """ data["0_cell_curvature"] = torch.mm( abs(data["incidence_1"]), data["1_cell_curvature"] @@ -335,15 +314,10 @@ def one_cell_curvature( ) -> torch_geometric.data.Data: r"""Calculate the one cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the one cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the one cell curvature added as a field. """ data["1_cell_curvature"] = ( 4 @@ -358,15 +332,10 @@ def two_cell_curvature( ) -> torch_geometric.data.Data: r"""Calculate the two cell curvature of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - Data with the two cell curvature. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: Data with the two cell curvature added as a field. """ # Term 1 is simply the degree of the 2-cell (i.e. each triangle belong to n tetrahedrons) term1 = data["2_cell_degrees"] @@ -391,10 +360,8 @@ class OneHotDegreeFeatures(torch_geometric.transforms.BaseTransform): r"""A transform that adds the node degree as one hot encodings to the node features. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -404,18 +371,16 @@ def __init__(self, **kwargs): self.features_fields = kwargs["features_fields"] self.transform = OneHotDegree(max_degree=kwargs["max_degrees"]) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, degrees_field={self.deg_field!r}, features_field={self.features_fields!r}" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data = self.transform.forward( data, @@ -429,12 +394,9 @@ def forward(self, data: torch_geometric.data.Data): class OneHotDegree(torch_geometric.transforms.BaseTransform): r"""Adds the node degree as one hot encodings to the node features. - Parameters - ---------- - max_degree : int - The maximum degree of the graph. - cat : bool, optional - If set to `True`, the one hot encodings are concatenated to the node features. + Args: + max_degree (int): The maximum degree of the graph. + cat (bool, optional): Whether to concatenate the one hot encoding to the node features. (default: False) """ def __init__( @@ -445,6 +407,9 @@ def __init__( self.max_degree = max_degree self.cat = cat + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_degree={self.max_degree}, cat={self.cat})" + def forward( self, data: torch_geometric.data.Data, @@ -453,19 +418,12 @@ def forward( ) -> torch_geometric.data.Data: r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - degrees_field : str - The field containing the node degrees. - features_field : str - The field containing the node features. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + degrees_field (str): The field containing the node degrees. + features_field (str): The field containing the node features. + Returns: + torch_geometric.data.Data: The transformed data. """ assert data.edge_index is not None @@ -492,10 +450,8 @@ def __repr__(self) -> str: class KeepSelectedDataFields(torch_geometric.transforms.BaseTransform): r"""A transform that keeps only the selected fields of the input data. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -503,18 +459,16 @@ def __init__(self, **kwargs): self.type = "keep_selected_data_fields" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r}" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ # Keeps all the fields fields_to_keep = ( diff --git a/topobenchmarkx/transforms/data_manipulations/node_degrees.py b/topobenchmarkx/transforms/data_manipulations/node_degrees.py index 7bd09843..d8f07aa9 100644 --- a/topobenchmarkx/transforms/data_manipulations/node_degrees.py +++ b/topobenchmarkx/transforms/data_manipulations/node_degrees.py @@ -4,10 +4,8 @@ class NodeDegrees(torch_geometric.transforms.BaseTransform): r"""A transform that calculates the node degrees of the input graph. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): @@ -15,18 +13,16 @@ def __init__(self, **kwargs): self.type = "node_degrees" self.parameters = kwargs + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ field_to_process = [ key @@ -44,16 +40,11 @@ def calculate_node_degrees( ) -> torch_geometric.data.Data: r"""Calculate the node degrees of the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - field : str - The field to calculate the node degrees. - - Returns - ------- - torch_geometric.data.Data + Args: + data (torch_geometric.data.Data): The input data. + field (str): The field to calculate the node degrees. + Returns: + torch_geometric.data.Data: The transformed data. """ if data[field].is_sparse: degrees = abs(data[field].to_dense()).sum(1) diff --git a/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py b/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py index 4422d39e..a49689f4 100644 --- a/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py +++ b/topobenchmarkx/transforms/data_manipulations/node_features_to_float.py @@ -3,28 +3,24 @@ class NodeFeaturesToFloat(torch_geometric.transforms.BaseTransform): r"""A transform that converts the node features of the input graph to float. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ def __init__(self, **kwargs): super().__init__() self.type = "map_node_features_to_float" + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data.x = data.x.float() return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py index 6ed0e333..580f87c1 100644 --- a/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py @@ -5,14 +5,10 @@ class OneHotDegree(torch_geometric.transforms.BaseTransform): r"""Adds the node degree as one hot encodings to the node features. - Parameters - ---------- - max_degree : int - The maximum degree of the graph. - cat : bool, optional - If set to `True`, the one hot encodings are concatenated to the node features. + Args: + max_degree (int): The maximum degree of the graph. + cat (bool, optional): If set to `True`, the one hot encodings are concatenated to the node features. (default: False) """ - def __init__( self, max_degree: int, @@ -20,6 +16,9 @@ def __init__( ) -> None: self.max_degree = max_degree self.cat = cat + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_degree={self.max_degree}, cat={self.cat})" def forward( self, @@ -29,19 +28,12 @@ def forward( ) -> torch_geometric.data.Data: r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - degrees_field : str - The field containing the node degrees. - features_field : str - The field containing the node features. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + degrees_field (str): The field containing the node degrees. + features_field (str): The field containing the node features. + Returns: + torch_geometric.data.Data: The transformed data. """ assert data.edge_index is not None @@ -59,7 +51,4 @@ def forward( else: data[features_field] = deg - return data - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.max_degree})" \ No newline at end of file + return data \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py index 13b75043..2347fb79 100644 --- a/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py @@ -7,12 +7,9 @@ class OneHotDegreeFeatures(torch_geometric.transforms.BaseTransform): r"""A transform that adds the node degree as one hot encodings to the node features. - Parameters - ---------- - **kwargs : optional - Parameters for the transform. + Args: + kwargs (optional): Parameters for the base transform. """ - def __init__(self, **kwargs): super().__init__() self.type = "one_hot_degree_features" @@ -20,18 +17,16 @@ def __init__(self, **kwargs): self.features_fields = kwargs["features_fields"] self.transform = OneHotDegree(max_degree=kwargs["max_degrees"]) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r}, degrees_field={self.deg_field!r}, features_field={self.features_fields!r})" + def forward(self, data: torch_geometric.data.Data): r"""Apply the transform to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - torch_geometric.data.Data - The transformed data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The transformed data. """ data = self.transform.forward( data, diff --git a/topobenchmarkx/transforms/data_transform.py b/topobenchmarkx/transforms/data_transform.py index 3f557229..701cd58a 100755 --- a/topobenchmarkx/transforms/data_transform.py +++ b/topobenchmarkx/transforms/data_transform.py @@ -2,15 +2,12 @@ from topobenchmarkx.transforms import TRANSFORMS class DataTransform(torch_geometric.transforms.BaseTransform): - """Abstract class that provides an interface to define a custom data + r"""Abstract class that provides an interface to define a custom data lifting. - Parameters - ---------- - transform_name : str - The name of the transform to be used. - **kwargs : optional - Additional arguments for the class. + Args: + transform_name (str): The name of the transform to be used. + **kwargs: Additional arguments for the class. """ def __init__(self, transform_name, **kwargs): @@ -28,17 +25,12 @@ def __init__(self, transform_name, **kwargs): def forward( self, data: torch_geometric.data.Data ) -> torch_geometric.data.Data: - """Forward pass of the lifting. + r"""Forward pass of the lifting. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - transformed_data : torch_geometric.data.Data - The lifted data. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + transformed_data (torch_geometric.data.Data): The lifted data. """ transformed_data = self.transform(data) return transformed_data diff --git a/topobenchmarkx/transforms/feature_liftings/feature_liftings.py b/topobenchmarkx/transforms/feature_liftings/feature_liftings.py index 90c9f99e..70ee93b2 100644 --- a/topobenchmarkx/transforms/feature_liftings/feature_liftings.py +++ b/topobenchmarkx/transforms/feature_liftings/feature_liftings.py @@ -5,14 +5,14 @@ class ProjectionSum(torch_geometric.transforms.BaseTransform): r"""Lifts r-cell features to r+1-cells by projection. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ - def __init__(self, **kwargs): super().__init__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" def lift_features( self, data: torch_geometric.data.Data | dict @@ -20,15 +20,10 @@ def lift_features( r"""Projects r-cell features of a graph to r+1-cell structures using the incidence matrix. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The data with the lifted features. """ keys = sorted( [key.split("_")[1] for key in data if "incidence" in key] @@ -47,15 +42,10 @@ def forward( ) -> torch_geometric.data.Data | dict: r"""Applies the lifting to the input data. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ data = self.lift_features(data) return data @@ -64,14 +54,14 @@ def forward( class ConcatentionLifting(torch_geometric.transforms.BaseTransform): r"""Lifts r-cell features to r+1-cells by concatenation. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ - def __init__(self, **kwargs): super().__init__() + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" def lift_features( self, data: torch_geometric.data.Data | dict @@ -79,17 +69,11 @@ def lift_features( r"""Concatenates r-cell features to r+1-cell structures using the incidence matrix. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ - keys = sorted( [key.split("_")[1] for key in data if "incidence" in key] ) @@ -120,15 +104,10 @@ def forward( ) -> torch_geometric.data.Data | dict: r"""Applies the lifting to the input data. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ data = self.lift_features(data) return data @@ -137,30 +116,25 @@ def forward( class SetLifting(torch_geometric.transforms.BaseTransform): r"""Lifts r-cell features to r+1-cells by set operations. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ - def __init__(self, **kwargs): super().__init__() + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + def lift_features( self, data: torch_geometric.data.Data | dict ) -> torch_geometric.data.Data | dict: r"""Concatenates r-cell features to r+1-cell structures using the incidence matrix. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ keys = sorted( @@ -204,15 +178,10 @@ def forward( ) -> torch_geometric.data.Data | dict: r"""Applies the lifting to the input data. - Parameters - ---------- - data : torch_geometric.data.Data | dict - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data | dict - The lifted data. + Args: + data (torch_geometric.data.Data | dict): The input data to be lifted. + Returns: + torch_geometric.data.Data | dict: The lifted data. """ data = self.lift_features(data) return data diff --git a/topobenchmarkx/transforms/liftings/graph2cell.py b/topobenchmarkx/transforms/liftings/graph2cell.py index 2f2224a4..84a86066 100755 --- a/topobenchmarkx/transforms/liftings/graph2cell.py +++ b/topobenchmarkx/transforms/liftings/graph2cell.py @@ -17,32 +17,26 @@ class Graph2CellLifting(GraphLifting): r"""Abstract class for lifting graphs to cell complexes. - Parameters - ---------- - complex_dim : int, optional - The dimension of the cell complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. + Args: + complex_dim (int, optional): The dimension of the cell complex to be generated. (default: 2) + kwargs (optional): Additional arguments for the class. """ - def __init__(self, complex_dim=2, **kwargs): super().__init__(**kwargs) self.complex_dim = complex_dim self.type = "graph2cell" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(complex_dim={self.complex_dim!r}, type={self.type!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to cell complex domain. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError @@ -51,17 +45,11 @@ def _get_lifted_topology( ) -> dict: r"""Returns the lifted topology. - Parameters - ---------- - cell_complex : CellComplex - The cell complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. + Args: + cell_complex (CellComplex): The cell complex. + graph (nx.Graph): The input graph. + Returns: + dict: The lifted topology. """ lifted_topology = get_complex_connectivity( cell_complex, self.complex_dim @@ -80,34 +68,28 @@ def _get_lifted_topology( class CellCyclesLifting(Graph2CellLifting): - r"""Lifts graphs to cell complexes by identifying the cycles as 2-cells. + r"""Lifts graphs to cell complexes by taking as 2-cells a cycle base for the graph. - Parameters - ---------- - max_cell_length : int, optional - The maximum length of the cycles to be lifted. Default is None. - **kwargs : optional - Additional arguments for the class. + Args: + max_cell_length (int, optional): The maximum length of the cycles to be lifted. Default is None. + kwargs (optional): Additional arguments for the class. """ - def __init__(self, max_cell_length=None, **kwargs): super().__init__(**kwargs) self.complex_dim = 2 self.max_cell_length = max_cell_length - + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_cell_length={self.max_cell_length!r}, complex_dim={self.complex_dim!r})" + def lift_topology(self, data: torch_geometric.data.Data) -> dict: - r"""Finds the cycles of a graph and lifts them to 2-cells. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. + r"""Finds a cycle base for the graph and lifts its cycles to 2-cells. - Returns - ------- - dict - The lifted topology. - """ + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. + """ G = self._generate_graph_from_data(data) cycles = nx.cycle_basis(G) cell_complex = CellComplex(G) diff --git a/topobenchmarkx/transforms/liftings/graph2hypergraph.py b/topobenchmarkx/transforms/liftings/graph2hypergraph.py index 0ac80a82..62307eda 100755 --- a/topobenchmarkx/transforms/liftings/graph2hypergraph.py +++ b/topobenchmarkx/transforms/liftings/graph2hypergraph.py @@ -14,61 +14,51 @@ class Graph2HypergraphLifting(GraphLifting): r"""Abstract class for lifting graphs to hypergraphs. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ def __init__(self, **kwargs): super().__init__(**kwargs) self.type = "graph2hypergraph" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(type={self.type!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to hypergraph domain. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError class HypergraphKHopLifting(Graph2HypergraphLifting): - r"""Lifts graphs to hypergraph domain by considering k-hop neighborhoods. - - Parameters - ---------- - k_value : int, optional - The number of hops to consider. Default is 1. - **kwargs : optional - Additional arguments for the class. + r"""Lifts graphs to hypergraph domain by considering k-hop neighborhoods of a node. This lifting extracts a number of hyperedges equal to the number of nodes in the graph. + + Args: + k_value (int, optional): The number of hops to consider. (default: 1) + kwargs (optional): Additional arguments for the class. """ - def __init__(self, k_value=1, **kwargs): super().__init__(**kwargs) self.k = k_value + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(k={self.k!r})" def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to hypergraph domain by considering k-hop neighborhoods. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ # Check if data has instance x: if hasattr(data, "x") and data.x is not None: @@ -111,37 +101,31 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: class HypergraphKNearestNeighborsLifting(Graph2HypergraphLifting): - r"""Lifts graphs to hypergraph domain by considering k-nearest neighbors. - - Parameters - ---------- - k_value : int, optional - The number of nearest neighbors to consider. Default is 1. - loop: boolean, optional - If True the hyperedges will contain the node they were created from. - **kwargs : optional - Additional arguments for the class. + r"""Lifts graphs to hypergraph domain by considering k-nearest neighbors. This lifting extracts a number of hyperedges equal to the number of nodes in the graph. The hyperedges all contain the same number of nodes, which is equal to the number of nearest neighbors considered. + + Args: + k_value (int, optional): The number of nearest neighbors to consider. (default: 1) + loop (bool, optional): If True the hyperedges will contain the node they were created from. (default: True) + cosine (bool, optional): If True the cosine distance will be used instead of the Euclidean distance. (default: False) + kwargs (optional): Additional arguments for the class. """ - - def __init__(self, k_value=1, loop=True, **kwargs): + def __init__(self, k_value=1, loop=True, cosine=False, **kwargs): super().__init__() self.k = k_value self.loop = loop - self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop) + self.transform = torch_geometric.transforms.KNNGraph(self.k, self.loop, cosine=cosine) + def __repr__(self) -> str: + return f"{self.__class__.__name__}(k={self.k!r}, loop={self.loop!r})" + def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to hypergraph domain by considering k-nearest neighbors. - - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ num_nodes = data.x.shape[0] data.pos = data.x diff --git a/topobenchmarkx/transforms/liftings/graph2simplicial.py b/topobenchmarkx/transforms/liftings/graph2simplicial.py index 036ba976..89f961a5 100755 --- a/topobenchmarkx/transforms/liftings/graph2simplicial.py +++ b/topobenchmarkx/transforms/liftings/graph2simplicial.py @@ -20,33 +20,27 @@ class Graph2SimplicialLifting(GraphLifting): r"""Abstract class for lifting graphs to simplicial complexes. - Parameters - ---------- - complex_dim : int, optional - The dimension of the simplicial complex to be generated. Default is 2. - **kwargs : optional - Additional arguments for the class. + Args: + complex_dim (int, optional): The maximum dimension of the simplicial complex to be generated. (default: 2) + kwargs (optional): Additional arguments for the class. """ - def __init__(self, complex_dim=2, **kwargs): super().__init__(**kwargs) self.complex_dim = complex_dim self.type = "graph2simplicial" self.signed = kwargs.get("signed", False) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(complex_dim={self.complex_dim!r}, type={self.type!r}, signed={self.signed!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to simplicial complex domain. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError @@ -55,17 +49,11 @@ def _get_lifted_topology( ) -> dict: r"""Returns the lifted topology. - Parameters - ---------- - simplicial_complex : SimplicialComplex - The simplicial complex. - graph : nx.Graph - The input graph. - - Returns - ------- - dict - The lifted topology. + Args: + simplicial_complex (SimplicialComplex): The simplicial complex. + graph (nx.Graph): The input graph. + Returns: + dict: The lifted topology. """ lifted_topology = get_complex_connectivity( simplicial_complex, self.complex_dim, signed=self.signed @@ -93,33 +81,27 @@ def _get_lifted_topology( class SimplicialNeighborhoodLifting(Graph2SimplicialLifting): r"""Lifts graphs to simplicial complex domain by considering k-hop - neighborhoods. - - Parameters - ---------- - max_k_simplices : int, optional - The maximum number of k-simplices to consider. Default is 5000. - **kwargs : optional - Additional arguments for the class. - """ + neighborhoods. For each node its neighborhood is selected and then all the possible simplices, when considering the neighborhood as a clique, are added to the simplicial complex. For this reason this lifting does not conserve the initial graph topology. + Args: + max_k_simplices (int, optional): The maximum number of k-simplices to consider. (default: 5000) + kwargs (optional): Additional arguments for the class. + """ def __init__(self, max_k_simplices=5000, **kwargs): super().__init__(**kwargs) self.max_k_simplices = max_k_simplices + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(max_k_simplices={self.max_k_simplices!r})" def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to simplicial complex domain by considering k-hop neighborhoods. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ graph = self._generate_graph_from_data(data) simplicial_complex = SimplicialComplex(graph) @@ -147,13 +129,10 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: class SimplicialCliqueLifting(Graph2SimplicialLifting): - r"""Lifts graphs to simplicial complex domain by identifying the cliques as - k-simplices. + r"""Lifts graphs to simplicial complex domain by identifying the cliques as k-simplices, considering also all the combinations with lower rank. - Parameters - ---------- - **kwargs : optional - Additional arguments for the class. + Args: + kwargs (optional): Additional arguments for the class. """ def __init__(self, **kwargs): @@ -163,15 +142,10 @@ def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to a simplicial complex by identifying the cliques as k-simplices. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ graph = self._generate_graph_from_data(data) simplicial_complex = SimplicialComplex(graph) diff --git a/topobenchmarkx/transforms/liftings/graph_lifting.py b/topobenchmarkx/transforms/liftings/graph_lifting.py index 0392f2dc..4dab940b 100644 --- a/topobenchmarkx/transforms/liftings/graph_lifting.py +++ b/topobenchmarkx/transforms/liftings/graph_lifting.py @@ -22,36 +22,29 @@ class GraphLifting(torch_geometric.transforms.BaseTransform): r"""Abstract class for lifting graph topologies to higher-order topological domains. - Parameters - ---------- - feature_lifting : str, optional - The feature lifting method to be used. Default is 'projection'. - preserve_edge_attr : bool, optional - Whether to preserve edge attributes. Default is False. - **kwargs : optional - Additional arguments for the class. + Args: + feature_lifting (str, optional): The feature lifting method to be used. (default: 'projection') + preserve_edge_attr (bool, optional): Whether to preserve edge attributes. (default: False) + kwargs (optional): Additional arguments for the class. """ - def __init__( self, feature_lifting="projection", preserve_edge_attr=False, **kwargs ): super().__init__() self.feature_lifting = FEATURE_LIFTINGS[feature_lifting]() self.preserve_edge_attr = preserve_edge_attr + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(feature_lifting={self.feature_lifting!r}, preserve_edge_attr={self.preserve_edge_attr!r})" @abstractmethod def lift_topology(self, data: torch_geometric.data.Data) -> dict: r"""Lifts the topology of a graph to higher-order topological domains. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - dict - The lifted topology. + Args: + data (torch_geometric.data.Data): The input data to be lifted. + Returns: + dict: The lifted topology. """ raise NotImplementedError @@ -60,15 +53,10 @@ def forward( ) -> torch_geometric.data.Data: r"""Applies the full lifting (topology + features) to the input data. - Parameters - ---------- - data : torch_geometric.data.Data - The input data to be lifted. - - Returns - ------- - torch_geometric.data.Data - The lifted data. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + torch_geometric.data.Data: The output data. """ initial_data = data.to_dict() lifted_topology = self.lift_topology(data) @@ -81,15 +69,10 @@ def forward( def _data_has_edge_attr(self, data: torch_geometric.data.Data) -> bool: r"""Checks if the input data object has edge attributes. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - bool - Whether the data object has edge attributes. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + bool: Whether the data object has edge attributes. """ return hasattr(data, "edge_attr") and data.edge_attr is not None @@ -98,15 +81,10 @@ def _generate_graph_from_data( ) -> nx.Graph: r"""Generates a NetworkX graph from the input data object. - Parameters - ---------- - data : torch_geometric.data.Data - The input data. - - Returns - ------- - nx.Graph - The generated NetworkX graph. + Args: + data (torch_geometric.data.Data): The input data. + Returns: + nx.Graph: The generated NetworkX graph. """ # Check if data object have edge_attr, return list of tuples as [(node_id, {'features':data}, 'dim':1)] or ?? nodes = [ diff --git a/topobenchmarkx/utils/config_resolvers.py b/topobenchmarkx/utils/config_resolvers.py index da79d56b..f6fcab17 100644 --- a/topobenchmarkx/utils/config_resolvers.py +++ b/topobenchmarkx/utils/config_resolvers.py @@ -1,22 +1,13 @@ def get_default_transform(data_domain, model): r"""Get default transform for a given data domain and model. - Parameters - ---------- - data_domain: str - Data domain. - model: str - Model name. Should be in the format "model_domain/name". - - Returns - ------- - str - Default transform. - - Raises - ------ - ValueError - If the combination of data_domain and model is invalid. + Args: + data_domain (str): Data domain. + model (str): Model name. Should be in the format "model_domain/name". + Returns: + str: Default transform. + Raises: + ValueError: If the combination of data_domain and model is invalid. """ model_domain = model.split("/")[0] if data_domain == model_domain: @@ -32,22 +23,13 @@ def get_default_transform(data_domain, model): def get_monitor_metric(task, metric): r"""Get monitor metric for a given task and loss. - Parameters - ---------- - task: str - Task, either "classification" or "regression". - loss: str - Name of the loss function. - - Returns - ------- - str - Monitor metric. - - Raises - ------ - ValueError - If the task is invalid. + Args: + task (str): Task, either "classification" or "regression". + loss (str): Name of the loss function. + Returns: + str: Monitor metric. + Raises: + ValueError: If the task is invalid. """ if task == "classification" or task == "regression": return f"val/{metric}" @@ -58,20 +40,12 @@ def get_monitor_metric(task, metric): def get_monitor_mode(task): r"""Get monitor mode for a given task. - Parameters - ---------- - task: str - Task, either "classification" or "regression". - - Returns - ------- - str - Monitor mode, either "max" or "min". - - Raises - ------ - ValueError - If the task is invalid. + Args: + task (str): Task, either "classification" or "regression". + Returns: + str: Monitor mode, either "max" or "min". + Raises: + ValueError: If the task is invalid. """ if task == "classification": return "max" @@ -84,31 +58,19 @@ def get_monitor_mode(task): def infer_in_channels(dataset): r"""Infer the number of input channels for a given dataset. - Parameters - ---------- - dataset: torch_geometric.data.Dataset - Input dataset. - - Returns - ------- - list - List with dimensions of the input channels. + Args: + dataset (torch_geometric.data.Dataset): Input dataset. + Returns: + list: List with dimensions of the input channels. """ - def find_complex_lifting(dataset): r"""Find if there is a complex lifting in the dataset. - Parameters - ---------- - dataset: torch_geometric.data.Dataset - Input dataset. - - Returns - ------- - bool - True if there is a complex lifting, False otherwise. - str - Name of the complex lifting, if it exists. + Args: + dataset (torch_geometric.data.Dataset): Input dataset. + Returns: + bool: True if there is a complex lifting, False otherwise. + str: Name of the complex lifting, if it exists. """ if "transforms" not in dataset: return False, None @@ -125,17 +87,11 @@ def find_complex_lifting(dataset): def check_for_type_feature_lifting(dataset, lifting): r"""Check the type of feature lifting in the dataset. - Parameters - ---------- - dataset: torch_geometric.data.Dataset - Input dataset. - lifting: str - Name of the complex lifting. - - Returns - ------- - str - Type of feature lifting. + Args: + dataset (torch_geometric.data.Dataset): Input dataset. + lifting (str): Name of the complex lifting. + Returns: + str: Type of feature lifting. """ lifting_params_keys = dataset.transforms[lifting].keys() if "feature_lifting" in lifting_params_keys: @@ -208,4 +164,11 @@ def check_for_type_feature_lifting(dataset, lifting): def infere_list_length(list): + r"""Infer the length of a list. + + Args: + list (list): Input list. + Returns: + int: Length of the input list. + """ return len(list) diff --git a/topobenchmarkx/utils/instantiators.py b/topobenchmarkx/utils/instantiators.py index 2e01a1b6..8f926cc8 100755 --- a/topobenchmarkx/utils/instantiators.py +++ b/topobenchmarkx/utils/instantiators.py @@ -9,11 +9,12 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: - """Instantiates callbacks from config. + r"""Instantiates callbacks from config. - :param callbacks_cfg: A DictConfig object containing callback - configurations. - :return: A list of instantiated callbacks. + Args: + callbacks_cfg (DictConfig): A DictConfig object containing callback configurations. + Returns: + list[Callback]: A list of instantiated callbacks. """ callbacks: list[Callback] = [] @@ -33,10 +34,12 @@ def instantiate_callbacks(callbacks_cfg: DictConfig) -> list[Callback]: def instantiate_loggers(logger_cfg: DictConfig) -> list[Logger]: - """Instantiates loggers from config. + r"""Instantiates loggers from config. - :param logger_cfg: A DictConfig object containing logger configurations. - :return: A list of instantiated loggers. + Args: + logger_cfg (DictConfig): A DictConfig object containing logger configurations. + Returns: + list[Logger]: A list of instantiated loggers. """ logger: list[Logger] = [] diff --git a/topobenchmarkx/utils/logging_utils.py b/topobenchmarkx/utils/logging_utils.py index 316d0191..459675fe 100755 --- a/topobenchmarkx/utils/logging_utils.py +++ b/topobenchmarkx/utils/logging_utils.py @@ -10,15 +10,16 @@ @rank_zero_only def log_hyperparameters(object_dict: dict[str, Any]) -> None: - """Controls which config parts are saved by Lightning loggers. + r"""Controls which config parts are saved by Lightning loggers. Additionally saves: - Number of model parameters - :param object_dict: A dictionary containing the following objects: - - `"cfg"`: A DictConfig object containing the main config. - - `"model"`: The Lightning model. - - `"trainer"`: The Lightning trainer. + Args: + object_dict (dict[str, Any]): A dictionary containing the following objects: + - `"cfg"`: A DictConfig object containing the main config. + - `"model"`: The Lightning model. + - `"trainer"`: The Lightning trainer. """ hparams = {} diff --git a/topobenchmarkx/utils/pylogger.py b/topobenchmarkx/utils/pylogger.py index 3b8222fd..3a24e0df 100755 --- a/topobenchmarkx/utils/pylogger.py +++ b/topobenchmarkx/utils/pylogger.py @@ -16,30 +16,35 @@ def __init__( rank_zero_only: bool = False, extra: Mapping[str, object] | None = None, ) -> None: - """Initializes a multi-GPU-friendly python command line logger that + r"""Initializes a multi-GPU-friendly python command line logger that logs on all processes with their rank prefixed in the log message. - :param name: The name of the logger. Default is ``__name__``. - :param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`. - :param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`. + Args: + name (str, optional): The name of the logger. (default: __name__) + rank_zero_only (bool, optional): Whether to force all logs to only occur on the rank zero process. (default: False) + extra (Mapping[str, object], optional): A dict-like object which provides contextual information. See `logging.LoggerAdapter`. (default: None) """ logger = logging.getLogger(name) super().__init__(logger=logger, extra=extra) self.rank_zero_only = rank_zero_only + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(name={self.logger.name!r}, rank_zero_only={self.rank_zero_only!r}, extra={self.extra})" def log( self, level: int, msg: str, rank: int | None = None, *args, **kwargs ) -> None: - """Delegate a log call to the underlying logger, after prefixing its + r"""Delegate a log call to the underlying logger, after prefixing its message with the rank of the process it's being logged from. If `'rank'` is provided, then the log will only occur on that rank/process. - :param level: The level to log at. Look at `logging.__init__.py` for more information. - :param msg: The message to log. - :param rank: The rank to log at. - :param args: Additional args to pass to the underlying logging function. - :param kwargs: Any additional keyword args to pass to the underlying logging function. + Args: + level (int): The level to log at. Look at `logging.__init__.py` for more information. + msg (str): The message to log. + rank (int, optional): The rank to log at. (default: None) + args: Additional args to pass to the underlying logging function. + kwargs: Any additional keyword args to pass to the underlying logging function. """ if self.isEnabledFor(level): msg, kwargs = self.process(msg, kwargs) diff --git a/topobenchmarkx/utils/rich_utils.py b/topobenchmarkx/utils/rich_utils.py index 6cf5080c..7a900459 100755 --- a/topobenchmarkx/utils/rich_utils.py +++ b/topobenchmarkx/utils/rich_utils.py @@ -29,14 +29,14 @@ def print_config_tree( resolve: bool = False, save_to_file: bool = False, ) -> None: - """Prints the contents of a DictConfig as a tree structure using the Rich + r"""Prints the contents of a DictConfig as a tree structure using the Rich library. - :param cfg: A DictConfig composed by Hydra. - :param print_order: Determines in what order config components are printed. Default is ``("data", "model", - "callbacks", "logger", "trainer", "paths", "extras")``. - :param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``. - :param save_to_file: Whether to export config to the hydra output folder. Default is ``False``. + Args: + cfg (DictConfig): A DictConfig composed by Hydra. + print_order (Sequence[str], optional): Determines in what order config components are printed. (default: `("data", "model", "callbacks", "logger", "trainer", "paths", "extras")`). + resolve (bool, optional): Whether to resolve reference fields of DictConfig. (default: False) + save_to_file (bool, optional): Whether to export config to the hydra output folder. (default: False) """ style = "dim" tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) @@ -81,11 +81,12 @@ def print_config_tree( @rank_zero_only def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None: - """Prompts user to input tags from command line if no tags are provided in + r"""Prompts user to input tags from command line if no tags are provided in config. - :param cfg: A DictConfig composed by Hydra. - :param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``. + Args: + cfg (DictConfig): A DictConfig composed by Hydra. + save_to_file (bool, optional): Whether to export tags to the hydra output folder. (default: False). """ if not cfg.get("tags"): if "id" in HydraConfig().cfg.hydra.job: diff --git a/topobenchmarkx/utils/utils.py b/topobenchmarkx/utils/utils.py index 74fdfbb5..2c6768ae 100755 --- a/topobenchmarkx/utils/utils.py +++ b/topobenchmarkx/utils/utils.py @@ -11,14 +11,15 @@ def extras(cfg: DictConfig) -> None: - """Applies optional utilities before the task is started. + r"""Applies optional utilities before the task is started. Utilities: - Ignoring python warnings - Setting tags from command line - Rich config printing - :param cfg: A DictConfig object containing the config tree. + Args: + cfg (DictConfig): A DictConfig object containing the config tree. """ # return if no `extras` config if not cfg.get("extras"): @@ -46,7 +47,7 @@ def extras(cfg: DictConfig) -> None: def task_wrapper(task_func: Callable) -> Callable: - """Optional decorator that controls the failure behavior when executing the + r"""Optional decorator that controls the failure behavior when executing the task function. This wrapper can be used to: @@ -62,10 +63,10 @@ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... return metric_dict, object_dict ``` - - :param task_func: The task function to be wrapped. - - :return: The wrapped task function. + Args: + task_func: The task function to be wrapped. + Returns: + The wrapped task function. """ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: @@ -104,11 +105,13 @@ def wrap(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: def get_metric_value( metric_dict: dict[str, Any], metric_name: str | None ) -> float | None: - """Safely retrieves value of the metric logged in LightningModule. + r"""Safely retrieves value of the metric logged in LightningModule. - :param metric_dict: A dict containing metric values. - :param metric_name: If provided, the name of the metric to retrieve. - :return: If a metric name was provided, the value of the metric. + Args: + metric_dict: A dict containing metric values. + metric_name: If provided, the name of the metric to retrieve. + Returns: + If a metric name was provided, the value of the metric. """ if not metric_name: log.info("Metric name is None! Skipping metric value retrieval...") From c1bae72afdddb35df92267790491dd4599cbb2f2 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Fri, 17 May 2024 21:25:04 +0000 Subject: [PATCH 26/32] fixing torch_geometeric --- configs/train.yaml | 6 +- env.sh | 5 +- tables/cell_statistics.csv | 13 ++++ tables/dataset_statistics.csv | 13 ++++ topobenchmarkx/data/heteriphilic_dataset.py | 8 +- .../data/us_county_demos_dataset.py | 63 ++++++++++------ topobenchmarkx/dataset_statistics.py | 75 ++++++++++++------- topobenchmarkx/io/load/preprocessor.py | 48 +++++++++++- topobenchmarkx/stat.sh | 6 +- 9 files changed, 175 insertions(+), 62 deletions(-) create mode 100644 tables/cell_statistics.csv create mode 100644 tables/dataset_statistics.csv diff --git a/configs/train.yaml b/configs/train.yaml index 4acf13bb..92f56c99 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,12 +4,12 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: amazon_ratings #us_country_demos - - model: hypergraph/allsettransformer #hypergraph/unignn2 #allsettransformer + - dataset: roman_empire #us_country_demos + - model: cell/cwn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - - trainer: default + - trainer: cpu - paths: default - extras: default - hydra: default diff --git a/env.sh b/env.sh index f817fb2e..4a4133b0 100644 --- a/env.sh +++ b/env.sh @@ -1,7 +1,7 @@ # #!/bin/bash -#conda create -n topoxx python=3.11.3 -#conda activate topoxx +conda create -n topoxx python=3.11.3 +conda activate topoxx pip install --upgrade pip pip install -e '.[all]' @@ -12,6 +12,7 @@ pip install git+https://github.com/pyt-team/TopoEmbedX.git CUDA="cu117" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 +pip install torch_geometric==2.4.0 pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html diff --git a/tables/cell_statistics.csv b/tables/cell_statistics.csv new file mode 100644 index 00000000..ca90bae6 --- /dev/null +++ b/tables/cell_statistics.csv @@ -0,0 +1,13 @@ +,3,4,5,6,7,8,9,10,greater_than_10,dataset,domain +0,1120,260,103,51,40,27,22,25,1000,Cora,cell +1,750,278,97,61,30,12,19,12,404,citeseer,cell +2,4174,3017,599,558,266,313,226,172,14280,PubMed,cell +3,769,192,10196,21486,407,53,0,1,17,ZINC,cell +4,7165,1701,700,310,178,97,56,24,35,roman_empire,cell +5,51642,9105,1392,710,359,289,170,103,4783,amazon_ratings,cell +6,24123,288,0,96,2,94,0,92,4260,minesweeper,cell +7,0,0,68,419,0,0,0,0,51,MUTAG,cell +8,24211,7495,1842,1531,568,585,358,349,1834,PROTEINS,cell +9,186,129,3025,10657,373,59,16,64,376,NCI1,cell +10,77706,36,13,3,0,0,0,0,0,IMDB-BINARY,cell +11,80846,35,17,3,0,0,0,0,0,IMDB-MULTI,cell diff --git a/tables/dataset_statistics.csv b/tables/dataset_statistics.csv new file mode 100644 index 00000000..38e43633 --- /dev/null +++ b/tables/dataset_statistics.csv @@ -0,0 +1,13 @@ +,num_hyperedges,zero_cell,one_cell,two_cell,three_cell,dataset,domain +0,0,2708,5278,2648,0,Cora,cell +1,0,3327,4552,1663,0,citeseer,cell +2,0,19717,44324,23605,0,PubMed,cell +3,0,277864,298985,33121,0,ZINC,cell +4,0,22662,32927,10266,0,roman_empire,cell +5,0,24492,93050,68553,0,amazon_ratings,cell +6,0,10000,39402,28955,0,minesweeper,cell +7,0,3371,3721,538,0,MUTAG,cell +8,0,43471,81044,38773,0,PROTEINS,cell +9,0,122747,132753,14885,0,NCI1,cell +10,0,19773,96531,77758,0,IMDB-BINARY,cell +11,0,19502,98903,80901,0,IMDB-MULTI,cell diff --git a/topobenchmarkx/data/heteriphilic_dataset.py b/topobenchmarkx/data/heteriphilic_dataset.py index a073ab8d..d4810055 100644 --- a/topobenchmarkx/data/heteriphilic_dataset.py +++ b/topobenchmarkx/data/heteriphilic_dataset.py @@ -5,7 +5,7 @@ import torch from omegaconf import DictConfig from torch_geometric.data import Data, InMemoryDataset -from torch_geometric.io import fs +#from torch_geometric.io import fs from topobenchmarkx.io.load.heterophilic import ( download_hetero_datasets, @@ -52,7 +52,7 @@ def __init__( transform: Callable | None = None, pre_transform: Callable | None = None, pre_filter: Callable | None = None, - force_reload: bool = True, + #force_reload: bool = True, use_node_attr: bool = False, use_edge_attr: bool = False, ) -> None: @@ -63,11 +63,11 @@ def __init__( transform, pre_transform, pre_filter, - force_reload=force_reload, + #force_reload=force_reload, ) # Load the processed data - data, _, _ = fs.torch_load(self.processed_paths[0]) + data, _, _ = torch.load(self.processed_paths[0]) # Map the loaded data into data = Data.from_dict(data) diff --git a/topobenchmarkx/data/us_county_demos_dataset.py b/topobenchmarkx/data/us_county_demos_dataset.py index df1fb632..76db183f 100644 --- a/topobenchmarkx/data/us_county_demos_dataset.py +++ b/topobenchmarkx/data/us_county_demos_dataset.py @@ -1,11 +1,12 @@ +import os import os.path as osp from collections.abc import Callable from typing import ClassVar - +import shutil import torch from omegaconf import DictConfig -from torch_geometric.data import Data, InMemoryDataset -from torch_geometric.io import fs +from torch_geometric.data import Data, InMemoryDataset, extract_zip +# from torch_geometric.io import fs from topobenchmarkx.io.load.download_utils import download_file_from_drive from topobenchmarkx.io.load.split_utils import random_splitting @@ -62,7 +63,7 @@ def __init__( transform: Callable | None = None, pre_transform: Callable | None = None, pre_filter: Callable | None = None, - force_reload: bool = True, + #force_reload: bool = True, use_node_attr: bool = False, use_edge_attr: bool = False, ) -> None: @@ -73,14 +74,14 @@ def __init__( transform, pre_transform, pre_filter, - force_reload=force_reload, + #force_reload=force_reload, ) # Load the processed data - data, _, _ = fs.torch_load(self.processed_paths[0]) - + data, _ = torch.load(self.processed_paths[0]) + # Map the loaded data into - data = Data.from_dict(data) + data = Data.from_dict(data) if isinstance(data, dict) else data # Create the splits and upload desired fold splits = random_splitting(data.y, parameters=self.parameters) @@ -113,8 +114,8 @@ def processed_dir(self) -> str: @property def raw_file_names(self) -> list[str]: - names = ["", f"_{self.parameters.year}"] - return [f"{self.name}_{name}.txt" for name in names] + #names = ["county", f"{self.parameters.year}"] + return [f"county_graph.csv", f"county_stats_{self.parameters.year}.csv"] @property def processed_file_names(self) -> str: @@ -139,20 +140,38 @@ def download(self) -> None: file_format=self.file_format, ) - # Extract the downloaded file if it is compressed - fs.cp( - f"{self.raw_dir}/{self.name}.{self.file_format}", - self.raw_dir, - extract=True, - ) + folder = self.raw_dir + filename = f"{self.name}.{self.file_format}" + path = osp.join(folder, filename) + extract_zip(path, folder) + # Delete zip file + os.unlink(path) + #shutil.rmtree(path) + # Move files from osp.join(folder, self.name) to folder + for file in os.listdir(osp.join(folder, self.name)): + shutil.move(osp.join(folder, self.name, file), folder) + + # Delete osp.join(folder, self.name) dir + shutil.rmtree(osp.join(folder, self.name)) - # Move the etracted files to the datasets/domain/dataset_name/raw/ directory - for filename in fs.ls(osp.join(self.raw_dir, self.name)): - fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename))) - fs.rm(osp.join(self.raw_dir, self.name)) - # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}' - fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}") + #os.rename(osp.join(folder, self.name), self.raw_dir) + + + # # Extract the downloaded file if it is compressed + # fs.cp(f"{self.raw_dir}/{self.name}.{self.file_format}", + # f"{self.raw_dir}/{self.name}.{self.file_format}", + # self.raw_dir, + # extract=True, + # ) + # # Move the etracted files to the datasets/domain/dataset_name/raw/ directory + # for filename in fs.ls(osp.join(self.raw_dir, self.name)): + # fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename))) + # fs.rm(osp.join(self.raw_dir, self.name)) + + # # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}' + # fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}") + def process(self) -> None: r"""Process the data for the dataset. diff --git a/topobenchmarkx/dataset_statistics.py b/topobenchmarkx/dataset_statistics.py index 757c53d9..a6a6dfe7 100755 --- a/topobenchmarkx/dataset_statistics.py +++ b/topobenchmarkx/dataset_statistics.py @@ -127,6 +127,19 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: "two_cell": 0, "three_cell": 0, } + + cell_dict = { + "3":0, + "4":0, + "5":0, + "6":0, + "7":0, + "8":0, + "9":0, + "10":0, + "greater_than_10":0 + + } for loader in dataloaders: for batch in loader: @@ -145,7 +158,12 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: dict_collector["one_cell"] += batch.x_1.shape[0] dict_collector["two_cell"] += batch.x_2.shape[0] cell_sizes, cell_counts = torch.unique(batch.incidence_2.to_dense().sum(0), return_counts=True) - + cell_sizes = cell_sizes.long() + for i in range(len(cell_sizes)): + if cell_sizes[i].item() > 10: + cell_dict["greater_than_10"] += cell_counts[i].item() + else: + cell_dict[str(cell_sizes[i].item())] += cell_counts[i].item() # Get current working dir filename = f"{cfg.paths['root_dir']}/tables/dataset_statistics.csv" @@ -165,34 +183,37 @@ def train(cfg: DictConfig) -> tuple[dict[str, Any], dict[str, Any]]: # write to csv file df_saved.to_csv(filename) - # if cfg.model.model_domain == "cell": - # filename = f"{cfg.paths['root_dir']}/tables/cell_statistics.csv" - # # Create a dict from two arrays - # cell_dict = dict(zip(cell_sizes.long().tolist(), cell_counts.long().tolist())) - - # # Check if there are cells size of which greater than 10 - # n_large_cells = 0 - # subset_keys = [key for key in sorted(cell_dict.keys()) if key > 10] + if cfg.model.model_domain == "cell": + filename = f"{cfg.paths['root_dir']}/tables/cell_statistics.csv" + # Create a dict from two arrays + # cell_dict = dict(zip(cell_sizes.long().tolist(), cell_counts.long().tolist())) + # keys = list(cell_dict.keys()) + # for key in keys: + # cell_dict[str(key)] = cell_dict.pop(key) + + # # Check if there are cells size of which greater than 10 + # n_large_cells = 0 + # subset_keys = [key for key in sorted(cell_dict.keys()) if int(key) > 10] - # for key in subset_keys: - # n_large_cells += cell_dict.pop(key) + # for key in subset_keys: + # n_large_cells += cell_dict.pop(key) - # cell_dict["greater_than_10"] = n_large_cells - - # cell_dict['dataset'] = cfg.dataset.parameters.data_name - # cell_dict['domain'] = cfg.model.model_domain - - # df = pd.DataFrame.from_dict(cell_dict, orient='index') - # if not os.path.exists(filename) == True: - # # Save to csv file such as methods .... is a header - # df.T.to_csv(filename, header=True) - # else: - # # read csv file with deader - # df_saved = pd.read_csv(filename, index_col=0) - # # add new row - # df_saved = df_saved._append(df.T, ignore_index=True) - # # write to csv file - # df_saved.to_csv(filename) + # cell_dict["greater_than_10"] = n_large_cells + + cell_dict['dataset'] = cfg.dataset.parameters.data_name + cell_dict['domain'] = cfg.model.model_domain + + df = pd.DataFrame.from_dict(cell_dict, orient='index') + if not os.path.exists(filename) == True: + # Save to csv file such as methods .... is a header + df.T.to_csv(filename, header=True) + else: + # read csv file with deader + df_saved = pd.read_csv(filename, index_col=0) + # add new row + df_saved = df_saved._append(df.T, ignore_index=True) + # write to csv file + df_saved.to_csv(filename) return diff --git a/topobenchmarkx/io/load/preprocessor.py b/topobenchmarkx/io/load/preprocessor.py index e1ebd639..ac9f53b5 100644 --- a/topobenchmarkx/io/load/preprocessor.py +++ b/topobenchmarkx/io/load/preprocessor.py @@ -7,6 +7,10 @@ from topobenchmarkx.io.load.utils import ensure_serializable, make_hash +from torch_geometric.data.dataset import * + + + class Preprocessor(torch_geometric.data.InMemoryDataset): r"""Preprocessor for datasets. @@ -34,11 +38,14 @@ def __init__( pre_transform = self.instantiate_pre_transform( data_dir, transforms_config ) + # Torch geometric introduces force_reload from 2.5.0 version, but there is weird bug + self.force_reload = force_reload + super().__init__( self.processed_data_dir, None, pre_transform, - force_reload=force_reload, + #force_reload=force_reload, **kwargs, ) self.save_transform_parameters() @@ -139,3 +146,42 @@ def process(self) -> None: assert isinstance(self._data, torch_geometric.data.Data) self.save(self.data_list, self.processed_paths[0]) + + def _process(self): + f = osp.join(self.processed_dir, 'pre_transform.pt') + if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): + warnings.warn( + f"The `pre_transform` argument differs from the one used in " + f"the pre-processed version of this dataset. If you want to " + f"make use of another pre-processing technique, make sure to " + f"delete '{self.processed_dir}' first") + + f = osp.join(self.processed_dir, 'pre_filter.pt') + if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): + warnings.warn( + "The `pre_filter` argument differs from the one used in " + "the pre-processed version of this dataset. If you want to " + "make use of another pre-fitering technique, make sure to " + "delete '{self.processed_dir}' first") + + if not self.force_reload and files_exist(self.processed_paths): + return + + if self.log and 'pytest' not in sys.modules: + print('Processing...', file=sys.stderr) + + makedirs(self.processed_dir) + self.process() + + path = osp.join(self.processed_dir, 'pre_transform.pt') + torch.save(_repr(self.pre_transform), path) + path = osp.join(self.processed_dir, 'pre_filter.pt') + torch.save(_repr(self.pre_filter), path) + + if self.log and 'pytest' not in sys.modules: + print('Done!', file=sys.stderr) + +def _repr(obj: Any) -> str: + if obj is None: + return 'None' + return re.sub('(<.*?)\\s.*(>)', r'\1\2', str(obj)) \ No newline at end of file diff --git a/topobenchmarkx/stat.sh b/topobenchmarkx/stat.sh index 78e36864..67e10248 100644 --- a/topobenchmarkx/stat.sh +++ b/topobenchmarkx/stat.sh @@ -1,7 +1,7 @@ # Description: Main experiment script for GCN model. # ----Node regression datasets: US County Demographics---- -models=( 'simplicial/scn' 'cell/cwn' 'hypergraph/unignn2' ) +models=( 'cell/cwn' ) for model in ${models[*]} do @@ -32,7 +32,7 @@ python dataset_statistics.py \ # ----Heterophilic datasets---- -datasets=( roman_empire amazon_ratings tolokers questions minesweeper ) +datasets=( roman_empire amazon_ratings minesweeper ) for dataset in ${datasets[*]} do @@ -49,7 +49,7 @@ python dataset_statistics.py \ model=$model # Train rest of the TU graph datasets -datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'REDDIT-BINARY' 'IMDB-BINARY' 'IMDB-MULTI') # +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI') # for dataset in ${datasets[*]} do From 7cce31ad695798f8bcc400788f1af4e25e036de0 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Fri, 17 May 2024 21:44:52 +0000 Subject: [PATCH 27/32] fixed issue with env --- configs/dataset/us_country_demos.yaml | 2 +- configs/train.yaml | 2 +- configs/trainer/default.yaml | 2 +- .../data/us_county_demos_dataset.py | 24 +++++-------------- 4 files changed, 9 insertions(+), 21 deletions(-) diff --git a/configs/dataset/us_country_demos.yaml b/configs/dataset/us_country_demos.yaml index 29d9b44b..c5b7ad82 100755 --- a/configs/dataset/us_country_demos.yaml +++ b/configs/dataset/us_country_demos.yaml @@ -17,7 +17,7 @@ parameters: num_features: 6 num_classes: 1 task: regression - task_variable: 'Election' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'] + task_variable: 'MigraRate' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'] force_reload: True loss_type: mse monitor_metric: mae diff --git a/configs/train.yaml b/configs/train.yaml index 92f56c99..00fadac7 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: roman_empire #us_country_demos + - dataset: us_country_demos #us_country_demos - model: cell/cwn #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index 6c132c34..c3c07cff 100755 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -3,7 +3,7 @@ _target_: lightning.pytorch.trainer.Trainer default_root_dir: ${paths.output_dir} min_epochs: 1 # prevents early stopping -max_epochs: 300 +max_epochs: 5 #accumulate_grad_batches: 1 #${dataset.parameters.batch_size} accelerator: gpu diff --git a/topobenchmarkx/data/us_county_demos_dataset.py b/topobenchmarkx/data/us_county_demos_dataset.py index 1aba03b2..a5d5dd15 100644 --- a/topobenchmarkx/data/us_county_demos_dataset.py +++ b/topobenchmarkx/data/us_county_demos_dataset.py @@ -101,6 +101,12 @@ def __init__( # Assign data object to self.data, to make it be prodessed by Dataset class self.data, self.slices = self.collate([data]) + # Make sure the dataset will be reloaded during next run + shutil.rmtree(self.raw_dir) + # Get parent dir of self.processed_paths[0] + processed_dir = os.path.abspath(os.path.join(self.processed_paths[0], os.pardir)) + shutil.rmtree(processed_dir) + def __repr__(self) -> str: return f"{self.name}(self.root={self.root}, self.name={self.name}, self.parameters={self.parameters}, self.transform={self.transform}, self.pre_transform={self.pre_transform}, self.pre_filter={self.pre_filter}, self.force_reload={self.force_reload})" @@ -153,24 +159,6 @@ def download(self) -> None: # Delete osp.join(folder, self.name) dir shutil.rmtree(osp.join(folder, self.name)) - - - #os.rename(osp.join(folder, self.name), self.raw_dir) - - - # # Extract the downloaded file if it is compressed - # fs.cp(f"{self.raw_dir}/{self.name}.{self.file_format}", - # f"{self.raw_dir}/{self.name}.{self.file_format}", - # self.raw_dir, - # extract=True, - # ) - # # Move the etracted files to the datasets/domain/dataset_name/raw/ directory - # for filename in fs.ls(osp.join(self.raw_dir, self.name)): - # fs.mv(filename, osp.join(self.raw_dir, osp.basename(filename))) - # fs.rm(osp.join(self.raw_dir, self.name)) - - # # Delete also f'{self.raw_dir}/{self.name}.{self.file_format}' - # fs.rm(f"{self.raw_dir}/{self.name}.{self.file_format}") def process(self) -> None: From 637ba1073314a2165c4990e1b3d0575c5d2ee327 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Fri, 17 May 2024 21:46:59 +0000 Subject: [PATCH 28/32] docker update --- Dockerfile | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/Dockerfile b/Dockerfile index afc10650..0dd0bd01 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,10 +7,11 @@ COPY . . RUN pip install --upgrade pip RUN pip install -e '.[all]' -RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git -RUN pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git +RUN pip install git+https://github.com/pyt-team/TopoNetX.git +RUN pip install git+https://github.com/pyt-team/TopoModelX.git +RUN pip install git+https://github.com/pyt-team/TopoEmbedX.git + +RUN pip install torch_geometric==2.4.0 RUN pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cu115 RUN pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cu115.html RUN pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+cu115.html -#RUN pip install lightning>=2.0.0 -#RUN pip install numpy pre-commit jupyterlab notebook ipykernel \ No newline at end of file From 30c12353770eef1f8257326ef94696a3fba27219 Mon Sep 17 00:00:00 2001 From: levtelyatnikov Date: Sat, 18 May 2024 00:32:08 +0000 Subject: [PATCH 29/32] cellular --- .../graph2cell_lifting/cell_cycles.yaml | 2 +- configs/dataset/us_country_demos.yaml | 2 +- configs/model/cell/can.yaml | 4 +- configs/train.yaml | 6 +- configs/trainer/default.yaml | 2 +- hp_scripts/main_exp/cellular/CAN.sh | 147 ++++++++++++++++++ hp_scripts/main_exp/cellular/CCCN.sh | 147 ++++++++++++++++++ hp_scripts/main_exp/cellular/CCXN.sh | 147 ++++++++++++++++++ hp_scripts/main_exp/cellular/CWN.sh | 147 ++++++++++++++++++ hp_scripts/main_exp/cellular/left_out.sh | 103 ++++++++++++ hp_scripts/main_exp/graph/gcn.sh | 2 +- topobenchmarkx/run_cellular_scripts.sh | 9 ++ 12 files changed, 709 insertions(+), 9 deletions(-) create mode 100644 hp_scripts/main_exp/cellular/CAN.sh create mode 100644 hp_scripts/main_exp/cellular/CCCN.sh create mode 100644 hp_scripts/main_exp/cellular/CCXN.sh create mode 100644 hp_scripts/main_exp/cellular/CWN.sh create mode 100644 hp_scripts/main_exp/cellular/left_out.sh create mode 100644 topobenchmarkx/run_cellular_scripts.sh diff --git a/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml b/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml index 79b91303..d9a6c272 100644 --- a/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml +++ b/configs/dataset/transforms/graph2cell_lifting/cell_cycles.yaml @@ -2,5 +2,5 @@ _target_: topobenchmarkx.transforms.data_transform.DataTransform transform_type: 'lifting' transform_name: "CellCyclesLifting" complex_dim: ${oc.select:dataset.parameters.max_dim_if_lifted,3} -max_cell_length: 6 +max_cell_length: 10 preserve_edge_attr: ${oc.select:dataset.parameters.preserve_edge_attr_if_lifted,False} diff --git a/configs/dataset/us_country_demos.yaml b/configs/dataset/us_country_demos.yaml index c5b7ad82..61cd17a3 100755 --- a/configs/dataset/us_country_demos.yaml +++ b/configs/dataset/us_country_demos.yaml @@ -17,7 +17,7 @@ parameters: num_features: 6 num_classes: 1 task: regression - task_variable: 'MigraRate' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'] + task_variable: 'MedianIncome' # options: ['Election', 'MedianIncome', 'MigraRate', 'BirthRate', 'DeathRate', 'BachelorRate', 'UnemploymentRate'] force_reload: True loss_type: mse monitor_metric: mae diff --git a/configs/model/cell/can.yaml b/configs/model/cell/can.yaml index f401f622..4d1e1575 100755 --- a/configs/model/cell/can.yaml +++ b/configs/model/cell/can.yaml @@ -7,7 +7,7 @@ feature_encoder: _target_: topobenchmarkx.models.encoders.${model.feature_encoder.encoder_name} encoder_name: AllCellFeatureEncoder in_channels: ${infer_in_channels:${dataset}} #${dataset.parameters.num_features} - out_channels: 32 + out_channels: 128 proj_dropout: 0.0 selected_dimensions: - 0 @@ -22,7 +22,7 @@ backbone: heads: 1 # For now we stuck to out_channels//heads, keep heads = 1 concat: True skip_connection: True - n_layers: 1 + n_layers: 4 att_lift: False backbone_wrapper: diff --git a/configs/train.yaml b/configs/train.yaml index 00fadac7..2c061d19 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,12 +4,12 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: us_country_demos #us_country_demos - - model: cell/cwn #hypergraph/unignn2 #allsettransformer + - dataset: minesweeper # us_country_demos + - model: cell/can #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default - logger: wandb # set logger here or use command line (e.g. `python train.py logger=tensorboard`) - - trainer: cpu + - trainer: default - paths: default - extras: default - hydra: default diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml index c3c07cff..6c132c34 100755 --- a/configs/trainer/default.yaml +++ b/configs/trainer/default.yaml @@ -3,7 +3,7 @@ _target_: lightning.pytorch.trainer.Trainer default_root_dir: ${paths.output_dir} min_epochs: 1 # prevents early stopping -max_epochs: 5 +max_epochs: 300 #accumulate_grad_batches: 1 #${dataset.parameters.batch_size} accelerator: gpu diff --git a/hp_scripts/main_exp/cellular/CAN.sh b/hp_scripts/main_exp/cellular/CAN.sh new file mode 100644 index 00000000..d112e1a5 --- /dev/null +++ b/hp_scripts/main_exp/cellular/CAN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/can \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/can \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/can \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/CCCN.sh b/hp_scripts/main_exp/cellular/CCCN.sh new file mode 100644 index 00000000..cef694ca --- /dev/null +++ b/hp_scripts/main_exp/cellular/CCCN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/cccn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/cccn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/cccn \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/CCXN.sh b/hp_scripts/main_exp/cellular/CCXN.sh new file mode 100644 index 00000000..87f316c9 --- /dev/null +++ b/hp_scripts/main_exp/cellular/CCXN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/ccxn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/ccxn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/ccxn \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/CWN.sh b/hp_scripts/main_exp/cellular/CWN.sh new file mode 100644 index 00000000..58366f32 --- /dev/null +++ b/hp_scripts/main_exp/cellular/CWN.sh @@ -0,0 +1,147 @@ +# Description: Main experiment script for GCN model. +# ----Node regression datasets: US County Demographics---- +task_variables=( 'Election' 'MedianIncome' 'MigraRate' 'BirthRate' 'DeathRate' 'BachelorRate' 'UnemploymentRate' ) + +for task_variable in ${task_variables[*]} +do + python train.py \ + dataset=us_country_demos \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.task_variable=$task_variable \ + model=cell/cwn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2,3,4 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=1000 \ + trainer.min_epochs=500 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun + +done + +# ----Cocitation datasets---- +datasets=( 'cocitation_cora' 'cocitation_citeseer' 'cocitation_pubmed' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + dataset.parameters.data_seed=0,3,5,7,9 \ + model=cell/cwn \ + model.feature_encoder.out_channels=32,64,128 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + model.backbone.n_layers=1,2 \ + model.optimizer.lr="0.01,0.001" \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=25 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + tags="[MainExperiment]" \ + --multirun +done + +# # ----Graph regression dataset---- +# # Train on ZINC dataset +# python train.py \ +# dataset=ZINC \ +# seed=42,3,5,23,150 \ +# model=cell/cwn \ +# model.optimizer.lr=0.01,0.001 \ +# model.optimizer.weight_decay=0 \ +# model.feature_encoder.out_channels=32,64,128 \ +# model.backbone.n_layers=2,4 \ +# model.feature_encoder.proj_dropout=0.25,0.5 \ +# dataset.parameters.batch_size=128,256 \ +# dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ +# dataset.parameters.data_seed=0 \ +# dataset.transforms.graph2cell_lifting.max_cell_length=10 \ +# model.readout.readout_name="NoReadOut,PropagateSignalDown" \ +# logger.wandb.project=TopoBenchmarkX_Cellular \ +# trainer.max_epochs=500 \ +# trainer.min_epochs=50 \ +# callbacks.early_stopping.min_delta=0.005 \ +# trainer.check_val_every_n_epoch=5 \ +# callbacks.early_stopping.patience=10 \ +# tags="[MainExperiment]" \ +# --multirun + +# ----TU graph datasets---- +# MUTAG have very few samples, so we use a smaller batch size +# Train on MUTAG dataset +python train.py \ + dataset=MUTAG \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=32,64 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + callbacks.early_stopping.patience=25 \ + tags="[MainExperiment]" \ + --multirun + +# Train rest of the TU graph datasets +datasets=( 'PROTEINS_TU' 'NCI1' 'NCI109' 'IMDB-BINARY' 'IMDB-MULTI' ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + --multirun +done + +# ----Heterophilic datasets---- + +datasets=( roman_empire minesweeper ) + +for dataset in ${datasets[*]} +do + python train.py \ + dataset=$dataset \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=1,2,3,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.data_seed=0,3,5,7,9 \ + dataset.parameters.batch_size=1 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=1000 \ + trainer.min_epochs=50 \ + trainer.check_val_every_n_epoch=1 \ + callbacks.early_stopping.patience=50 \ + tags="[MainExperiment]" \ + --multirun +done diff --git a/hp_scripts/main_exp/cellular/left_out.sh b/hp_scripts/main_exp/cellular/left_out.sh new file mode 100644 index 00000000..81f00969 --- /dev/null +++ b/hp_scripts/main_exp/cellular/left_out.sh @@ -0,0 +1,103 @@ +# ----Graph regression dataset---- +# Train on ZINC dataset + +# CWN +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/cwn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# CCXN +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/ccxn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + +# CCCN +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/cccn \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + + +# CAN + +python train.py \ + dataset=ZINC \ + seed=42,3,5,23,150 \ + model=cell/can \ + model.optimizer.lr=0.01,0.001 \ + model.optimizer.weight_decay=0 \ + model.feature_encoder.out_channels=32,64,128 \ + model.backbone.n_layers=2,4 \ + model.feature_encoder.proj_dropout=0.25,0.5 \ + dataset.parameters.batch_size=128,256 \ + dataset.transforms.one_hot_node_degree_features.degrees_fields=x \ + dataset.parameters.data_seed=0 \ + dataset.transforms.graph2cell_lifting.max_cell_length=10 \ + model.readout.readout_name="NoReadOut,PropagateSignalDown" \ + logger.wandb.project=TopoBenchmarkX_Cellular \ + trainer.max_epochs=500 \ + trainer.min_epochs=50 \ + callbacks.early_stopping.min_delta=0.005 \ + trainer.check_val_every_n_epoch=5 \ + callbacks.early_stopping.patience=10 \ + tags="[MainExperiment]" \ + --multirun + + +# REDDIT BINARY for all \ No newline at end of file diff --git a/hp_scripts/main_exp/graph/gcn.sh b/hp_scripts/main_exp/graph/gcn.sh index 07107d16..3e9f85d2 100644 --- a/hp_scripts/main_exp/graph/gcn.sh +++ b/hp_scripts/main_exp/graph/gcn.sh @@ -82,7 +82,7 @@ do model.backbone.num_layers=1,2,3,4 \ model.feature_encoder.proj_dropout=0.25,0.5 \ dataset.parameters.data_seed=0,3,5 \ - dataset.parameters.batch_size=128,256 \ + dataset.parameters.batch_size=1 \ logger.wandb.project=TopoBenchmarkX_Graph \ trainer.max_epochs=1000 \ trainer.min_epochs=50 \ diff --git a/topobenchmarkx/run_cellular_scripts.sh b/topobenchmarkx/run_cellular_scripts.sh new file mode 100644 index 00000000..075d6824 --- /dev/null +++ b/topobenchmarkx/run_cellular_scripts.sh @@ -0,0 +1,9 @@ +# Run the scripts from the hypergraph directory +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CWN.sh +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CCCN.sh +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CAN.sh +bash /TopoBenchmarkX/hp_scripts/main_exp/cellular/CCXN.sh + +# Run in case we have time +# bash ~/TopoBenchmarkX/hp_scripts/main_exp/cellular/left_out.sh + From fd0c488abd50f6a84a026909064f9283a1549ab7 Mon Sep 17 00:00:00 2001 From: Guille Date: Sat, 18 May 2024 22:10:59 +0000 Subject: [PATCH 30/32] New env.sh --- conda.sh | 11 +++++++++++ configs/train.yaml | 2 +- env.sh | 7 ++----- 3 files changed, 14 insertions(+), 6 deletions(-) create mode 100755 conda.sh mode change 100644 => 100755 env.sh diff --git a/conda.sh b/conda.sh new file mode 100755 index 00000000..40717fa9 --- /dev/null +++ b/conda.sh @@ -0,0 +1,11 @@ +# #!/bin/bash + +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm -rf ~/miniconda3/miniconda.sh + +~/miniconda3/bin/conda init bash + +#conda create -n topox python=3.11.3 +#conda activate topox \ No newline at end of file diff --git a/configs/train.yaml b/configs/train.yaml index 2c061d19..1f05aba4 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: minesweeper # us_country_demos + - dataset: ZINC # us_country_demos - model: cell/can #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default diff --git a/env.sh b/env.sh old mode 100644 new mode 100755 index 4a4133b0..0ccab522 --- a/env.sh +++ b/env.sh @@ -1,7 +1,4 @@ -# #!/bin/bash - -conda create -n topoxx python=3.11.3 -conda activate topoxx +#!/bin/bash -l pip install --upgrade pip pip install -e '.[all]' @@ -10,7 +7,7 @@ pip install git+https://github.com/pyt-team/TopoNetX.git pip install git+https://github.com/pyt-team/TopoModelX.git pip install git+https://github.com/pyt-team/TopoEmbedX.git -CUDA="cu117" # if available, select the CUDA version suitable for your system +CUDA="cu121" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 pip install torch_geometric==2.4.0 pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} From ac85988f0cce4dfe2c44cd6020863f3e70b65ee1 Mon Sep 17 00:00:00 2001 From: Guille Date: Mon, 20 May 2024 12:09:40 +0000 Subject: [PATCH 31/32] Solved problem with pyg loaders --- configs/dataset/ZINC.yaml | 2 +- .../one_hot_node_degree_features.yaml | 2 +- env.sh | 2 +- topobenchmarkx/io/load/preprocessor.py | 41 +------------------ .../data_manipulations/one_hot_degree.py | 1 + .../one_hot_degree_features.py | 2 +- 6 files changed, 6 insertions(+), 44 deletions(-) diff --git a/configs/dataset/ZINC.yaml b/configs/dataset/ZINC.yaml index f39a9c4c..58d2ee3e 100644 --- a/configs/dataset/ZINC.yaml +++ b/configs/dataset/ZINC.yaml @@ -3,7 +3,7 @@ _target_: topobenchmarkx.io.load.loaders.GraphLoader # USE python train.py dataset.transforms.one_hot_node_degree_features.degrees_fields=x to run this config defaults: - #- transforms/data_manipulations: node_feat_to_float + - transforms/data_manipulations: node_degrees - transforms/data_manipulations@transforms.one_hot_node_degree_features: one_hot_node_degree_features - transforms: ${get_default_transform:graph,${model}} diff --git a/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml b/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml index b8892aa9..7fdd0e67 100755 --- a/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml +++ b/configs/dataset/transforms/data_manipulations/one_hot_node_degree_features.yaml @@ -4,5 +4,5 @@ transform_type: "data manipulation" degrees_fields: "node_degrees" features_fields: "x" -max_degrees: ${dataset.parameters.max_node_degree} +max_degree: ${dataset.parameters.max_node_degree} diff --git a/env.sh b/env.sh index 0ccab522..66f84654 100755 --- a/env.sh +++ b/env.sh @@ -9,7 +9,7 @@ pip install git+https://github.com/pyt-team/TopoEmbedX.git CUDA="cu121" # if available, select the CUDA version suitable for your system # e.g. cpu, cu102, cu111, cu113, cu115 -pip install torch_geometric==2.4.0 +pip install fsspec==2024.5.0 pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html diff --git a/topobenchmarkx/io/load/preprocessor.py b/topobenchmarkx/io/load/preprocessor.py index ac9f53b5..5cb8a125 100644 --- a/topobenchmarkx/io/load/preprocessor.py +++ b/topobenchmarkx/io/load/preprocessor.py @@ -145,43 +145,4 @@ def process(self) -> None: self._data_list = None # Reset cache. assert isinstance(self._data, torch_geometric.data.Data) - self.save(self.data_list, self.processed_paths[0]) - - def _process(self): - f = osp.join(self.processed_dir, 'pre_transform.pt') - if osp.exists(f) and torch.load(f) != _repr(self.pre_transform): - warnings.warn( - f"The `pre_transform` argument differs from the one used in " - f"the pre-processed version of this dataset. If you want to " - f"make use of another pre-processing technique, make sure to " - f"delete '{self.processed_dir}' first") - - f = osp.join(self.processed_dir, 'pre_filter.pt') - if osp.exists(f) and torch.load(f) != _repr(self.pre_filter): - warnings.warn( - "The `pre_filter` argument differs from the one used in " - "the pre-processed version of this dataset. If you want to " - "make use of another pre-fitering technique, make sure to " - "delete '{self.processed_dir}' first") - - if not self.force_reload and files_exist(self.processed_paths): - return - - if self.log and 'pytest' not in sys.modules: - print('Processing...', file=sys.stderr) - - makedirs(self.processed_dir) - self.process() - - path = osp.join(self.processed_dir, 'pre_transform.pt') - torch.save(_repr(self.pre_transform), path) - path = osp.join(self.processed_dir, 'pre_filter.pt') - torch.save(_repr(self.pre_filter), path) - - if self.log and 'pytest' not in sys.modules: - print('Done!', file=sys.stderr) - -def _repr(obj: Any) -> str: - if obj is None: - return 'None' - return re.sub('(<.*?)\\s.*(>)', r'\1\2', str(obj)) \ No newline at end of file + self.save(self.data_list, self.processed_paths[0]) \ No newline at end of file diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py index 580f87c1..80a2611a 100644 --- a/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree.py @@ -13,6 +13,7 @@ def __init__( self, max_degree: int, cat: bool = False, + **kwargs, ) -> None: self.max_degree = max_degree self.cat = cat diff --git a/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py index 2347fb79..c12d1c9c 100644 --- a/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py +++ b/topobenchmarkx/transforms/data_manipulations/one_hot_degree_features.py @@ -15,7 +15,7 @@ def __init__(self, **kwargs): self.type = "one_hot_degree_features" self.deg_field = kwargs["degrees_fields"] self.features_fields = kwargs["features_fields"] - self.transform = OneHotDegree(max_degree=kwargs["max_degrees"]) + self.transform = OneHotDegree(**kwargs) def __repr__(self) -> str: return f"{self.__class__.__name__}(type={self.type!r}, degrees_field={self.deg_field!r}, features_field={self.features_fields!r})" From 2351fb7b66abba2424f131fc91d53811bb55f23b Mon Sep 17 00:00:00 2001 From: Guille Date: Mon, 20 May 2024 16:43:03 +0000 Subject: [PATCH 32/32] Preeliminar env setup --- configs/train.yaml | 2 +- env.sh | 21 +++++++++++---------- pyproject.toml | 4 +--- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/configs/train.yaml b/configs/train.yaml index 1f05aba4..0bdd2097 100755 --- a/configs/train.yaml +++ b/configs/train.yaml @@ -4,7 +4,7 @@ # order of defaults determines the order in which configs override each other defaults: - _self_ - - dataset: ZINC # us_country_demos + - dataset: MUTAG # us_country_demos - model: cell/can #hypergraph/unignn2 #allsettransformer - evaluator: default - callbacks: default diff --git a/env.sh b/env.sh index 66f84654..fe8c8123 100755 --- a/env.sh +++ b/env.sh @@ -3,17 +3,18 @@ pip install --upgrade pip pip install -e '.[all]' -pip install git+https://github.com/pyt-team/TopoNetX.git -pip install git+https://github.com/pyt-team/TopoModelX.git -pip install git+https://github.com/pyt-team/TopoEmbedX.git - -CUDA="cu121" # if available, select the CUDA version suitable for your system - # e.g. cpu, cu102, cu111, cu113, cu115 -pip install fsspec==2024.5.0 -pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/${CUDA} -pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+${CUDA}.html -pip install torch-cluster -f https://data.pyg.org/whl/torch-2.0.0+${CUDA}.html +pip install --no-dependencies git+https://github.com/pyt-team/TopoNetX.git +pip install --no-dependencies git+https://github.com/pyt-team/TopoModelX.git +pip install --no-dependencies git+https://github.com/pyt-team/TopoEmbedX.git +# Note that not all combinations of torch and CUDA are available +# See https://github.com/pyg-team/pyg-lib to check the configuration that works for you +TORCH="2.3.0" # available options: 1.12.0, 1.13.0, 2.0.0, 2.1.0, 2.2.0, or 2.3.0 +CUDA="cu121" # if available, select the CUDA version suitable for your system + # available options: cpu, cu102, cu113, cu116, cu117, cu118, or cu121 +pip install torch==${TORCH} --extra-index-url https://download.pytorch.org/whl/${CUDA} +pip install lightning torch_geometric==2.4.0 +pip install pyg-lib torch-scatter torch-sparse torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html pytest pre-commit install diff --git a/pyproject.toml b/pyproject.toml index a6564426..37c63da0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,6 @@ dependencies=[ "networkx", "pandas", "gudhi", - "pyg-nightly", "decorator", "hypernetx < 2.0.0", "trimesh", @@ -42,9 +41,8 @@ dependencies=[ "hydra-core==1.3.2", "hydra-colorlog==1.2.0", "hydra-optuna-sweeper==1.2.0", - "lightning==2.2.1", - "einops==0.7.0", "wandb==0.16.4", + "einops==0.7.0", "tabulate", "ipykernel", "notebook",