Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support (most of) the elementwise functions in the Array API #22

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 127 additions & 2 deletions src/quantity/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import array_api_compat
import astropy.units as u
import numpy as np
from astropy.units.quantity_helper import UFUNC_HELPERS
from astropy.units.quantity_helper import UFUNC_HELPERS, converters_and_unit

from .api import QuantityArray
from .utils import has_array_namespace
Expand All @@ -23,6 +23,21 @@
DIMENSIONLESS = u.dimensionless_unscaled

PYTHON_NUMBER = float | int | complex
TUPLE_OR_LIST = tuple | list

NORMALIZED_FUNCTION_NAMES = {
"absolute": "abs",
"arccos": "acos",
"arccosh": "acosh",
"arcsin": "asin",
"arcsinh": "asinh",
"arctan": "atan",
"arctan2": "atan2",
"arctanh": "atanh",
"conjugate": "conj",
"true_divide": "divide",
"power": "pow",
}


def get_value_and_unit(
Expand Down Expand Up @@ -247,5 +262,115 @@ def __ipow__(self, exp, mod=None):
def __setitem__(self, item, value):
self.value[item] = value_in_unit(value, self.unit)

__array_ufunc__ = None
def __array_ufunc__(self, function, method, *inputs, **kwargs):
# Used for our namespace.
if (where := kwargs.get("where")) is not None and isinstance(where, Quantity):
return NotImplemented

converters, unit = converters_and_unit(function, method, *inputs)
out = kwargs.get("out")
if out is not None:
# If pre-allocated output is used, check it is suitable.
# This also returns array view, to ensure we don't loop back.
kwargs["out"] = self._check_output(
out[0] if len(out) == 1 else out, unit, function=function
)

if method == "reduce" and "initial" in kwargs and unit is not None:
# Special-case for initial argument for reductions like
# np.add.reduce. This should be converted to the output unit as
# well, which is typically the same as the input unit (but can
# in principle be different: unitless for np.equal, radian
# for np.arctan2, though those are not necessarily useful!)
kwargs["initial"] = self._to_own_unit(kwargs["initial"], unit=unit)

input_values = [get_value_and_unit(in_)[0] for in_ in inputs]
if not all(
isinstance(v, PYTHON_NUMBER) or has_array_namespace(v)
for v in (input_values[:-1] if method == "reduceat" else input_values)
):
return NotImplemented
input_values = [
v if conv is None else conv(v)
for (v, conv) in zip(input_values, converters, strict=True)
]
try:
xp = self.value.__array_namespace__()
except AttributeError:
try:
xp = array_api_compat.array_namespace(self.value)
except TypeError:
xp = np

if xp is np:
xp_func = function
else:
function_name = NORMALIZED_FUNCTION_NAMES.get(n := function.__name__, n)
xp_func = getattr(xp, function_name)
try:
result = getattr(xp_func, method)(*input_values, **kwargs)
except Exception:
# TODO: JAX supports "at" method if one passes in inplace=False.
return NotImplemented
return self._result_as_quantity(result, unit)

@classmethod
def _check_output(cls, output, unit, function=None):
if isinstance(output, tuple):
return tuple(
cls._check_output(output_, unit_, function)
for output_, unit_ in zip(output, unit, strict=True)
)

if output is None:
return None

if unit is None:
if isinstance(output, Quantity):
msg = "Cannot store non-quantity output{} in {} instance"
raise TypeError(
msg.format(
(
f" from {function.__name__} function"
if function is not None
else ""
),
type(output),
)
)
return output

if not isinstance(output, Quantity):
if unit == DIMENSIONLESS:
return output

msg = (
"Cannot store output with unit '{}'{} in {} instance. "
"Use {} instance instead."
)
raise u.UnitTypeError(
msg.format(
unit,
(
f" from {function.__name__} function"
if function is not None
else ""
),
type(output),
cls,
)
)
return output.value

@classmethod
def _result_as_quantity(cls, result, unit):
if isinstance(result, TUPLE_OR_LIST):
# Some np.linalg functions return namedtuple, which is handy to access
# elements by name, but cannot be directly initialized with an iterator.
result_cls = getattr(result, "_make", result.__class__)
return result_cls(cls(r, u) for (r, u) in zip(result, unit, strict=True))

# If needed, weap the result array as a Quantity with the proper unit.
return result if unit is None else cls(result, unit)

__array_function__ = None
30 changes: 30 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import astropy.units as u
import numpy as np
from astropy.utils.decorators import classproperty
from numpy.testing import assert_array_almost_equal_nulp, assert_array_equal

ARRAY_NAMESPACES = []


class ANSTests:
IMMUTABLE = False # default
NO_SETITEM = False
NO_OUTPUTS = False

def __init_subclass__(cls, **kwargs):
# Add class to namespaces available for testing if the underlying
Expand Down Expand Up @@ -47,6 +49,8 @@ def teardown_class(cls):


class UsingArrayAPIStrict(MonkeyPatchUnitConversion, ANSTests):
NO_OUTPUTS = True

@classproperty(lazy=True)
def xp(cls):
return __import__("array_api_strict")
Expand All @@ -65,7 +69,33 @@ def xp(cls):
class UsingJAX(MonkeyPatchUnitConversion, ANSTests):
IMMUTABLE = True
NO_SETITEM = True
NO_OUTPUTS = True

@classproperty(lazy=True)
def xp(cls):
return __import__("jax").numpy


def assert_quantity_equal(q1, q2, nulp=0):
assert q1.unit == q2.unit
assert q1.value.__class__ is q2.value.__class__
if nulp:
assert_array_almost_equal_nulp(q1.value, q2.value, nulp=nulp)
else:
assert_array_equal(q1.value, q2.value)


class TrackingNameSpace:
"""Intermediate namespace that tracks attributes that were used.

Used to check whether we test complete sets of functions in the Array API.
"""

def __init__(self, ns):
self.ns = ns
self.used_attrs = set()

def __getattr__(self, attr):
if not attr.startswith("_"):
self.used_attrs.add(attr)
return getattr(self.ns, attr)
Loading
Loading