Documentation: https://array-api-jit.readthedocs.io
Source Code: https://github.com/34j/array-api-jit
JIT decorator supporting multiple array API compatible libraries
Install this via pip (or your favourite package manager):
pip install array-api-jit
Simply decorate your function with @jit()
:
from array_api_jit import jit
@jit()
def my_function(x: Any) -> Any:
xp = array_namespace(x)
return xp.sin(x) + xp.cos(x)
You can specify the decorator, arguments, and keyword arguments for each library.
from array_api_jit import jit
from array_api_compat import array_namespace
from typing import Any
import numba
@jit(
{"numpy": numba.jit()}, # numba.jit is not used by default because it may not succeed
decorator_kwargs={
"jax": {"static_argnames": ["n"]}
}, # jax requires for-loop variable to be "static_argnames"
# fail_on_error: bool = False, # do not raise an error if the decorator fails (Default)
# rerun_on_error: bool = True, # re-run the original function if the wrapped function fails (NOT Default)
)
def sin_n_times(x: Any, n: int) -> Any:
xp = array_namespace(x)
for i in range(n):
x = xp.sin(x)
return x
Thanks goes to these wonderful people (emoji key):
This project follows the all-contributors specification. Contributions of any kind welcome!
This package was created with Copier and the browniebroke/pypackage-template project template.