Skip to content

Commit

Permalink
lint: fix unsafe defaults (B006) and unnecessary Literal type hints (…
Browse files Browse the repository at this point in the history
…PYI051)
  • Loading branch information
DaniBodor committed Feb 3, 2024
1 parent f3c350a commit 38334ad
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 15 deletions.
8 changes: 3 additions & 5 deletions deeprank2/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as targets

# ruff: noqa: PYI051 (redundant-literal-union), the literal is a special case, while the str is generic

_log = logging.getLogger(__name__)


Expand Down Expand Up @@ -482,7 +480,7 @@ def __init__(
hdf5_path: str | list,
subset: list[str] | None = None,
train_source: str | GridDataset | None = None,
features: list[str] | str | Literal["all"] | None = "all",
features: list[str] | str | None = "all",
target: str | None = None,
target_transform: bool = False,
target_filter: dict[str, str] | None = None,
Expand Down Expand Up @@ -733,8 +731,8 @@ def __init__( # noqa: C901
hdf5_path: str | list,
subset: list[str] | None = None,
train_source: str | GridDataset | None = None,
node_features: list[str] | str | Literal["all"] | None = "all",
edge_features: list[str] | str | Literal["all"] | None = "all",
node_features: list[str] | str | None = "all",
edge_features: list[str] | str | None = "all",
features_transform: dict | None = None,
clustering_method: str | None = None,
target: str | None = None,
Expand Down
8 changes: 3 additions & 5 deletions deeprank2/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,7 @@ def _process_one_query(self, query: Query) -> None:
def process(
self,
prefix: str = "processed-queries",
feature_modules: list[ModuleType, str] | ModuleType | str | Literal["all"] = [ # noqa: PYI051
components,
contact,
],
feature_modules: list[ModuleType, str] | ModuleType | str | None = None,
cpu_count: int | None = None,
combine_output: bool = True,
grid_settings: GridSettings | None = None,
Expand Down Expand Up @@ -540,6 +537,7 @@ def process(
list[str]: The list of paths of the generated HDF5 files.
"""
# set defaults
feature_modules = feature_modules or [components, contact]
self._prefix = "processed-queries" if not prefix else re.sub(".hdf5$", "", prefix) # scrape extension if present

max_cpus = os.cpu_count()
Expand Down Expand Up @@ -577,7 +575,7 @@ def process(

return output_paths

def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str | Literal["all"]) -> list[str]: # noqa: PYI051
def _set_feature_modules(self, feature_modules: list[ModuleType, str] | ModuleType | str) -> list[str]:
"""Convert `feature_modules` to list[str] irrespective of input type.
Raises:
Expand Down
3 changes: 2 additions & 1 deletion tests/test_querycollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def _querycollection_tester(
query_type: str,
n_queries: int = 3,
feature_modules: ModuleType | list[ModuleType] = [components, contact],
feature_modules: ModuleType | list[ModuleType] | None = None,
cpu_count: int = 1,
combine_output: bool = True,
) -> (QueryCollection, str, list[str]):
Expand All @@ -35,6 +35,7 @@ def _querycollection_tester(
combine_output (bool): boolean for combining the hdf5 files generated by the processes.
By default, the hdf5 files generated are combined into one, and then deleted.
"""
feature_modules = feature_modules or [components, contact]
if query_type == "ppi":
queries = [
ProteinProteinInterfaceQuery(
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from deeprank2.domain import edgestorage as Efeat
from deeprank2.domain import gridstorage
from deeprank2.domain import nodestorage as Nfeat
from deeprank2.domain import targetstorage as Target
from deeprank2.domain import targetstorage as targets
from deeprank2.molstruct.pair import ResidueContact
from deeprank2.utils.buildgraph import get_structure
from deeprank2.utils.graph import Edge, Graph, Node
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_graph_write_to_hdf5(graph: Graph) -> None:
assert len(np.nonzero(edge_features_group[Efeat.INDEX][()])) > 0

# target
assert grp[Target.VALUES][target_name][()] == target_value
assert grp[targets.VALUES][target_name][()] == target_value

finally:
shutil.rmtree(tmp_dir_path) # clean up after the test
Expand Down Expand Up @@ -146,7 +146,7 @@ def test_graph_write_as_grid_to_hdf5(graph: Graph) -> None:
assert np.all(data.shape == tuple(grid_settings.points_counts))

# target
assert grp[Target.VALUES][target_name][()] == target_value
assert grp[targets.VALUES][target_name][()] == target_value

finally:
shutil.rmtree(tmp_dir_path) # clean up after the test
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_graph_augmented_write_as_grid_to_hdf5(graph: Graph) -> None:
assert np.abs(np.sum(data) - np.sum(unaugmented_data)).item() < 0.2

# target
assert grp[Target.VALUES][target_name][()] == target_value
assert grp[targets.VALUES][target_name][()] == target_value

finally:
shutil.rmtree(tmp_dir_path) # clean up after the test

0 comments on commit 38334ad

Please sign in to comment.