Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Jun 2, 2024
1 parent dcd7f93 commit 6c1a0c1
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 57 deletions.
38 changes: 26 additions & 12 deletions test/data/preprocess/test_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,24 @@ def setup_method(self, mocker):
)

params = [
{"mock_inmemory_init": "torch_geometric.data.InMemoryDataset.__init__"},
{"mock_save_transform": (PreProcessor, "save_transform_parameters")},
{
"mock_inmemory_init": "torch_geometric.data.InMemoryDataset.__init__"
},
{
"mock_save_transform": (
PreProcessor,
"save_transform_parameters",
)
},
{"mock_load": (PreProcessor, "load")},
{"mock_len": (PreProcessor, "__len__"), "init_args": {"return_value":3}},
{"mock_getitem": (PreProcessor, "get"), "init_args": {"return_value": "0"}},
{
"mock_len": (PreProcessor, "__len__"),
"init_args": {"return_value": 3},
},
{
"mock_getitem": (PreProcessor, "get"),
"init_args": {"return_value": "0"},
},
]
self.flow_mocker = FlowMocker(mocker, params)

Expand All @@ -52,31 +65,32 @@ def test_init_with_transform(self, mocker):
val_processed_paths = ["/some/path"]
params = [
{"assert_args": ("created_property", "processed_data_dir")},
{
"assert_args": ("created_property", "processed_data_dir")
},
{"assert_args": ("created_property", "processed_data_dir")},
{
"mock_inmemory_init": "torch_geometric.data.InMemoryDataset.__init__",
"assert_args": ("called_once_with", ANY, None, ANY)
"assert_args": ("called_once_with", ANY, None, ANY),
},
{
"mock_processed_paths": (PreProcessor, "processed_paths"),
"init_args": {"property_val": val_processed_paths},
},
{
"mock_save_transform": (PreProcessor, "save_transform_parameters"),
"assert_args": ("created_property", "processed_paths")
"mock_save_transform": (
PreProcessor,
"save_transform_parameters",
),
"assert_args": ("created_property", "processed_paths"),
},
{
"mock_load": (PreProcessor, "load"),
"assert_args": ("called_once_with", val_processed_paths[0])
"assert_args": ("called_once_with", val_processed_paths[0]),
},
{"mock_len": (PreProcessor, "__len__")},
{"mock_getitem": (PreProcessor, "get")},
]
self.flow_mocker = FlowMocker(mocker, params)
self.preprocessor_with_tranform = PreProcessor(
self.dataset, self.data_dir, self.transforms_config
self.dataset, self.data_dir, self.transforms_config
)
self.flow_mocker.assert_all(self.preprocessor_with_tranform)

Expand Down
4 changes: 3 additions & 1 deletion topobenchmarkx/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def __init__(

# This line allows to access init params with 'self.hparams' attribute
# also ensures init params will be stored in ckpt
self.save_hyperparameters(logger=False, ignore=["backbone","readout","feature_encoder"])
self.save_hyperparameters(
logger=False, ignore=["backbone", "readout", "feature_encoder"]
)

self.feature_encoder = feature_encoder
if backbone_wrapper is None:
Expand Down
2 changes: 1 addition & 1 deletion topobenchmarkx/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any # noqa: I001
from typing import Any

from topobenchmarkx.transforms.data_manipulations import DATA_MANIPULATIONS
from topobenchmarkx.transforms.feature_liftings import FEATURE_LIFTINGS
Expand Down
2 changes: 1 addition & 1 deletion topobenchmarkx/transforms/data_manipulations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .calculate_simplicial_curvature import (
CalculateSimplicialCurvature,
) # noqa: I001
)
from .equal_gaus_features import EqualGausFeatures
from .identity_transform import IdentityTransform
from .infere_knn_connectivity import InfereKNNConnectivity
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch # noqa: I001
import torch
import torch_geometric

from topobenchmarkx.transforms.liftings.graph2hypergraph import (
Expand Down
2 changes: 1 addition & 1 deletion topobenchmarkx/transforms/liftings/graph2hypergraph/knn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch # noqa: I001
import torch
import torch_geometric

from topobenchmarkx.transforms.liftings.graph2hypergraph import (
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from itertools import combinations # noqa: I001
from itertools import combinations
from typing import Any

import networkx as nx
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import random # noqa: I001
import random
from itertools import combinations
from typing import Any

Expand Down
78 changes: 40 additions & 38 deletions tutorials/tutorial_lifting.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,29 @@
"metadata": {},
"outputs": [],
"source": [
"from omegaconf import OmegaConf\n",
"from itertools import combinations\n",
"from typing import Any\n",
"import networkx as nx\n",
"from toponetx.classes import SimplicialComplex\n",
"\n",
"import lightning as pl\n",
"import networkx as nx\n",
"import torch\n",
"import torch_geometric\n",
"import lightning as pl\n",
"from omegaconf import OmegaConf\n",
"from topomodelx.nn.simplicial.scn2 import SCN2\n",
"from toponetx.classes import SimplicialComplex\n",
"\n",
"from topobenchmarkx.data.load import GraphLoader\n",
"from topobenchmarkx.data.preprocess import PreProcessor\n",
"from topobenchmarkx.dataloader import TBXDataloader\n",
"from topobenchmarkx.nn.encoders import AllCellFeatureEncoder\n",
"from topobenchmarkx.nn.readouts import PropagateSignalDown\n",
"from topomodelx.nn.simplicial.scn2 import SCN2\n",
"from topobenchmarkx.nn.wrappers.simplicial import SCNWrapper\n",
"from topobenchmarkx.evaluator import TBXEvaluator\n",
"from topobenchmarkx.loss import TBXLoss\n",
"from topobenchmarkx.model import TBXModel\n",
"from topobenchmarkx.transforms.liftings.graph2simplicial import Graph2SimplicialLifting"
"from topobenchmarkx.nn.encoders import AllCellFeatureEncoder\n",
"from topobenchmarkx.nn.readouts import PropagateSignalDown\n",
"from topobenchmarkx.nn.wrappers.simplicial import SCNWrapper\n",
"from topobenchmarkx.transforms.liftings.graph2simplicial import (\n",
" Graph2SimplicialLifting,\n",
")"
]
},
{
Expand All @@ -75,50 +77,50 @@
"outputs": [],
"source": [
"loader_config = {\n",
" 'data_domain': 'graph',\n",
" 'data_type': 'TUDataset',\n",
" 'data_name': 'MUTAG',\n",
" 'data_dir': './data/MUTAG/',\n",
" \"data_domain\": \"graph\",\n",
" \"data_type\": \"TUDataset\",\n",
" \"data_name\": \"MUTAG\",\n",
" \"data_dir\": \"./data/MUTAG/\",\n",
"}\n",
"\n",
"transform_config = { 'clique_lifting':\n",
" {'_target_': '__main__.SimplicialCliquesLEQLifting',\n",
" 'transform_name': 'SimplicialCliquesLEQLifting',\n",
" 'transform_type': 'lifting',\n",
" 'complex_dim': 3,}\n",
"transform_config = { \"clique_lifting\":\n",
" {\"_target_\": \"__main__.SimplicialCliquesLEQLifting\",\n",
" \"transform_name\": \"SimplicialCliquesLEQLifting\",\n",
" \"transform_type\": \"lifting\",\n",
" \"complex_dim\": 3,}\n",
"}\n",
"\n",
"split_config = {\n",
" 'learning_setting': 'inductive',\n",
" 'split_type': 'k-fold',\n",
" 'data_seed': 0,\n",
" 'data_split_dir': './data/MUTAG/splits/',\n",
" 'k': 10,\n",
" \"learning_setting\": \"inductive\",\n",
" \"split_type\": \"k-fold\",\n",
" \"data_seed\": 0,\n",
" \"data_split_dir\": \"./data/MUTAG/splits/\",\n",
" \"k\": 10,\n",
"}\n",
"\n",
"in_channels = 7\n",
"out_channels = 2\n",
"dim_hidden = 128\n",
"\n",
"wrapper_config = {\n",
" 'out_channels': dim_hidden,\n",
" 'num_cell_dimensions': 3,\n",
" \"out_channels\": dim_hidden,\n",
" \"num_cell_dimensions\": 3,\n",
"}\n",
"\n",
"readout_config = {\n",
" 'readout_name': 'PropagateSignalDown',\n",
" 'num_cell_dimensions': 1,\n",
" 'hidden_dim': dim_hidden,\n",
" 'out_channels': out_channels,\n",
" 'task_level': 'graph',\n",
" 'pooling_type': 'sum',\n",
" \"readout_name\": \"PropagateSignalDown\",\n",
" \"num_cell_dimensions\": 1,\n",
" \"hidden_dim\": dim_hidden,\n",
" \"out_channels\": out_channels,\n",
" \"task_level\": \"graph\",\n",
" \"pooling_type\": \"sum\",\n",
"}\n",
"\n",
"loss_config = {'task': 'classification', 'loss_type': 'cross_entropy'}\n",
"loss_config = {\"task\": \"classification\", \"loss_type\": \"cross_entropy\"}\n",
"\n",
"evaluator_config = {'task': 'classification', \n",
" 'num_classes': out_channels, \n",
" 'classification_metrics': ['accuracy', 'precision', 'recall']}\n",
"evaluator_config = {\"task\": \"classification\",\n",
" \"num_classes\": out_channels,\n",
" \"classification_metrics\": [\"accuracy\", \"precision\", \"recall\"]}\n",
"\n",
"loader_config = OmegaConf.create(loader_config)\n",
"transform_config = OmegaConf.create(transform_config)\n",
Expand Down Expand Up @@ -222,7 +224,7 @@
"source": [
"from topobenchmarkx.transforms import TRANSFORMS\n",
"\n",
"TRANSFORMS['SimplicialCliquesLEQLifting'] = SimplicialCliquesLEQLifting"
"TRANSFORMS[\"SimplicialCliquesLEQLifting\"] = SimplicialCliquesLEQLifting"
]
},
{
Expand Down Expand Up @@ -367,7 +369,7 @@
}
],
"source": [
"trainer = pl.Trainer(max_epochs=200, accelerator='gpu', enable_progress_bar=False)\n",
"trainer = pl.Trainer(max_epochs=200, accelerator=\"gpu\", enable_progress_bar=False)\n",
"\n",
"trainer.fit(model, datamodule)\n",
"train_metrics = trainer.callback_metrics"
Expand Down Expand Up @@ -395,7 +397,7 @@
],
"source": [
"for key in train_metrics:\n",
" print(key,': ', train_metrics[key].item())"
" print(key,\": \", train_metrics[key].item())"
]
},
{
Expand Down

0 comments on commit 6c1a0c1

Please sign in to comment.