From 1f0af69bc478337db999dcb2a85c863462011f09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 15:45:05 +0000 Subject: [PATCH] fix tests --- python/tests/test_jax.py | 46 +++++++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 8f4c68510b..ce7018e078 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -47,7 +47,7 @@ def test_conversion(): module_name=model.name, module_path=outdir ) jax_module = amici.import_model_module( - module_name=model.name + "_jax", module_path=outdir + module_name=Path(outdir).stem, module_path=Path(outdir).parent ) ts = tuple(np.linspace(0, 1, 10)) @@ -108,7 +108,7 @@ def test_dimerization(): module_name=model.name, module_path=outdir ) jax_module = amici.import_model_module( - module_name=model.name + "_jax", module_path=outdir + module_name=Path(outdir).stem, module_path=Path(outdir).parent ) ts = tuple(np.linspace(0, 1, 10)) @@ -178,7 +178,7 @@ def check_fields_jax( ts = ts.flatten() iys = iys.flatten() - ts_preeq = ts[ts == 0] + ts_init = ts[ts == 0] ts_dyn = ts[ts > 0] ts_posteq = np.array([]) @@ -188,31 +188,37 @@ def check_fields_jax( } p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids]) - args = ( - jnp.array([]), # p_preeq - jnp.array(ts_preeq), # ts_preeq - jnp.array(ts_dyn), # ts_dyn - jnp.array(ts_posteq), # ts_posteq - jnp.array(my), # my - jnp.array(iys), # iys - diffrax.Kvaerno5(), # solver - diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller - diffrax.RecursiveCheckpointAdjoint(), # adjoint - 2**8, # max_steps - ) + kwargs = { + "ts_init": jnp.array(ts_init), + "ts_dyn": jnp.array(ts_dyn), + "ts_posteq": jnp.array(ts_posteq), + "my": jnp.array(my), + "iys": jnp.array(iys), + "x_preeq": jnp.array([]), + "solver": diffrax.Kvaerno5(), + "controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), + "adjoint": diffrax.RecursiveCheckpointAdjoint(), + "max_steps": 2**8, # max_steps + } fun = beartype(jax_model.simulate_condition) for output in ["llh", "x0", "x", "y", "res"]: - oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output) + okwargs = kwargs | { + "adjoint": diffrax.DirectAdjoint(), + "max_steps": 2**8, + "ret": output, + } if sensi_order == amici.SensitivityOrder.none: - r_jax[output] = fun(p, *oargs)[0] + r_jax[output] = fun(p, **okwargs)[0] if sensi_order == amici.SensitivityOrder.first: if output == "llh": - r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, *args)[0] - else: - r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)(p, *oargs)[ + r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, **kwargs)[ 0 ] + else: + r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)( + p, **okwargs + )[0] amici_par_idx = np.array( [jax_model.parameter_ids.index(par_id) for par_id in parameter_ids]