Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Dec 8, 2024
1 parent b1d4bf6 commit ddfa177
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 44 deletions.
26 changes: 14 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ repos:
args: ["--pytest-test-first"]
- id: trailing-whitespace

- repo: https://github.com/pre-commit/pygrep-hooks
rev: "v1.10.0"
hooks:
- id: rst-backticks
- id: rst-directive-colons
- id: rst-inline-touching-normal

- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.29.4
hooks:
- id: check-dependabot
- id: check-github-workflows
- id: check-readthedocs

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: "v0.7.3"
hooks:
Expand Down Expand Up @@ -53,18 +67,6 @@ repos:
hooks:
- id: mypy
files: src
additional_dependencies:
- pytest

- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.10.0
hooks:
- id: rst-directive-colons
# Detect mistake of rst directive not ending with double colon.
- id: rst-inline-touching-normal
# Detect mistake of inline code touching normal text in rst.
- id: text-unicode-replacement-char
# Forbid files which have a UTF-8 Unicode replacement character.

- repo: https://github.com/codespell-project/codespell
rev: "v2.3.0"
Expand Down
30 changes: 30 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"array-api-compat>=1.9.1",
"astropy>=7.0",
"numpy>=2.0",
"typing-extensions>=4.12.2",
]
dynamic = ["version"]

Expand Down Expand Up @@ -105,6 +106,30 @@ exclude_lines = [
"@overload",
]

[tool.mypy]
python_version = "3.11"
files = ["quantity"]
strict = true

disallow_incomplete_defs = true
disallow_untyped_defs = false
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
warn_return_any = true
warn_unreachable = true
warn_unused_configs = true

[[tool.mypy.overrides]]
module = ["quantity._dev.*", "quantity.tests.*"]
ignore_errors = true

[[tool.mypy.overrides]]
ignore_missing_imports = true
module = [
"astropy.*",
"array_api_compat.*"
]


[tool.ruff]
exclude=[ # package template provided files.
"setup.py",
Expand Down Expand Up @@ -147,3 +172,8 @@ ignore = [
[tool.ruff.lint.per-file-ignores]
"tests/**" = ["T20"]
"noxfile.py" = ["T20"]

[dependency-groups]
typing = [
"mypy>=1.13.0",
]
17 changes: 17 additions & 0 deletions src/quantity/_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""Minimal definition of the Array API."""

from __future__ import annotations

from typing import Any, Protocol


class HasArrayNameSpace(Protocol):
"""Minimal defintion of the Array API."""

def __array_namespace__(self) -> Any: ...


class Array(HasArrayNameSpace, Protocol):
"""Minimal defintion of the Array API."""

def __pow__(self, other: Any) -> Array: ...
24 changes: 24 additions & 0 deletions src/quantity/_quantity_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Minimal definition of the Quantity API."""

__all__ = ["Quantity", "ArrayQuantity", "Unit"]

from typing import Protocol, runtime_checkable

from astropy.units import UnitBase as Unit

from ._array_api import Array


@runtime_checkable
class Quantity(Protocol):
"""Minimal definition of the Quantity API."""

value: Array
unit: Unit


@runtime_checkable
class ArrayQuantity(Quantity, Array, Protocol):
"""An array-valued Quantity."""

...
115 changes: 86 additions & 29 deletions src/quantity/_src/core.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,31 @@
"""Quantity."""
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from __future__ import annotations

import operator
from collections.abc import Callable
from dataclasses import dataclass, replace
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, Union, cast, overload

import array_api_compat
import astropy.units as u
import numpy as np
from astropy.units import UnitBase as Unit
from astropy.units.quantity_helper import UFUNC_HELPERS

from .api import QuantityArray
from .utils import has_array_namespace

if TYPE_CHECKING:
from typing import Any
from types import NotImplementedType
from typing import Any, Self

from ._array_api import Array
from ._quantity_api import ArrayQuantity, Unit


T = TypeVar("T")

from .api import Unit
from .array_api import Array
Expand All @@ -33,9 +44,9 @@ def get_value_and_unit(
)


def value_in_unit(value, unit):
def value_in_unit(value: Array, unit: Unit) -> Array:
v_value, v_unit = get_value_and_unit(value, default_unit=DIMENSIONLESS)
return v_unit.to(unit, v_value)
return cast(Array, v_unit.to(unit, v_value))


_OP_TO_NP_FUNC = {
Expand All @@ -50,7 +61,12 @@ def value_in_unit(value, unit):
OP_HELPERS = {op: UFUNC_HELPERS[np_func] for op, np_func in _OP_TO_NP_FUNC.items()}


def _make_op(fop, mode):
QuantityOpCallable: TypeAlias = Callable[
["Quantity", Any], Union["Quantity", NotImplementedType]
]


def _make_op(fop: str, mode: str) -> QuantityOpCallable:
assert mode in "fri"
op = fop if mode == "f" else "__" + mode + fop[2:]
helper = OP_HELPERS[fop]
Expand All @@ -70,27 +86,29 @@ def __op__(self, other):
return __op__


def _make_ops(op):
return tuple(_make_op(op, mode) for mode in "fri")
def _make_ops(
op: str,
) -> tuple[QuantityOpCallable, QuantityOpCallable, QuantityOpCallable]:
return (_make_op(op, "f"), _make_op(op, "r"), _make_op(op, "i"))


def _make_comp(comp):
def __comp__(self, other):
def _make_comp(comp: str) -> Callable[[Quantity, Any], Array]:
def _comp_(self: Quantity, other: Any) -> Array | NotImplementedType:
try:
other = value_in_unit(other, self.unit)
except Exception:
return NotImplemented
return getattr(self.value, comp)(other)

return __comp__
return _comp_


def _make_deferred(attr):
def _make_deferred(attr: str) -> Callable[[Quantity], property]:
# Use array_api_compat getter if available (size, device), since
# some array formats provide inconsistent implementations.
attr_getter = getattr(array_api_compat, attr, operator.attrgetter(attr))

def deferred(self):
def deferred(self: Quantity):
return attr_getter(self.value)

return property(deferred)
Expand Down Expand Up @@ -129,32 +147,61 @@ def defer_dimensionless(self):
return defer_dimensionless


def _check_pow_args(exp, mod):
if mod is not None:
return NotImplemented
# -----------------


@overload
def _parse_pow_mod(mod: None, /) -> None: ...


@overload
def _parse_pow_mod(mod: object, /) -> NotImplementedType: ...


def _parse_pow_mod(mod: T, /) -> T | NotImplementedType:
return mod if mod is None else NotImplemented # type: ignore[redundant-expr]


if not isinstance(exp, PYTHON_NUMBER):
# -----------------


@overload
def _check_pow_exp(exp: Array | PYTHON_NUMBER, /) -> PYTHON_NUMBER: ...


@overload
def _check_pow_exp(exp: object, /) -> NotImplementedType: ...


def _check_pow_exp(exp: Any, /) -> PYTHON_NUMBER | NotImplementedType:
out: PYTHON_NUMBER
if isinstance(exp, PYTHON_NUMBER):
out = exp
else:
try:
exp = exp.__complex__()
out = complex(exp)
except Exception:
try:
return exp.__float__()
return float(exp)
except Exception:
return NotImplemented

return exp.real if exp.imag == 0 else exp
return out.real if out.imag == 0 else out


@dataclass(frozen=True, eq=False)
class Quantity:
value: Any
unit: u.UnitBase
value: Array
unit: Unit

def __array_namespace__(self, *, api_version: str | None = None) -> Any:
# TODO: make our own?
del api_version
return np

def _operate(self, other, op, units_helper):
def _operate(
self, other: Any, op: Any, units_helper: Any
) -> Self | NotImplementedType:
if not has_array_namespace(other) and not isinstance(other, PYTHON_NUMBER):
# HACK: unit should take care of this!
if not isinstance(other, u.UnitBase):
Expand Down Expand Up @@ -223,9 +270,11 @@ def _operate(self, other, op, units_helper):

# TODO: __dlpack__, __dlpack_device__

def __pow__(self, exp, mod=None):
exp = _check_pow_args(exp, mod)
if exp is NotImplemented:
def __pow__(self, exp: Any, mod: Any = None) -> Self | NotImplementedType:
if (mod := _parse_pow_mod(mod)) is NotImplemented:
return NotImplemented

if (exp := _check_pow_exp(exp)) is NotImplemented:
return NotImplemented

value = self.value.__pow__(exp)
Expand All @@ -234,17 +283,25 @@ def __pow__(self, exp, mod=None):
return replace(self, value=value, unit=self.unit**exp)

def __ipow__(self, exp, mod=None):
exp = _check_pow_args(exp, mod)
if exp is NotImplemented:
if (mod := _parse_pow_mod(mod)) is NotImplemented:
return NotImplemented

if (exp := _check_pow_exp(exp)) is NotImplemented:
return NotImplemented

value = self.value.__ipow__(exp)
if value is NotImplemented:
return NotImplemented
return replace(self, value=value, unit=self.unit**exp)

def __setitem__(self, item, value):
self.value[item] = value_in_unit(value, self.unit)
def __setitem__(self, item: Any, value: Any) -> None:
"""Call the setitem method of the array for the value in the unit.
The Array API does not guarantee mutability of the underlying array,
so this method will raise an exception if the array is immutable.
"""
self.value[item] = value_in_unit(value, self.unit) # type: ignore[index]

__array_ufunc__ = None
__array_function__ = None
7 changes: 4 additions & 3 deletions src/quantity/_src/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Utility functions for the quantity package."""

from typing import Any, TypeGuard

import array_api_compat


def has_array_namespace(arg: object) -> bool:
def has_array_namespace(arg: Any) -> TypeGuard[Array]:
try:
array_api_compat.array_namespace(arg)
except TypeError:
return False
else:
return True
return True
Empty file added src/quantity/py.typed
Empty file.
3 changes: 3 additions & 0 deletions src/quantity/version.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# NOTE: First try _dev.scm_version if it exists and setuptools_scm is installed
# This file is not included in wheels/tarballs, so otherwise it will
# fall back on the generated _version module.

__all__ = ['version']

version: str
try:
try:
Expand Down

0 comments on commit ddfa177

Please sign in to comment.