Skip to content

Commit

Permalink
refactor: move functions to files that make more sense (#65)
Browse files Browse the repository at this point in the history
Signed-off-by: Henry Schreiner <[email protected]>
  • Loading branch information
henryiii authored Jun 1, 2023
1 parent 8267a25 commit 75dbae7
Show file tree
Hide file tree
Showing 8 changed files with 118 additions and 100 deletions.
2 changes: 1 addition & 1 deletion src/scikit_hep_repo_review/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@

__version__ = "0.5.1"

__all__ = ("__version__",)
__all__ = ["__version__"]
34 changes: 32 additions & 2 deletions src/scikit_hep_repo_review/checks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from __future__ import annotations

from collections.abc import Set
from typing import ClassVar, Protocol
import importlib.metadata
from collections.abc import Mapping, Set
from typing import Any, ClassVar, Protocol

from ..fixtures import apply_fixtures

__all__ = ["Check", "collect_checks", "is_allowed"]


class Check(Protocol):
Expand All @@ -10,3 +15,28 @@ class Check(Protocol):

def check(self) -> bool | None:
...


def collect_checks(fixtures: Mapping[str, Any]) -> dict[str, Check]:
check_functions = (
ep.load()
for ep in importlib.metadata.entry_points(group="scikit_hep_repo_review.checks")
)

return {
k: v
for func in check_functions
for k, v in apply_fixtures(fixtures, func).items()
}


def is_allowed(ignore_list: Set[str], name: str) -> bool:
"""
Skips the check if the name is in the ignore list or if the name without
the number is in the ignore list.
"""
if name in ignore_list:
return False
if name.rstrip("0123456789") in ignore_list:
return False
return True
17 changes: 14 additions & 3 deletions src/scikit_hep_repo_review/families.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
from __future__ import annotations

from typing import TypedDict
import importlib.metadata
import typing

__all__ = ["Family", "get_familes"]
__all__ = ["Family", "collect_families", "get_familes"]


def __dir__() -> list[str]:
return __all__


class Family(TypedDict, total=False):
class Family(typing.TypedDict, total=False):
name: str # defaults to key
order: int # defaults to 0


def collect_families() -> dict[str, Family]:
return {
name: family
for ep in importlib.metadata.entry_points(
group="scikit_hep_repo_review.families"
)
for name, family in ep.load()().items()
}


def get_familes() -> dict[str, Family]:
return {
"general": Family(
Expand Down
54 changes: 53 additions & 1 deletion src/scikit_hep_repo_review/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
from __future__ import annotations

import graphlib
import importlib.metadata
import inspect
import typing
from collections.abc import Callable, Mapping
from typing import Any

from ._compat import tomllib
from ._compat.importlib.resources.abc import Traversable

__all__ = ["pyproject", "package"]
__all__ = [
"pyproject",
"package",
"compute_fixtures",
"apply_fixtures",
"collect_fixtures",
]


def __dir__() -> list[str]:
Expand All @@ -22,3 +33,44 @@ def pyproject(package: Traversable) -> dict[str, Any]:

def package(package: Traversable) -> Traversable:
return package


def compute_fixtures(
package: Traversable, fixtures: Mapping[str, Callable[..., Any]]
) -> dict[str, Any]:
results: dict[str, Any] = {"package": package}
graph = {
name: set() if name == "package" else inspect.signature(fix).parameters.keys()
for name, fix in fixtures.items()
}
ts = graphlib.TopologicalSorter(graph)
for fixture_name in ts.static_order():
if fixture_name == "package":
continue
func = fixtures[fixture_name]
signature = inspect.signature(func)
kwargs = {name: results[name] for name in signature.parameters}
results[fixture_name] = fixtures[fixture_name](**kwargs)
return results


T = typing.TypeVar("T")


def apply_fixtures(computed_fixtures: Mapping[str, Any], func: Callable[..., T]) -> T:
signature = inspect.signature(func)
kwargs = {
name: value
for name, value in computed_fixtures.items()
if name in signature.parameters
}
return func(**kwargs)


def collect_fixtures() -> dict[str, Callable[[Traversable], Any]]:
return {
ep.name: ep.load()
for ep in importlib.metadata.entry_points(
group="scikit_hep_repo_review.fixtures"
)
}
6 changes: 6 additions & 0 deletions src/scikit_hep_repo_review/ghpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@

from ._compat.importlib.resources.abc import Traversable

__all__ = ["GHPath"]


def __dir__() -> list[str]:
return __all__


@dataclasses.dataclass(frozen=True, kw_only=True)
class GHPath(Traversable):
Expand Down
94 changes: 8 additions & 86 deletions src/scikit_hep_repo_review/processor.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from __future__ import annotations

import dataclasses
import importlib.metadata
import inspect
import graphlib
import textwrap
import typing
from collections.abc import Callable, Mapping, Sequence
from graphlib import TopologicalSorter
from typing import Any, TypeVar
from collections.abc import Sequence

from markdown_it import MarkdownIt
import markdown_it

from ._compat.importlib.resources.abc import Traversable
from .checks import Check
from .families import Family
from .fixtures import pyproject
from .checks import Check, collect_checks, is_allowed
from .families import Family, collect_families
from .fixtures import apply_fixtures, collect_fixtures, compute_fixtures, pyproject

__all__ = ["Result", "ResultDict", "ProcessReturn", "process", "as_simple_dict"]

Expand All @@ -23,9 +20,7 @@ def __dir__() -> list[str]:
return __all__


T = TypeVar("T")

md = MarkdownIt()
md = markdown_it.MarkdownIt()


# Helper to get the type in the JSON style returns
Expand Down Expand Up @@ -54,79 +49,6 @@ class ProcessReturn(typing.NamedTuple):
results: list[Result]


def is_allowed(ignore_list: set[str], name: str) -> bool:
"""
Skips the check if the name is in the ignore list or if the name without
the number is in the ignore list.
"""
if name in ignore_list:
return False
if name.rstrip("0123456789") in ignore_list:
return False
return True


def compute_fixtures(
package: Traversable, fixtures: Mapping[str, Callable[..., Any]]
) -> dict[str, Any]:
results: dict[str, Any] = {"package": package}
graph = {
name: set() if name == "package" else inspect.signature(fix).parameters.keys()
for name, fix in fixtures.items()
}
ts = TopologicalSorter(graph)
for fixture_name in ts.static_order():
if fixture_name == "package":
continue
func = fixtures[fixture_name]
signature = inspect.signature(func)
kwargs = {name: results[name] for name in signature.parameters}
results[fixture_name] = fixtures[fixture_name](**kwargs)
return results


def apply_fixtures(computed_fixtures: Mapping[str, Any], func: Callable[..., T]) -> T:
signature = inspect.signature(func)
kwargs = {
name: value
for name, value in computed_fixtures.items()
if name in signature.parameters
}
return func(**kwargs)


def collect_fixtures() -> dict[str, Callable[[Traversable], Any]]:
return {
ep.name: ep.load()
for ep in importlib.metadata.entry_points(
group="scikit_hep_repo_review.fixtures"
)
}


def collect_checks(fixtures: Mapping[str, Any]) -> dict[str, Check]:
check_functions = (
ep.load()
for ep in importlib.metadata.entry_points(group="scikit_hep_repo_review.checks")
)

return {
k: v
for func in check_functions
for k, v in apply_fixtures(fixtures, func).items()
}


def collect_families() -> dict[str, Family]:
return {
name: family
for ep in importlib.metadata.entry_points(
group="scikit_hep_repo_review.families"
)
for name, family in ep.load()().items()
}


def process(package: Traversable, *, ignore: Sequence[str] = ()) -> ProcessReturn:
"""
Process the package and return a dictionary of results.
Expand Down Expand Up @@ -167,7 +89,7 @@ def process(package: Traversable, *, ignore: Sequence[str] = ()) -> ProcessRetur
completed: dict[str, bool | None] = {}

# Run all the checks in topological order
ts = TopologicalSorter(graph)
ts = graphlib.TopologicalSorter(graph)
for name in ts.static_order():
if all(completed.get(n, False) for n in graph[name]):
completed[name] = apply_fixtures(fixtures, tasks[name].check)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import scikit_hep_repo_review.processor
from scikit_hep_repo_review._compat.importlib.resources.abc import Traversable
from scikit_hep_repo_review.checks import collect_checks


class D100:
Expand Down Expand Up @@ -51,7 +52,7 @@ def test_load_entry_point(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(
importlib.metadata, "entry_points", lambda group: [ep] # noqa: ARG005
)
checks = scikit_hep_repo_review.processor.collect_checks({"package": Path(".")})
checks = collect_checks({"package": Path(".")})

assert len(checks) == 2
assert "D100" in checks
Expand Down
8 changes: 2 additions & 6 deletions tests/test_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@
import pytest

from scikit_hep_repo_review._compat.importlib.resources.abc import Traversable
from scikit_hep_repo_review.fixtures import package
from scikit_hep_repo_review.processor import (
apply_fixtures,
collect_checks,
compute_fixtures,
)
from scikit_hep_repo_review.checks import collect_checks
from scikit_hep_repo_review.fixtures import apply_fixtures, compute_fixtures, package


class D100:
Expand Down

0 comments on commit 75dbae7

Please sign in to comment.