Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
gbg141 committed Jun 2, 2024
1 parent 6c1a0c1 commit de4cc9b
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 271 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,7 @@ notebooks/tmp
*.pickle

# wandb
wandb/
wandb/

# ruff
.ruff_cache/
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ repos:
- --maxkb=2048
- id: requirements-txt-fixer

# - repo: https://github.com/astral-sh/ruff-pre-commit
# rev: v0.4.4
# hooks:
# - id: ruff
# args: [ --fix ]
# - id: ruff-format
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format

# - repo: https://github.com/numpy/numpydoc
# rev: v1.6.0
# hooks:
# - id: numpydoc-validation
- repo: https://github.com/numpy/numpydoc
rev: v1.6.0
hooks:
- id: numpydoc-validation
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ wrap-descriptions = 79

[tool.ruff]
target-version = "py310"
extend-include = ["*.ipynb"]
#extend-include = ["*.ipynb"]
extend-exclude = ["test", "tutorials", "notebooks"]
line-length = 79 # PEP 8 standard for maximum line length

[tool.ruff.format]
Expand Down Expand Up @@ -128,7 +129,7 @@ ignore = [
[tool.ruff.lint.pydocstyle]
convention = "numpy"

[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F403"]

[tool.setuptools.dynamic]
Expand Down
33 changes: 18 additions & 15 deletions topobenchmarkx/data/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# numpydoc ignore=GL08
from .utils import (
ensure_serializable,
generate_zero_sparse_connectivity,
get_complex_connectivity,
load_cell_complex_dataset,
load_manual_graph,
load_simplicial_dataset,
make_hash,
ensure_serializable, # noqa: F401
generate_zero_sparse_connectivity, # noqa: F401
get_complex_connectivity, # noqa: F401
load_cell_complex_dataset, # noqa: F401
load_manual_graph, # noqa: F401
load_simplicial_dataset, # noqa: F401
make_hash, # noqa: F401
)

utils_functions = [
Expand All @@ -14,12 +15,14 @@
"load_cell_complex_dataset",
"load_simplicial_dataset",
"load_manual_graph",
"make_hash",
"ensure_serializable",
]

from .split_utils import (
load_coauthorship_hypergraph_splits,
load_inductive_splits,
load_transductive_splits,
from .split_utils import ( # noqa: E402
load_coauthorship_hypergraph_splits, # noqa: F401
load_inductive_splits, # noqa: F401
load_transductive_splits, # noqa: F401
)

split_helper_functions = [
Expand All @@ -28,10 +31,10 @@
"load_transductive_splits",
]

from .io_utils import (
download_file_from_drive,
load_hypergraph_pickle_dataset,
read_us_county_demos,
from .io_utils import ( # noqa: E402
download_file_from_drive, # noqa: F401
load_hypergraph_pickle_dataset, # noqa: F401
read_us_county_demos, # noqa: F401
)

io_helper_functions = [
Expand Down
230 changes: 0 additions & 230 deletions topobenchmarkx/evaluator/comparisons.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# numpydoc ignore=GL08,PR01,RT01
import torch_geometric


class KeepSelectedDataFields(torch_geometric.transforms.BaseTransform):
class KeepSelectedDataFields(
torch_geometric.transforms.BaseTransform
): # numpydoc ignore=PR01
r"""A transform that keeps only the selected fields of the input data.
Args:
Expand All @@ -16,7 +19,9 @@ def __init__(self, **kwargs):
def __repr__(self) -> str:
return f"{self.__class__.__name__}(type={self.type!r}, parameters={self.parameters!r})"

def forward(self, data: torch_geometric.data.Data):
def forward(
self, data: torch_geometric.data.Data
): # numpydoc ignore=GL08,PR01,RT01
r"""Apply the transform to the input data.
Args:
Expand All @@ -30,7 +35,7 @@ def forward(self, data: torch_geometric.data.Data):
+ self.parameters["preserved_fields"]
)

for key in data.keys():
for key in data:
if key not in fields_to_keep:
del data[key]
return data
Loading

0 comments on commit de4cc9b

Please sign in to comment.