Skip to content

Commit

Permalink
add runtime typechecks to jax tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 17, 2024
1 parent 74cd498 commit d94714b
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/sdist/amici/jax.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _tcl(self, x, pk):

return TPL_TOTAL_CL_RET

def y(self, t, x, pk, tcl):
def _y(self, t, x, pk, tcl):

TPL_X_SYMS = x
TPL_PK_SYMS = pk
Expand Down
8 changes: 4 additions & 4 deletions python/sdist/amici/jax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,16 +410,16 @@ def _sigmays(
def simulate_condition(
self,
p: jt.Float[jt.Array, "np"],
p_preeq: jt.Float[jt.Array, "?np"],
p_preeq: jt.Float[jt.Array, "*np"],
ts_preeq: jt.Float[jt.Array, "nt_preeq"],
ts_dyn: jt.Float[jt.Array, "nt_dyn"],
ts_posteq: jt.Float[jt.Array, "nt_posteq"],
my: jt.Float[jt.Array, "nt_preeq+nt_dyn+nt_posteq"],
iys: jt.Float[jt.Array, "nt_preeq+nt_dyn+nt_posteq"],
my: jt.Float[jt.Array, "nt"],
iys: jt.Int[jt.Array, "nt"],
solver: diffrax.AbstractSolver,
controller: diffrax.AbstractStepSizeController,
adjoint: diffrax.AbstractAdjoint,
max_steps: jnp.int_,
max_steps: int | jnp.int_,
ret: str = "llh",
):
r"""
Expand Down
3 changes: 2 additions & 1 deletion python/sdist/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ test = [
# unsupported x86_64 / x86_64h
"antimony!=2.14; platform_system=='Darwin' and platform_machine in 'x86_64h'",
"scipy",
"pooch"
"pooch",
"beartype",
]
vis = [
"matplotlib",
Expand Down
3 changes: 2 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import jax
import diffrax
import numpy as np
from beartype import beartype

from amici.pysb_import import pysb2amici
from numpy.testing import assert_allclose
Expand Down Expand Up @@ -158,7 +159,7 @@ def check_fields_jax(
diffrax.RecursiveCheckpointAdjoint(), # adjoint
2**8, # max_steps
)
fun = jax_model.simulate_condition
fun = beartype(jax_model.simulate_condition)

for output in ["llh", "x0", "x", "y", "res"]:
oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output)
Expand Down
3 changes: 2 additions & 1 deletion tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from amici.jax.petab import run_simulations, JAXProblem
from petab.v1.visualize import plot_problem
from beartype import beartype

jax.config.update("jax_enable_x64", True)

Expand Down Expand Up @@ -354,7 +355,7 @@ def test_jax_llh(benchmark_problem):
eqx.filter_value_and_grad(run_simulations, has_aux=True)
)(jax_problem, simulation_conditions)
else:
llh_jax, _ = eqx.filter_jit(run_simulations)(
llh_jax, _ = beartype(eqx.filter_jit(run_simulations))(
jax_problem, simulation_conditions
)

Expand Down

0 comments on commit d94714b

Please sign in to comment.