Skip to content

Commit

Permalink
fix jax test
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Dec 2, 2024
1 parent 3f05c7e commit 2181559
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def test_dimerization():
observables=["a_obs", "b_obs"],
constant_parameters=["ksyn_a", "ksyn_b"],
)
pysb2jax(model, outdir, verbose=True, observables=["ab"])
pysb2jax(
model,
outdir,
observables=["a_obs", "b_obs"],
)

amici_module = amici.import_model_module(
module_name=model.name, module_path=outdir
Expand Down Expand Up @@ -137,12 +141,19 @@ def _test_model(amici_module, jax_module, ts, p, k):
rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata])

check_fields_jax(
rs_amici, jax_model, edata, ["x", "y", "llh", "res", "x0"]
rs_amici,
jax_model,
amici_model.getParameterIds(),
amici_model.getFixedParameterIds(),
edata,
["x", "y", "llh", "res", "x0"],
)

check_fields_jax(
rs_amici,
jax_model,
amici_model.getParameterIds(),
amici_model.getFixedParameterIds(),
edata,
["sllh", "sx0", "sx", "sres", "sy"],
sensi_order=amici.SensitivityOrder.first,
Expand All @@ -152,6 +163,8 @@ def _test_model(amici_module, jax_module, ts, p, k):
def check_fields_jax(
rs_amici,
jax_model,
parameter_ids,
fixed_parameter_ids,
edata,
fields,
sensi_order=amici.SensitivityOrder.none,
Expand All @@ -168,7 +181,13 @@ def check_fields_jax(
ts_preeq = ts[ts == 0]
ts_dyn = ts[ts > 0]
ts_posteq = np.array([])
p = jnp.array(list(edata.parameters) + list(edata.fixedParameters))

par_dict = {
**dict(zip(parameter_ids, edata.parameters)),
**dict(zip(fixed_parameter_ids, edata.fixedParameters)),
}

p = jnp.array([par_dict[par_id] for par_id in jax_model.parameter_ids])
args = (
jnp.array([]), # p_preeq
jnp.array(ts_preeq), # ts_preeq
Expand All @@ -195,6 +214,10 @@ def check_fields_jax(
0
]

amici_par_idx = np.array(
[jax_model.parameter_ids.index(par_id) for par_id in parameter_ids]
)

for field in fields:
for r_amici, r_jax in zip(rs_amici, [r_jax]):
actual = r_jax[field]
Expand All @@ -207,26 +230,26 @@ def check_fields_jax(
axis=1,
)
elif field == "sllh":
actual = actual[: len(edata.parameters)]
actual = actual[amici_par_idx]
elif field == "sx":
actual = np.permute_dims(
actual[iys == 0, :, : len(edata.parameters)], (0, 2, 1)
)
actual = actual[:, :, amici_par_idx]
actual = np.permute_dims(actual[iys == 0, :, :], (0, 2, 1))
elif field == "sy":
actual = actual[:, amici_par_idx]
actual = np.permute_dims(
np.stack(
[
actual[iys == iy, : len(edata.parameters)]
actual[iys == iy, :]
for iy in sorted(np.unique(iys))
],
axis=1,
),
(0, 2, 1),
)
elif field == "sx0":
actual = actual[:, : len(edata.parameters)].T
actual = actual[:, amici_par_idx].T
elif field == "sres":
actual = actual[:, : len(edata.parameters)]
actual = actual[:, amici_par_idx]

assert_allclose(
actual=actual,
Expand Down

0 comments on commit 2181559

Please sign in to comment.