diff --git a/python/sdist/amici/jax.template.py b/python/sdist/amici/jax.template.py index b9b37c8402..67a9decf07 100644 --- a/python/sdist/amici/jax.template.py +++ b/python/sdist/amici/jax.template.py @@ -65,7 +65,7 @@ def _tcl(self, x, pk): return TPL_TOTAL_CL_RET - def y(self, t, x, pk, tcl): + def _y(self, t, x, pk, tcl): TPL_X_SYMS = x TPL_PK_SYMS = pk diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 2534728a96..22f994229d 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -410,16 +410,16 @@ def _sigmays( def simulate_condition( self, p: jt.Float[jt.Array, "np"], - p_preeq: jt.Float[jt.Array, "?np"], + p_preeq: jt.Float[jt.Array, "*np"], ts_preeq: jt.Float[jt.Array, "nt_preeq"], ts_dyn: jt.Float[jt.Array, "nt_dyn"], ts_posteq: jt.Float[jt.Array, "nt_posteq"], - my: jt.Float[jt.Array, "nt_preeq+nt_dyn+nt_posteq"], - iys: jt.Float[jt.Array, "nt_preeq+nt_dyn+nt_posteq"], + my: jt.Float[jt.Array, "nt"], + iys: jt.Int[jt.Array, "nt"], solver: diffrax.AbstractSolver, controller: diffrax.AbstractStepSizeController, adjoint: diffrax.AbstractAdjoint, - max_steps: jnp.int_, + max_steps: int | jnp.int_, ret: str = "llh", ): r""" diff --git a/python/sdist/pyproject.toml b/python/sdist/pyproject.toml index 0635aec0aa..c2a20fd0f2 100644 --- a/python/sdist/pyproject.toml +++ b/python/sdist/pyproject.toml @@ -71,7 +71,8 @@ test = [ # unsupported x86_64 / x86_64h "antimony!=2.14; platform_system=='Darwin' and platform_machine in 'x86_64h'", "scipy", - "pooch" + "pooch", + "beartype", ] vis = [ "matplotlib", diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 0e1c48eb34..d66f258e24 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -8,6 +8,7 @@ import jax import diffrax import numpy as np +from beartype import beartype from amici.pysb_import import pysb2amici from numpy.testing import assert_allclose @@ -158,7 +159,7 @@ def check_fields_jax( diffrax.RecursiveCheckpointAdjoint(), # adjoint 2**8, # max_steps ) - fun = jax_model.simulate_condition + fun = beartype(jax_model.simulate_condition) for output in ["llh", "x0", "x", "y", "res"]: oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output) diff --git a/tests/benchmark-models/test_petab_benchmark.py b/tests/benchmark-models/test_petab_benchmark.py index c25356ed33..132402f3c8 100644 --- a/tests/benchmark-models/test_petab_benchmark.py +++ b/tests/benchmark-models/test_petab_benchmark.py @@ -39,6 +39,7 @@ ) from amici.jax.petab import run_simulations, JAXProblem from petab.v1.visualize import plot_problem +from beartype import beartype jax.config.update("jax_enable_x64", True) @@ -354,7 +355,7 @@ def test_jax_llh(benchmark_problem): eqx.filter_value_and_grad(run_simulations, has_aux=True) )(jax_problem, simulation_conditions) else: - llh_jax, _ = eqx.filter_jit(run_simulations)( + llh_jax, _ = beartype(eqx.filter_jit(run_simulations))( jax_problem, simulation_conditions )