Skip to content

Commit 6039da8

Browse files
authored
feat: add main feat (#10)
1 parent 48518a6 commit 6039da8

File tree

10 files changed

+1506
-558
lines changed

10 files changed

+1506
-558
lines changed

.github/workflows/ci.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ jobs:
3636
fail-fast: false
3737
matrix:
3838
python-version:
39-
- "3.9"
4039
- "3.10"
4140
- "3.11"
4241
- "3.12"

.pre-commit-config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,10 @@ repos:
5555
hooks:
5656
- id: mypy
5757
additional_dependencies: []
58+
- repo: https://github.com/adamtheturtle/doccmd-pre-commit
59+
rev: v2025.4.8
60+
hooks:
61+
- id: doccmd
62+
args:
63+
["--language", "python", "--no-pad-file", "--command", "ruff format"]
64+
additional_dependencies: ["ruff"]

README.md

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,49 @@ JIT decorator supporting multiple array API compatible libraries
4444

4545
Install this via pip (or your favourite package manager):
4646

47-
`pip install array-api-jit`
47+
```shell
48+
pip install array-api-jit
49+
```
50+
51+
## Usage
52+
53+
Simply decorate your function with `@jit()`:
54+
55+
```python
56+
from array_api_jit import jit
57+
58+
59+
@jit()
60+
def my_function(x: Any) -> Any:
61+
xp = array_namespace(x)
62+
return xp.sin(x) + xp.cos(x)
63+
```
64+
65+
## Advanced Usage
66+
67+
You can specify the decorator, arguments, and keyword arguments for each library.
68+
69+
```python
70+
from array_api_jit import jit
71+
from array_api_compat import array_namespace
72+
from typing import Any
73+
import numba
74+
75+
76+
@jit(
77+
{"numpy": numba.jit()}, # numba.jit is not used by default because it may not succeed
78+
decorator_kwargs={
79+
"jax": {"static_argnames": ["n"]}
80+
}, # jax requires for-loop variable to be "static_argnames"
81+
# fail_on_error: bool = False, # do not raise an error if the decorator fails (Default)
82+
# rerun_on_error: bool = True, # re-run the original function if the wrapped function fails (NOT Default)
83+
)
84+
def sin_n_times(x: Any, n: int) -> Any:
85+
xp = array_namespace(x)
86+
for i in range(n):
87+
x = xp.sin(x)
88+
return x
89+
```
4890

4991
## Contributors ✨
5092

pyproject.toml

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@ license = { text = "MIT" }
1111
authors = [
1212
{ name = "34j", email = "[email protected]" },
1313
]
14-
requires-python = ">=3.9"
14+
requires-python = ">=3.10"
1515
classifiers = [
1616
"Development Status :: 2 - Pre-Alpha",
1717
"Intended Audience :: Developers",
1818
"Natural Language :: English",
1919
"Operating System :: OS Independent",
20-
"Programming Language :: Python :: 3.9",
2120
"Programming Language :: Python :: 3.10",
2221
"Programming Language :: Python :: 3.11",
2322
"Programming Language :: Python :: 3.12",
@@ -26,6 +25,9 @@ classifiers = [
2625
]
2726

2827
dependencies = [
28+
"array-api-compat>=1.11.2",
29+
"cm-time>=0.1.2",
30+
"frozendict>=2.4.6",
2931
]
3032
urls."Bug Tracker" = "https://github.com/34j/array-api-jit/issues"
3133
urls.Changelog = "https://github.com/34j/array-api-jit/blob/main/CHANGELOG.md"
@@ -34,8 +36,11 @@ urls.repository = "https://github.com/34j/array-api-jit"
3436

3537
[dependency-groups]
3638
dev = [
39+
"jax>=0.4.30",
40+
"numba>=0.60.0",
3741
"pytest>=8,<9",
3842
"pytest-cov>=6,<7",
43+
"torch>=2.7.1",
3944
]
4045
docs = [
4146
"furo>=2023.5.20; python_version>='3.11'",
@@ -45,7 +50,8 @@ docs = [
4550
]
4651

4752
[tool.ruff]
48-
line-length = 88
53+
format.docstring-code-format = true
54+
line-length = 100
4955
lint.select = [
5056
"B", # flake8-bugbear
5157
"D", # flake8-docstrings
@@ -83,9 +89,6 @@ lint.isort.known-first-party = [ "array_api_jit", "tests" ]
8389
addopts = """\
8490
-v
8591
-Wdefault
86-
--cov=array_api_jit
87-
--cov-report=term
88-
--cov-report=xml
8992
"""
9093
pythonpath = [ "src" ]
9194

src/array_api_jit/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
11
__version__ = "0.0.0"
2+
from ._main import jit
3+
4+
__all__ = ["jit"]

src/array_api_jit/_main.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# https://github.com/search?q=gumerov+translation+language%3APython&type=code&l=Python
2+
import importlib.util
3+
import warnings
4+
from collections.abc import Callable, Mapping, Sequence
5+
from functools import cache, wraps
6+
from types import ModuleType
7+
from typing import Any, ParamSpec, TypeVar
8+
9+
from array_api_compat import (
10+
array_namespace,
11+
is_cupy_namespace,
12+
is_dask_namespace,
13+
is_jax_namespace,
14+
is_numpy_namespace,
15+
is_torch_namespace,
16+
)
17+
from frozendict import frozendict
18+
19+
if importlib.util.find_spec("numba"):
20+
import numpy as np
21+
from numba.extending import overload
22+
23+
@overload(array_namespace)
24+
def _array_namespace_overload(*args: Any) -> Any:
25+
def inner(*args: Any) -> Any:
26+
return np
27+
28+
return inner
29+
30+
31+
P = ParamSpec("P")
32+
T = TypeVar("T")
33+
Pin = ParamSpec("Pin")
34+
Tin = TypeVar("Tin")
35+
Pinner = ParamSpec("Pinner")
36+
Tinner = TypeVar("Tinner")
37+
STR_TO_IS_NAMESPACE = {
38+
"numpy": is_numpy_namespace,
39+
"jax": is_jax_namespace,
40+
"cupy": is_cupy_namespace,
41+
"torch": is_torch_namespace,
42+
"dask": is_dask_namespace,
43+
}
44+
45+
46+
def _default_decorator(
47+
module: ModuleType,
48+
/,
49+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
50+
if is_jax_namespace(module):
51+
import jax
52+
53+
return jax.jit
54+
elif is_numpy_namespace(module) or is_cupy_namespace(module):
55+
# import numba
56+
57+
# return numba.jit()
58+
# The success rate of numba.jit is low
59+
return lambda x: x
60+
elif is_torch_namespace(module):
61+
import torch
62+
63+
return torch.compile
64+
elif is_dask_namespace(module):
65+
return lambda x: x
66+
else:
67+
return getattr(module, "jit", lambda x: x)
68+
69+
70+
Decorator = Callable[[Callable[Pin, Tin]], Callable[Pin, Tin]]
71+
72+
73+
def jit(
74+
decorator: Mapping[str, Decorator[..., Any]] | None = None,
75+
/,
76+
*,
77+
fail_on_error: bool = False,
78+
rerun_on_error: bool = False,
79+
decorator_args: Mapping[str, Sequence[Any]] | None = None,
80+
decorator_kwargs: Mapping[str, Mapping[str, Any]] | None = None,
81+
) -> Callable[[Callable[P, T]], Callable[P, T]]:
82+
"""
83+
Just-in-time compilation decorator with multiple backends.
84+
85+
Parameters
86+
----------
87+
decorator : Mapping[str, Callable[[Callable[P, T]], Callable[P, T]]] | None, optional
88+
The JIT decorator to use for each array namespace, by default None
89+
fail_on_error : bool, optional
90+
If True, raise an error if the JIT decorator fails to apply.
91+
If False, just warn and return the original function, by default False
92+
rerun_on_error : bool, optional
93+
If True, rerun the function without JIT if the function
94+
with JIT applied fails, by default False
95+
decorator_args : Mapping[str, Sequence[Any]] | None, optional
96+
Additional positional arguments to be passed along with the function
97+
to the decorator for each array namespace, by default None
98+
decorator_kwargs : Mapping[str, Mapping[str, Any]] | None, optional
99+
Additional keyword arguments to be passed along with the function
100+
to the decorator for each array namespace, by default None
101+
102+
Returns
103+
-------
104+
Callable[[Callable[P, T]], Callable[P, T]]
105+
The JIT decorator that can be applied to a function.
106+
107+
Example
108+
-------
109+
>>> from array_api_jit import jit
110+
>>> from array_api_compat import array_namespace
111+
>>> from typing import Any
112+
>>> import numba
113+
>>> @jit(
114+
... {"numpy": numba.jit()}, # numba.jit is not used by default
115+
... decorator_kwargs={"jax": {"static_argnames": ["n"]}}, # jax requires static_argnames
116+
... )
117+
... def sin_n_times(x: Any, n: int) -> Any:
118+
... xp = array_namespace(x)
119+
... for i in range(n):
120+
... x = xp.sin(x)
121+
... return x
122+
123+
"""
124+
125+
def new_decorator(f: Callable[Pinner, Tinner]) -> Callable[Pinner, Tinner]:
126+
decorator_args_ = frozendict(decorator_args or {})
127+
decorator_kwargs_ = frozendict(decorator_kwargs or {})
128+
decorator_ = decorator or {}
129+
130+
@cache
131+
def jit_cached(xp: ModuleType) -> Callable[Pinner, Tinner]:
132+
for name_, is_namespace in STR_TO_IS_NAMESPACE.items():
133+
if is_namespace(xp):
134+
name = name_
135+
else:
136+
name = xp.__name__.split(".")[0]
137+
decorator_args__ = decorator_args_.get(name, ())
138+
decorator_kwargs__ = decorator_kwargs_.get(name, {})
139+
if name in decorator_:
140+
decorator_current = decorator_[name]
141+
else:
142+
decorator_current = _default_decorator(xp)
143+
try:
144+
return decorator_current(f, *decorator_args__, **decorator_kwargs__)
145+
except Exception as e:
146+
if fail_on_error:
147+
raise RuntimeError(f"Failed to apply JIT decorator for {name}") from e
148+
warnings.warn(
149+
f"Failed to apply JIT decorator for {name}: {e}",
150+
RuntimeWarning,
151+
stacklevel=2,
152+
)
153+
return f
154+
155+
@wraps(f)
156+
def inner(*args_inner: Pinner.args, **kwargs_inner: Pinner.kwargs) -> Tinner:
157+
try:
158+
xp = array_namespace(*args_inner)
159+
except TypeError as e:
160+
if e.args[0] == "Unrecognized array input":
161+
return f(*args_inner, **kwargs_inner)
162+
raise
163+
f_jit = jit_cached(xp)
164+
try:
165+
return f_jit(*args_inner, **kwargs_inner)
166+
except Exception as e:
167+
if rerun_on_error:
168+
warnings.warn(
169+
f"JIT failed for {xp.__name__}: {e}. Rerunning without JIT.",
170+
RuntimeWarning,
171+
stacklevel=2,
172+
)
173+
return f(*args_inner, **kwargs_inner)
174+
raise RuntimeError(f"Failed to run JIT function for {xp.__name__}") from e
175+
176+
return inner
177+
178+
return new_decorator

src/array_api_jit/main.py

Lines changed: 0 additions & 3 deletions
This file was deleted.

tests/conftest.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from typing import Any
2+
3+
import pytest
4+
5+
6+
@pytest.fixture(scope="session", params=["numpy", "torch", "jax"])
7+
def xp(request: pytest.FixtureRequest) -> Any:
8+
"""Get the array namespace for the given backend."""
9+
backend = request.param
10+
if backend == "numpy":
11+
import numpy as xp
12+
elif backend == "torch":
13+
import torch as xp
14+
elif backend == "jax":
15+
import jax.numpy as xp
16+
else:
17+
raise ValueError(f"Unknown backend: {backend}")
18+
return xp

0 commit comments

Comments
 (0)