From 79d4f22184cc1f5d0778366f2c26f4ba66312e5e Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 27 Nov 2024 17:19:24 -0500 Subject: [PATCH] WIP Signed-off-by: nstarman --- .pre-commit-config.yaml | 24 ++++--- pyproject.toml | 30 +++++++++ quantity/_array_api.py | 17 +++++ quantity/_quantity_api.py | 24 +++++++ quantity/py.typed | 0 src/quantity/_src/core.py | 129 ++++++++++++++++++++++++++----------- src/quantity/_src/utils.py | 7 +- src/quantity/version.py | 3 + 8 files changed, 184 insertions(+), 50 deletions(-) create mode 100644 quantity/_array_api.py create mode 100644 quantity/_quantity_api.py create mode 100644 quantity/py.typed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d16f775..39b20ee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,6 +24,20 @@ repos: args: ["--pytest-test-first"] - id: trailing-whitespace + - repo: https://github.com/pre-commit/pygrep-hooks + rev: "v1.10.0" + hooks: + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + + - repo: https://github.com/python-jsonschema/check-jsonschema + rev: 0.29.4 + hooks: + - id: check-dependabot + - id: check-github-workflows + - id: check-readthedocs + - repo: https://github.com/astral-sh/ruff-pre-commit rev: "v0.7.3" hooks: @@ -56,16 +70,6 @@ repos: additional_dependencies: - pytest - - repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.10.0 - hooks: - - id: rst-directive-colons - # Detect mistake of rst directive not ending with double colon. - - id: rst-inline-touching-normal - # Detect mistake of inline code touching normal text in rst. - - id: text-unicode-replacement-char - # Forbid files which have a UTF-8 Unicode replacement character. - - repo: https://github.com/codespell-project/codespell rev: "v2.3.0" hooks: diff --git a/pyproject.toml b/pyproject.toml index 9803536..29efc9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,6 +11,7 @@ dependencies = [ "array-api-compat>=1.9.1", "astropy>=7.0", "numpy>=2.0", + "typing-extensions>=4.12.2", ] dynamic = ["version"] @@ -105,6 +106,30 @@ exclude_lines = [ "@overload", ] +[tool.mypy] + python_version = "3.11" + files = ["quantity"] + strict = true + + disallow_incomplete_defs = true + disallow_untyped_defs = false + enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"] + warn_return_any = true + warn_unreachable = true + warn_unused_configs = true + + [[tool.mypy.overrides]] + module = ["quantity._dev.*", "quantity.tests.*"] + ignore_errors = true + + [[tool.mypy.overrides]] + ignore_missing_imports = true + module = [ + "astropy.*", + "array_api_compat.*" + ] + + [tool.ruff] exclude=[ # package template provided files. "setup.py", @@ -146,3 +171,8 @@ ignore = [ [tool.ruff.lint.per-file-ignores] "tests/**" = ["T20"] "noxfile.py" = ["T20"] + +[dependency-groups] +typing = [ + "mypy>=1.13.0", +] diff --git a/quantity/_array_api.py b/quantity/_array_api.py new file mode 100644 index 0000000..8a28821 --- /dev/null +++ b/quantity/_array_api.py @@ -0,0 +1,17 @@ +"""Minimal definition of the Array API.""" + +from __future__ import annotations + +from typing import Any, Protocol + + +class HasArrayNameSpace(Protocol): + """Minimal defintion of the Array API.""" + + def __array_namespace__(self) -> Any: ... + + +class Array(HasArrayNameSpace, Protocol): + """Minimal defintion of the Array API.""" + + def __pow__(self, other: Any) -> Array: ... diff --git a/quantity/_quantity_api.py b/quantity/_quantity_api.py new file mode 100644 index 0000000..1e861c6 --- /dev/null +++ b/quantity/_quantity_api.py @@ -0,0 +1,24 @@ +"""Minimal definition of the Quantity API.""" + +__all__ = ["Quantity", "ArrayQuantity", "Unit"] + +from typing import Protocol, runtime_checkable + +from astropy.units import UnitBase as Unit + +from ._array_api import Array + + +@runtime_checkable +class Quantity(Protocol): + """Minimal definition of the Quantity API.""" + + value: Array + unit: Unit + + +@runtime_checkable +class ArrayQuantity(Quantity, Array, Protocol): + """An array-valued Quantity.""" + + ... diff --git a/quantity/py.typed b/quantity/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/src/quantity/_src/core.py b/src/quantity/_src/core.py index adfb568..49face9 100644 --- a/src/quantity/_src/core.py +++ b/src/quantity/_src/core.py @@ -1,19 +1,30 @@ +"""Quantity.""" # Licensed under a 3-clause BSD style license - see LICENSE.rst + from __future__ import annotations import operator +from collections.abc import Callable from dataclasses import dataclass, replace -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, Union, cast, overload import array_api_compat import astropy.units as u import numpy as np +from astropy.units import UnitBase as Unit from astropy.units.quantity_helper import UFUNC_HELPERS from .utils import has_array_namespace if TYPE_CHECKING: - from typing import Any + from types import NotImplementedType + from typing import Any, Self + + from ._array_api import Array + from ._quantity_api import ArrayQuantity, Unit + + +T = TypeVar("T") DIMENSIONLESS = u.dimensionless_unscaled @@ -21,19 +32,17 @@ PYTHON_NUMBER = float | int | complex -def get_value_and_unit(arg, default_unit=None): - # HACK: interoperability with astropy Quantity. Have protocol? - try: - unit = arg.unit - except AttributeError: - return arg, default_unit - else: - return arg.value, unit +def get_value_and_unit( + arg: ArrayQuantity | Array, default_unit: Unit | None = None +) -> tuple[Array, Unit]: + return ( + (arg.value, arg.unit) if isinstance(arg, ArrayQuantity) else (arg, default_unit) + ) -def value_in_unit(value, unit): +def value_in_unit(value: Array, unit: Unit) -> Array: v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS) - return v_unit.to(unit, v_value) + return cast(Array, v_unit.to(unit, v_value)) _OP_TO_NP_FUNC = { @@ -48,7 +57,12 @@ def value_in_unit(value, unit): OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()} -def _make_op(fop, mode): +QuantityOpCallable: TypeAlias = Callable[ + ["Quantity", Any], Union["Quantity", NotImplementedType] +] + + +def _make_op(fop: str, mode: str) -> QuantityOpCallable: assert mode in "fri" op = fop if mode == "f" else "__" + mode + fop[2:] helper = OP_HELPERS[fop] @@ -68,27 +82,29 @@ def __op__(self, other): return __op__ -def _make_ops(op): - return tuple(_make_op(op, mode) for mode in "fri") +def _make_ops( + op: str, +) -> tuple[QuantityOpCallable, QuantityOpCallable, QuantityOpCallable]: + return (_make_op(op, "f"), _make_op(op, "r"), _make_op(op, "i")) -def _make_comp(comp): - def __comp__(self, other): +def _make_comp(comp: str) -> Callable[[Quantity, Any], Array]: + def _comp_(self: Quantity, other: Any) -> Array | NotImplementedType: try: other = value_in_unit(other, self.unit) except Exception: return NotImplemented return getattr(self.value, comp)(other) - return __comp__ + return _comp_ -def _make_deferred(attr): +def _make_deferred(attr: str) -> Callable[[Quantity], property]: # Use array_api_compat getter if available (size, device), since # some array formats provide inconsistent implementations. attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr)) - def deferred(self): + def deferred(self: Quantity): return attr_getter(self.value) return property(deferred) @@ -127,32 +143,61 @@ def defer_dimensionless(self): return defer_dimensionless -def _check_pow_args(exp, mod): - if mod is not None: - return NotImplemented +# ----------------- + + +@overload +def _parse_pow_mod(mod: None, /) -> None: ... + + +@overload +def _parse_pow_mod(mod: object, /) -> NotImplementedType: ... + + +def _parse_pow_mod(mod: T, /) -> T | NotImplementedType: + return mod if mod is None else NotImplemented # type: ignore[redundant-expr] + - if not isinstance(exp, PYTHON_NUMBER): +# ----------------- + + +@overload +def _check_pow_exp(exp: Array | PYTHON_NUMBER, /) -> PYTHON_NUMBER: ... + + +@overload +def _check_pow_exp(exp: object, /) -> NotImplementedType: ... + + +def _check_pow_exp(exp: Any, /) -> PYTHON_NUMBER | NotImplementedType: + out: PYTHON_NUMBER + if isinstance(exp, PYTHON_NUMBER): + out = exp + else: try: - exp = exp.__complex__() + out = complex(exp) except Exception: try: - return exp.__float__() + return float(exp) except Exception: return NotImplemented - return exp.real if exp.imag == 0 else exp + return out.real if out.imag == 0 else out @dataclass(frozen=True, eq=False) class Quantity: - value: Any - unit: u.UnitBase + value: Array + unit: Unit def __array_namespace__(self, *, api_version: str | None = None) -> Any: # TODO: make our own? + del api_version return np - def _operate(self, other, op, units_helper): + def _operate( + self, other: Any, op: Any, units_helper: Any + ) -> Self | NotImplementedType: if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER): # HACK: unit should take care of this! if not isinstance(other, u.UnitBase): @@ -221,9 +266,11 @@ def _operate(self, other, op, units_helper): # TODO: __dlpack__, __dlpack_device__ - def __pow__(self, exp, mod=None): - exp = _check_pow_args(exp, mod) - if exp is NotImplemented: + def __pow__(self, exp: Any, mod: Any = None) -> Self | NotImplementedType: + if (mod := _parse_pow_mod(mod)) is NotImplemented: + return NotImplemented + + if (exp := _check_pow_exp(exp)) is NotImplemented: return NotImplemented value = self.value.__pow__(exp) @@ -232,8 +279,10 @@ def __pow__(self, exp, mod=None): return replace(self, value=value, unit=self.unit**exp) def __ipow__(self, exp, mod=None): - exp = _check_pow_args(exp, mod) - if exp is NotImplemented: + if (mod := _parse_pow_mod(mod)) is NotImplemented: + return NotImplemented + + if (exp := _check_pow_exp(exp)) is NotImplemented: return NotImplemented value = self.value.__ipow__(exp) @@ -241,8 +290,14 @@ def __ipow__(self, exp, mod=None): return NotImplemented return replace(self, value=value, unit=self.unit**exp) - def __setitem__(self, item, value): - self.value[item] = value_in_unit(value, self.unit) + def __setitem__(self, item: Any, value: Any) -> None: + """Call the setitem method of the array for the value in the unit. + + The Array API does not guarantee mutability of the underlying array, + so this method will raise an exception if the array is immutable. + + """ + self.value[item] = value_in_unit(value, self.unit) # type: ignore[index] __array_ufunc__ = None __array_function__ = None diff --git a/src/quantity/_src/utils.py b/src/quantity/_src/utils.py index f163052..2bc5a57 100644 --- a/src/quantity/_src/utils.py +++ b/src/quantity/_src/utils.py @@ -1,12 +1,13 @@ """Utility functions for the quantity package.""" +from typing import Any, TypeGuard + import array_api_compat -def has_array_namespace(arg: object) -> bool: +def has_array_namespace(arg: Any) -> TypeGuard[Array]: try: array_api_compat.array_namespace(arg) except TypeError: return False - else: - return True + return True diff --git a/src/quantity/version.py b/src/quantity/version.py index 9b13e60..7a32c3a 100644 --- a/src/quantity/version.py +++ b/src/quantity/version.py @@ -1,6 +1,9 @@ # NOTE: First try _dev.scm_version if it exists and setuptools_scm is installed # This file is not included in wheels/tarballs, so otherwise it will # fall back on the generated _version module. + +__all__ = ['version'] + version: str try: try: