diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index db3e4cd7..d50e0d83 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -12,7 +12,8 @@ import math import sys import warnings -from collections.abc import Collection +from collections.abc import Collection, Hashable +from functools import lru_cache from typing import ( TYPE_CHECKING, Any, @@ -61,23 +62,37 @@ _API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) +@lru_cache(100) +def _issubclass_fast(cls: type, modname: str, clsname: str) -> bool: + try: + mod = sys.modules[modname] + except KeyError: + return False + parent_cls = getattr(mod, clsname) + return issubclass(cls, parent_cls) + + def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: """Return True if `x` is a zero-gradient array. These arrays are a design quirk of Jax that may one day be removed. See https://github.com/google/jax/issues/20620. """ - if "numpy" not in sys.modules or "jax" not in sys.modules: + # Fast exit + try: + dtype = x.dtype # type: ignore[attr-defined] + except AttributeError: + return False + cls = cast(Hashable, type(dtype)) + if not _issubclass_fast(cls, "numpy.dtypes", "VoidDType"): return False - import jax - import numpy as np + if "jax" not in sys.modules: + return False - jax_float0 = cast("np.dtype[np.void]", jax.float0) - return ( - isinstance(x, np.ndarray) - and cast("npt.NDArray[np.void]", x).dtype == jax_float0 - ) + import jax + # jax.float0 is a np.dtype([('float0', 'V')]) + return dtype == jax.float0 def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: @@ -101,15 +116,12 @@ def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: is_jax_array is_pydata_sparse_array """ - # Avoid importing NumPy if it isn't already - if "numpy" not in sys.modules: - return False - - import numpy as np - # TODO: Should we reject ndarray subclasses? - return (isinstance(x, (np.ndarray, np.generic)) - and not _is_jax_zero_gradient_array(x)) # pyright: ignore[reportUnknownArgumentType] # fmt: skip + cls = cast(Hashable, type(x)) + return ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + ) and not _is_jax_zero_gradient_array(x) def is_cupy_array(x: object) -> bool: @@ -133,14 +145,8 @@ def is_cupy_array(x: object) -> bool: is_jax_array is_pydata_sparse_array """ - # Avoid importing CuPy if it isn't already - if "cupy" not in sys.modules: - return False - - import cupy as cp # pyright: ignore[reportMissingTypeStubs] - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "cupy", "ndarray") def is_torch_array(x: object) -> TypeIs[torch.Tensor]: @@ -161,14 +167,8 @@ def is_torch_array(x: object) -> TypeIs[torch.Tensor]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "torch" not in sys.modules: - return False - - import torch - - # TODO: Should we reject ndarray subclasses? - return isinstance(x, torch.Tensor) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "torch", "Tensor") def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: @@ -190,13 +190,8 @@ def is_ndonnx_array(x: object) -> TypeIs[ndx.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing torch if it isn't already - if "ndonnx" not in sys.modules: - return False - - import ndonnx as ndx - - return isinstance(x, ndx.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "ndonnx", "Array") def is_dask_array(x: object) -> TypeIs[da.Array]: @@ -218,13 +213,8 @@ def is_dask_array(x: object) -> TypeIs[da.Array]: is_jax_array is_pydata_sparse_array """ - # Avoid importing dask if it isn't already - if "dask.array" not in sys.modules: - return False - - import dask.array - - return isinstance(x, dask.array.Array) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "dask.array", "Array") def is_jax_array(x: object) -> TypeIs[jax.Array]: @@ -247,13 +237,8 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]: is_dask_array is_pydata_sparse_array """ - # Avoid importing jax if it isn't already - if "jax" not in sys.modules: - return False - - import jax - - return isinstance(x, jax.Array) or _is_jax_zero_gradient_array(x) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x) def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: @@ -276,14 +261,9 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: is_dask_array is_jax_array """ - # Avoid importing jax if it isn't already - if "sparse" not in sys.modules: - return False - - import sparse # pyright: ignore[reportMissingTypeStubs] - # TODO: Account for other backends. - return isinstance(x, sparse.SparseArray) + cls = cast(Hashable, type(x)) + return _issubclass_fast(cls, "sparse", "SparseArray") def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] @@ -302,13 +282,23 @@ def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[repo is_jax_array """ return ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_dask_array(x) - or is_jax_array(x) - or is_pydata_sparse_array(x) - or hasattr(x, "__array_namespace__") + hasattr(x, '__array_namespace__') + or _is_array_api_cls(cast(Hashable, type(x))) + ) + + +@lru_cache(100) +def _is_array_api_cls(cls: type) -> bool: + return ( + # TODO: drop support for numpy<2 which didn't have __array_namespace__ + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + # TODO: drop support for jax<0.4.32 which didn't have __array_namespace__ + or _issubclass_fast(cls, "jax", "Array") ) @@ -317,6 +307,7 @@ def _compat_module_name() -> str: return __name__.removesuffix(".common._helpers") +@lru_cache(100) def is_numpy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a NumPy namespace. @@ -338,6 +329,7 @@ def is_numpy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"numpy", _compat_module_name() + ".numpy"} +@lru_cache(100) def is_cupy_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a CuPy namespace. @@ -359,6 +351,7 @@ def is_cupy_namespace(xp: Namespace) -> bool: return xp.__name__ in {"cupy", _compat_module_name() + ".cupy"} +@lru_cache(100) def is_torch_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a PyTorch namespace. @@ -399,6 +392,7 @@ def is_ndonnx_namespace(xp: Namespace) -> bool: return xp.__name__ == "ndonnx" +@lru_cache(100) def is_dask_namespace(xp: Namespace) -> bool: """ Returns True if `xp` is a Dask namespace. @@ -939,6 +933,19 @@ def size(x: HasShape[Collection[SupportsIndex | None]]) -> int | None: return None if math.isnan(out) else out +@lru_cache(100) +def _is_writeable_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): + return False + if _is_array_api_cls(cls): + return True + return None + + def is_writeable_array(x: object) -> bool: """ Return False if ``x.__setitem__`` is expected to raise; True otherwise. @@ -949,11 +956,32 @@ def is_writeable_array(x: object) -> bool: As there is no standard way to check if an array is writeable without actually writing to it, this function blindly returns True for all unknown array types. """ - if is_numpy_array(x): - return x.flags.writeable - if is_jax_array(x) or is_pydata_sparse_array(x): + cls = cast(Hashable, type(x)) + if _issubclass_fast(cls, "numpy", "ndarray"): + return cast("npt.NDArray", x).flags.writeable + res = _is_writeable_cls(cls) + if res is not None: + return res + return hasattr(x, '__array_namespace__') + + +@lru_cache(100) +def _is_lazy_cls(cls: type) -> bool | None: + if ( + _issubclass_fast(cls, "numpy", "ndarray") + or _issubclass_fast(cls, "numpy", "generic") + or _issubclass_fast(cls, "cupy", "ndarray") + or _issubclass_fast(cls, "torch", "Tensor") + or _issubclass_fast(cls, "sparse", "SparseArray") + ): return False - return is_array_api_obj(x) + if ( + _issubclass_fast(cls, "jax", "Array") + or _issubclass_fast(cls, "dask.array", "Array") + or _issubclass_fast(cls, "ndonnx", "Array") + ): + return True + return None def is_lazy_array(x: object) -> bool: @@ -969,14 +997,6 @@ def is_lazy_array(x: object) -> bool: This function errs on the side of caution for array types that may or may not be lazy, e.g. JAX arrays, by always returning True for them. """ - if ( - is_numpy_array(x) - or is_cupy_array(x) - or is_torch_array(x) - or is_pydata_sparse_array(x) - ): - return False - # **JAX note:** while it is possible to determine if you're inside or outside # jax.jit by testing the subclass of a jax.Array object, as well as testing bool() # as we do below for unknown arrays, this is not recommended by JAX best practices. @@ -986,10 +1006,14 @@ def is_lazy_array(x: object) -> bool: # compatibility, is highly detrimental to performance as the whole graph will end # up being computed multiple times. - if is_jax_array(x) or is_dask_array(x) or is_ndonnx_array(x): - return True + # Note: skipping reclassification of JAX zero gradient arrays, as one will + # exclusively get them once they leave a jax.grad JIT context. + cls = cast(Hashable, type(x)) + res = _is_lazy_cls(cls) + if res is not None: + return res - if not is_array_api_obj(x): + if not hasattr(x, "__array_namespace__"): return False # Unknown Array API compatible object. Note that this test may have dire consequences @@ -1042,7 +1066,7 @@ def is_lazy_array(x: object) -> bool: "to_device", ] -_all_ignore = ["sys", "math", "inspect", "warnings"] +_all_ignore = ['lru_cache', 'sys', 'math', 'inspect', 'warnings'] def __dir__() -> list[str]: return __all__