Skip to content

Commit

Permalink
Finish de-coupling
Browse files Browse the repository at this point in the history
  • Loading branch information
blakeNaccarato committed Jul 31, 2024
1 parent b4f6b2f commit 0d4a697
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 65 deletions.
187 changes: 138 additions & 49 deletions src/boilercore/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,11 @@
"""Common functionality of boiler repositories."""

from collections.abc import Sequence
from importlib.machinery import ModuleSpec
from itertools import chain
from pathlib import Path
from types import ModuleType
from collections.abc import Iterable
from typing import NamedTuple
from warnings import filterwarnings

from boilercore.paths import get_module_name
from boilercore.types import Action

PROJECT_PATH = Path()


class WarningFilter(NamedTuple):
"""A warning filter, e.g. to be unpacked into `warnings.filterwarnings`."""
Expand All @@ -25,49 +18,145 @@ class WarningFilter(NamedTuple):
append: bool = False


DEFAULT_CATEGORIES = [DeprecationWarning, PendingDeprecationWarning, EncodingWarning]
ERROR = "error"
DEFAULT = "default"
NO_WARNINGS = []


def filter_certain_warnings(
package: ModuleType | ModuleSpec | Path | str,
categories: Sequence[type[Warning]] = DEFAULT_CATEGORIES,
root_action: Action | None = ERROR,
package_action: Action = ERROR,
other_action: Action = DEFAULT,
other_warnings: Sequence[WarningFilter] = NO_WARNINGS,
):
"""Filter certain warnings for a package."""
def filter_boiler_warnings(other_warnings: Iterable[WarningFilter] | None = None):
"""Filter certain warnings for `boiler` projects."""
for filt in [
# Optionally filter warnings with the root action
*([WarningFilter(action=root_action)] if root_action else []),
# Filter certain categories with a package action, and third-party action otherwise
*chain.from_iterable(
filter_package_warnings(
package=package,
category=category,
action=package_action,
other_action=other_action,
)
for category in categories
),
# Additionally filter these other warnings
*other_warnings,
WarningFilter(action="default"),
*[
WarningFilter(action="error", category=category, module=r"^boiler.*")
for category in [
DeprecationWarning,
PendingDeprecationWarning,
EncodingWarning,
]
],
*WARNING_FILTERS,
*(other_warnings or []),
]:
filterwarnings(*filt)


def filter_package_warnings(
package: ModuleType | ModuleSpec | Path | str,
category: type[Warning],
action: Action = ERROR,
other_action: Action = DEFAULT,
) -> tuple[WarningFilter, WarningFilter]:
"""Get filter which filters warnings differently for the package."""
all_package_modules = rf"{get_module_name(package)}\..*"
return (
WarningFilter(action=other_action, category=category),
WarningFilter(action=action, category=category, module=all_package_modules),
)
WARNING_FILTERS = [
# * --------------------------------------------------------------------------------
# * MARK: DeprecationWarning
WarningFilter(
category=DeprecationWarning,
module="pybtex.plugin",
message=r"pkg_resources is deprecated as an API\.",
),
WarningFilter(
category=DeprecationWarning,
message=r"Deprecated call to `pkg_resources\.declare_namespace\('mpl_toolkits'\)`\.",
),
WarningFilter(
category=DeprecationWarning,
message=r"Deprecated call to `pkg_resources\.declare_namespace\('sphinxcontrib'\)`\.",
),
WarningFilter(
category=DeprecationWarning,
message=r"Deprecated call to `pkg_resources\.declare_namespace\('zc'\)`\.",
),
WarningFilter(
category=DeprecationWarning,
module=r"latexcodec\.codec",
message=r"open_text is deprecated\. Use files\(\) instead",
),
WarningFilter(
category=DeprecationWarning,
module=r"nptyping\.typing_",
message=r"`.+` is a deprecated alias for `.+`\.",
),
WarningFilter(
category=DeprecationWarning,
module=r"IPython\.core\.pylabtools",
message=r"backend2gui is deprecated.",
),
*[
WarningFilter(
category=DeprecationWarning,
message=rf"Deprecated call to `pkg_resources\.declare_namespace\('{ns}'\)`\.",
)
for ns in ["mpl_toolkits", "sphinxcontrib", "zc"]
],
*[
WarningFilter(
category=DeprecationWarning, module=r"pytest_harvest.*", message=message
)
for message in [
r"The hookspec pytest_harvest_xdist.+ uses old-style configuration options",
r"The hookimpl pytest_configure uses old-style configuration options \(marks or attributes\)\.",
]
],
# * --------------------------------------------------------------------------------
# * MARK: EncodingWarning
WarningFilter(
category=EncodingWarning, message="'encoding' argument not specified"
),
*[
WarningFilter(
category=EncodingWarning,
module=module,
message=r"'encoding' argument not specified\.",
)
for module in [r"sphinx.*", r"jupyter_client\.connect"]
],
# * --------------------------------------------------------------------------------
# * MARK: FutureWarning
WarningFilter(
category=FutureWarning,
message=r"A grouping was used that is not in the columns of the DataFrame and so was excluded from the result\. This grouping will be included in a future version of pandas\. Add the grouping as a column of the DataFrame to silence this warning\.",
),
# * --------------------------------------------------------------------------------
# * MARK: ImportWarning
WarningFilter(
# Happens during tests under some configurations
category=ImportWarning,
message=r"ImportDenier\.find_spec\(\) not found; falling back to find_module\(\)",
),
# * --------------------------------------------------------------------------------
# * MARK: RuntimeWarning
WarningFilter(
category=RuntimeWarning, message=r"invalid value encountered in power"
),
WarningFilter(
# ? https://github.com/pytest-dev/pytest-qt/issues/558#issuecomment-2143975018
category=RuntimeWarning,
message=r"Failed to disconnect .* from signal",
),
WarningFilter(
category=RuntimeWarning,
message=r"numpy\.ndarray size changed, may indicate binary incompatibility\. Expected \d+ from C header, got \d+ from PyObject",
),
WarningFilter(
category=RuntimeWarning,
message=r"Proactor event loop does not implement add_reader family of methods required for zmq.+",
),
WarningFilter(
category=UserWarning,
message=r"The palette list has more values \(\d+\) than needed \(\d+\), which may not be intended\.",
),
WarningFilter(
category=UserWarning,
action="default",
message=r"Loaded local environment variables from `\.env`",
),
WarningFilter(
category=UserWarning,
message=r"The palette list has more values \(\d+\) than needed \(\d+\), which may not be intended\.",
),
WarningFilter(
category=UserWarning,
message=r"To output multiple subplots, the figure containing the passed axes is being cleared\.",
),
# * --------------------------------------------------------------------------------
# * MARK: Combinations
*[
WarningFilter(
category=category,
# module=r"colorspacious\.comparison", # ? CI still complains
message=r"invalid escape sequence",
)
for category in [DeprecationWarning, SyntaxWarning]
],
]
"""Warning filters."""
11 changes: 4 additions & 7 deletions src/boilercore/models/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from pydantic import Field

from boilercore import PROJECT_PATH
from boilercore.fits import Fit
from boilercore.models import SynchronizedPathsYamlModel
from boilercore.models.geometry import Geometry
Expand All @@ -18,12 +17,10 @@ class Params(SynchronizedPathsYamlModel):
geometry: Geometry = Field(default_factory=Geometry, description="Geometry.")
paths: Paths

def __init__(
self,
data_file: Path = PROJECT_PATH / Path("params.yaml"),
root: Path = PROJECT_PATH,
):
super().__init__(data_file, paths=Paths(root=root.resolve()))
def __init__(self, root: Path | None = None, data_file: Path | None = None):
root = (root or Path.cwd()).resolve()
data_file = data_file or root / "params.yaml"
super().__init__(data_file, paths=Paths(root=root))


PARAMS = Params()
Expand Down
3 changes: 2 additions & 1 deletion src/boilercore/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def get_session_path(
test_data_name = Path("root")
project_test_data = Path("tests") / test_data_name
session_path = tmp_path_factory.getbasetemp() / test_data_name
package.PROJECT_PATH = session_path # type: ignore
if getattr(package, "PROJECT_PATH", None):
package.PROJECT_PATH = session_path # type: ignore
copytree(project_test_data, session_path, dirs_exist_ok=True)
return session_path

Expand Down
1 change: 1 addition & 0 deletions src/boilercore/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

Action: TypeAlias = Literal["default", "error", "ignore", "always", "module", "once"]
"""Action to take for a warning."""

Freezable: TypeAlias = (
Callable[..., Any] | Mapping[str, Any] | ItemsView[str, Any] | Iterable[Any]
)
Expand Down
19 changes: 11 additions & 8 deletions tests/boilercore_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
from cachier import cachier, set_default_params # pyright: ignore[reportMissingImports]

import boilercore
from boilercore import filter_certain_warnings
from boilercore import filter_boiler_warnings
from boilercore.hashes import hash_args
from boilercore.models.params import Params
from boilercore.notebooks import namespaces
from boilercore.notebooks.namespaces import NO_PARAMS, get_cached_nb_ns, get_ns_attrs
from boilercore.notebooks.types import Params
from boilercore.testing import get_session_path, unwrap_node
from boilercore_tests import EMPTY_NB

Expand All @@ -23,10 +24,10 @@
@pytest.fixture(autouse=True)
def _filter_certain_warnings():
"""Filter certain warnings."""
filter_certain_warnings(package=boilercore)
filter_boiler_warnings()


@pytest.fixture(autouse=True, scope="session")
@pytest.fixture()
def project_session_path(tmp_path_factory) -> Path:
"""Project session path."""
return get_session_path(tmp_path_factory, boilercore)
Expand All @@ -35,9 +36,9 @@ def project_session_path(tmp_path_factory) -> Path:
@pytest.fixture()
def params(project_session_path):
"""Parameters."""
from boilercore.models.params import PARAMS # noqa: PLC0415

return PARAMS
return Params(
root=project_session_path, data_file=project_session_path / "params.yaml"
)


@pytest.fixture(scope="session")
Expand All @@ -64,7 +65,9 @@ def custom_cachier(fun: Callable[..., Any]):
return wrapper

@custom_cachier
def fun(nb: str = EMPTY_NB, params: Params = NO_PARAMS) -> SimpleNamespace:
def fun(
nb: str = EMPTY_NB, params: namespaces.Params = NO_PARAMS
) -> SimpleNamespace:
"""Get cached minimal namespace suitable for passing to a receiving function."""
return get_cached_nb_ns(nb, params, get_ns_attrs(unwrap_node(request.node)))

Expand Down

0 comments on commit 0d4a697

Please sign in to comment.