diff --git a/numba_scipy/special/overloads.py b/numba_scipy/special/overloads.py index 31b90ae..67d2e0d 100644 --- a/numba_scipy/special/overloads.py +++ b/numba_scipy/special/overloads.py @@ -1,3 +1,5 @@ +from textwrap import dedent + import numba import scipy.special as sc @@ -7,10 +9,46 @@ def choose_kernel(name, all_signatures): def choice_function(*args): + scalar_args = () + has_arrays = False + for a in args: + if isinstance(a, numba.types.Array): + scalar_args += (a.dtype,) + has_arrays = True + else: + scalar_args += (a,) + for signature in all_signatures: - if args == signature: + if scalar_args == signature: f = signatures.name_and_types_to_pointer[(name, *signature)] - return lambda *args: f(*args) + + if has_arrays: + + args_str = ", ".join([f"arg{i}" for i in range(len(args))]) + + global_env = {"f": f, "numba": numba} + + vectorized_fn_src = dedent(f""" + @numba.vectorize + def f_vec({args_str}): + return f({args_str}) + """) + + mod_code = compile(vectorized_fn_src, "", mode="exec") + local_env = {} + exec(mod_code, global_env, local_env) + + f_vec = local_env["f_vec"] + + def f_res(*args): + return f_vec(*args) + + else: + + def f_res(*args): + return f(*args) + + return f_res return choice_function diff --git a/numba_scipy/tests/test_special.py b/numba_scipy/tests/test_special.py index 6933332..59718cf 100644 --- a/numba_scipy/tests/test_special.py +++ b/numba_scipy/tests/test_special.py @@ -3,6 +3,7 @@ import pytest + import numpy as np from numpy.testing import assert_allclose import numba @@ -45,11 +46,9 @@ def compare_functions(args, scipy_func, numba_func): for arg in args: - overload_value = numba_func(*arg) - scipy_value = scipy_func(*arg) - if np.isnan(overload_value): - assert np.isnan(scipy_value) - else: + for arg in [arg, [np.repeat(a, 2) for a in arg]]: + overload_value = numba_func(*arg) + scipy_value = scipy_func(*arg) rtol = 2**8 * np.finfo(scipy_value.dtype).eps assert_allclose(overload_value, scipy_value, atol=0, rtol=rtol)