|
| 1 | +# https://github.com/search?q=gumerov+translation+language%3APython&type=code&l=Python |
| 2 | +import importlib.util |
| 3 | +import warnings |
| 4 | +from collections.abc import Callable, Mapping, Sequence |
| 5 | +from functools import cache, wraps |
| 6 | +from types import ModuleType |
| 7 | +from typing import Any, ParamSpec, TypeVar |
| 8 | + |
| 9 | +from array_api_compat import ( |
| 10 | + array_namespace, |
| 11 | + is_cupy_namespace, |
| 12 | + is_dask_namespace, |
| 13 | + is_jax_namespace, |
| 14 | + is_numpy_namespace, |
| 15 | + is_torch_namespace, |
| 16 | +) |
| 17 | +from frozendict import frozendict |
| 18 | + |
| 19 | +if importlib.util.find_spec("numba"): |
| 20 | + import numpy as np |
| 21 | + from numba.extending import overload |
| 22 | + |
| 23 | + @overload(array_namespace) |
| 24 | + def _array_namespace_overload(*args: Any) -> Any: |
| 25 | + def inner(*args: Any) -> Any: |
| 26 | + return np |
| 27 | + |
| 28 | + return inner |
| 29 | + |
| 30 | + |
| 31 | +P = ParamSpec("P") |
| 32 | +T = TypeVar("T") |
| 33 | +Pin = ParamSpec("Pin") |
| 34 | +Tin = TypeVar("Tin") |
| 35 | +Pinner = ParamSpec("Pinner") |
| 36 | +Tinner = TypeVar("Tinner") |
| 37 | +STR_TO_IS_NAMESPACE = { |
| 38 | + "numpy": is_numpy_namespace, |
| 39 | + "jax": is_jax_namespace, |
| 40 | + "cupy": is_cupy_namespace, |
| 41 | + "torch": is_torch_namespace, |
| 42 | + "dask": is_dask_namespace, |
| 43 | +} |
| 44 | + |
| 45 | + |
| 46 | +def _default_decorator( |
| 47 | + module: ModuleType, |
| 48 | + /, |
| 49 | +) -> Callable[[Callable[P, T]], Callable[P, T]]: |
| 50 | + if is_jax_namespace(module): |
| 51 | + import jax |
| 52 | + |
| 53 | + return jax.jit |
| 54 | + elif is_numpy_namespace(module) or is_cupy_namespace(module): |
| 55 | + # import numba |
| 56 | + |
| 57 | + # return numba.jit() |
| 58 | + # The success rate of numba.jit is low |
| 59 | + return lambda x: x |
| 60 | + elif is_torch_namespace(module): |
| 61 | + import torch |
| 62 | + |
| 63 | + return torch.compile |
| 64 | + elif is_dask_namespace(module): |
| 65 | + return lambda x: x |
| 66 | + else: |
| 67 | + return getattr(module, "jit", lambda x: x) |
| 68 | + |
| 69 | + |
| 70 | +Decorator = Callable[[Callable[Pin, Tin]], Callable[Pin, Tin]] |
| 71 | + |
| 72 | + |
| 73 | +def jit( |
| 74 | + decorator: Mapping[str, Decorator[..., Any]] | None = None, |
| 75 | + /, |
| 76 | + *, |
| 77 | + fail_on_error: bool = False, |
| 78 | + rerun_on_error: bool = False, |
| 79 | + decorator_args: Mapping[str, Sequence[Any]] | None = None, |
| 80 | + decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None, |
| 81 | +) -> Callable[[Callable[P, T]], Callable[P, T]]: |
| 82 | + """ |
| 83 | + Just-in-time compilation decorator with multiple backends. |
| 84 | +
|
| 85 | + Parameters |
| 86 | + ---------- |
| 87 | + decorator : Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional |
| 88 | + The JIT decorator to use for each array namespace, by default None |
| 89 | + fail_on_error : bool, optional |
| 90 | + If True, raise an error if the JIT decorator fails to apply. |
| 91 | + If False, just warn and return the original function, by default False |
| 92 | + rerun_on_error : bool, optional |
| 93 | + If True, rerun the function without JIT if the function |
| 94 | + with JIT applied fails, by default False |
| 95 | + decorator_args : Mapping[str, Sequence[Any]] | None, optional |
| 96 | + Additional positional arguments to be passed along with the function |
| 97 | + to the decorator for each array namespace, by default None |
| 98 | + decorator_kwargs : Mapping[str, Mapping[str, Any]] | None, optional |
| 99 | + Additional keyword arguments to be passed along with the function |
| 100 | + to the decorator for each array namespace, by default None |
| 101 | +
|
| 102 | + Returns |
| 103 | + ------- |
| 104 | + Callable[[Callable[P, T]], Callable[P, T]] |
| 105 | + The JIT decorator that can be applied to a function. |
| 106 | +
|
| 107 | + Example |
| 108 | + ------- |
| 109 | + >>> from array_api_jit import jit |
| 110 | + >>> from array_api_compat import array_namespace |
| 111 | + >>> from typing import Any |
| 112 | + >>> import numba |
| 113 | + >>> @jit( |
| 114 | + ... {"numpy": numba.jit()}, # numba.jit is not used by default |
| 115 | + ... decorator_kwargs={"jax": {"static_argnames": ["n"]}}, # jax requires static_argnames |
| 116 | + ... ) |
| 117 | + ... def sin_n_times(x: Any, n: int) -> Any: |
| 118 | + ... xp = array_namespace(x) |
| 119 | + ... for i in range(n): |
| 120 | + ... x = xp.sin(x) |
| 121 | + ... return x |
| 122 | +
|
| 123 | + """ |
| 124 | + |
| 125 | + def new_decorator(f: Callable[Pinner, Tinner]) -> Callable[Pinner, Tinner]: |
| 126 | + decorator_args_ = frozendict(decorator_args or {}) |
| 127 | + decorator_kwargs_ = frozendict(decorator_kwargs or {}) |
| 128 | + decorator_ = decorator or {} |
| 129 | + |
| 130 | + @cache |
| 131 | + def jit_cached(xp: ModuleType) -> Callable[Pinner, Tinner]: |
| 132 | + for name_, is_namespace in STR_TO_IS_NAMESPACE.items(): |
| 133 | + if is_namespace(xp): |
| 134 | + name = name_ |
| 135 | + else: |
| 136 | + name = xp.__name__.split(".")[0] |
| 137 | + decorator_args__ = decorator_args_.get(name, ()) |
| 138 | + decorator_kwargs__ = decorator_kwargs_.get(name, {}) |
| 139 | + if name in decorator_: |
| 140 | + decorator_current = decorator_[name] |
| 141 | + else: |
| 142 | + decorator_current = _default_decorator(xp) |
| 143 | + try: |
| 144 | + return decorator_current(f, *decorator_args__, **decorator_kwargs__) |
| 145 | + except Exception as e: |
| 146 | + if fail_on_error: |
| 147 | + raise RuntimeError(f"Failed to apply JIT decorator for {name}") from e |
| 148 | + warnings.warn( |
| 149 | + f"Failed to apply JIT decorator for {name}: {e}", |
| 150 | + RuntimeWarning, |
| 151 | + stacklevel=2, |
| 152 | + ) |
| 153 | + return f |
| 154 | + |
| 155 | + @wraps(f) |
| 156 | + def inner(*args_inner: Pinner.args, **kwargs_inner: Pinner.kwargs) -> Tinner: |
| 157 | + try: |
| 158 | + xp = array_namespace(*args_inner) |
| 159 | + except TypeError as e: |
| 160 | + if e.args[0] == "Unrecognized array input": |
| 161 | + return f(*args_inner, **kwargs_inner) |
| 162 | + raise |
| 163 | + f_jit = jit_cached(xp) |
| 164 | + try: |
| 165 | + return f_jit(*args_inner, **kwargs_inner) |
| 166 | + except Exception as e: |
| 167 | + if rerun_on_error: |
| 168 | + warnings.warn( |
| 169 | + f"JIT failed for {xp.__name__}: {e}. Rerunning without JIT.", |
| 170 | + RuntimeWarning, |
| 171 | + stacklevel=2, |
| 172 | + ) |
| 173 | + return f(*args_inner, **kwargs_inner) |
| 174 | + raise RuntimeError(f"Failed to run JIT function for {xp.__name__}") from e |
| 175 | + |
| 176 | + return inner |
| 177 | + |
| 178 | + return new_decorator |
0 commit comments