Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 5, 2024
1 parent ba37d0b commit 1f0af69
Showing 1 changed file with 26 additions and 20 deletions.
46 changes: 26 additions & 20 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_conversion():
module_name=model.name, module_path=outdir
)
jax_module = amici.import_model_module(
module_name=model.name + "_jax", module_path=outdir
module_name=Path(outdir).stem, module_path=Path(outdir).parent
)

ts = tuple(np.linspace(0, 1, 10))
Expand Down Expand Up @@ -108,7 +108,7 @@ def test_dimerization():
module_name=model.name, module_path=outdir
)
jax_module = amici.import_model_module(
module_name=model.name + "_jax", module_path=outdir
module_name=Path(outdir).stem, module_path=Path(outdir).parent
)

ts = tuple(np.linspace(0, 1, 10))
Expand Down Expand Up @@ -178,7 +178,7 @@ def check_fields_jax(
ts = ts.flatten()
iys = iys.flatten()

ts_preeq = ts[ts == 0]
ts_init = ts[ts == 0]
ts_dyn = ts[ts > 0]
ts_posteq = np.array([])

Expand All @@ -188,31 +188,37 @@ def check_fields_jax(
}

p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids])
args = (
jnp.array([]), # p_preeq
jnp.array(ts_preeq), # ts_preeq
jnp.array(ts_dyn), # ts_dyn
jnp.array(ts_posteq), # ts_posteq
jnp.array(my), # my
jnp.array(iys), # iys
diffrax.Kvaerno5(), # solver
diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller
diffrax.RecursiveCheckpointAdjoint(), # adjoint
2**8, # max_steps
)
kwargs = {
"ts_init": jnp.array(ts_init),
"ts_dyn": jnp.array(ts_dyn),
"ts_posteq": jnp.array(ts_posteq),
"my": jnp.array(my),
"iys": jnp.array(iys),
"x_preeq": jnp.array([]),
"solver": diffrax.Kvaerno5(),
"controller": diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM),
"adjoint": diffrax.RecursiveCheckpointAdjoint(),
"max_steps": 2**8, # max_steps
}
fun = beartype(jax_model.simulate_condition)

for output in ["llh", "x0", "x", "y", "res"]:
oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output)
okwargs = kwargs | {
"adjoint": diffrax.DirectAdjoint(),
"max_steps": 2**8,
"ret": output,
}
if sensi_order == amici.SensitivityOrder.none:
r_jax[output] = fun(p, *oargs)[0]
r_jax[output] = fun(p, **okwargs)[0]
if sensi_order == amici.SensitivityOrder.first:
if output == "llh":
r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, *args)[0]
else:
r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)(p, *oargs)[
r_jax[f"s{output}"] = jax.grad(fun, has_aux=True)(p, **kwargs)[
0
]
else:
r_jax[f"s{output}"] = jax.jacfwd(fun, has_aux=True)(
p, **okwargs
)[0]

amici_par_idx = np.array(
[jax_model.parameter_ids.index(par_id) for par_id in parameter_ids]
Expand Down

0 comments on commit 1f0af69

Please sign in to comment.