Skip to content

Commit

Permalink
fix python jax tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 12, 2024
1 parent 4a5e7d2 commit f745be0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
12 changes: 6 additions & 6 deletions python/sdist/amici/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _run(
my: jnp.ndarray,
pscale: np.ndarray,
checkpointed=True,
dynamic=True,
dynamic="true",
):
ps = self.unscale_p(p, pscale)

Expand Down Expand Up @@ -227,11 +227,11 @@ def _run(
else:
x = x_posteq

Check warning on line 228 in python/sdist/amici/jax.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/jax.py#L228

Added line #L228 was not covered by tests

obs = self._obs(ts, x, ps, k, tcl)
obs = jnp.stack(self._obs(ts, x, ps, k, tcl), axis=1)
my_r = my.reshape((len(ts), -1))
sigmay = self._sigmay(obs, ps, k)
llh = self._loss(obs, sigmay, my_r)
x_rdata = self._x_rdata(x, tcl)
x_rdata = jnp.stack(self._x_rdata(x, tcl), axis=1)
return llh, (x_rdata, obs, stats)

@eqx.filter_jit
Expand All @@ -244,7 +244,7 @@ def run(
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
dynamic=True,
dynamic="true",
):
return self._run(
ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic
Expand All @@ -260,7 +260,7 @@ def srun(
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
dynamic=True,
dynamic="true",
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 2, True)
Expand All @@ -277,7 +277,7 @@ def s2run(
k_preeq: np.ndarray,
my: np.ndarray,
pscale: np.ndarray,
dynamic=True,
dynamic="true",
):
(llh, (x, obs, stats)), sllh = (
jax.value_and_grad(self._run, 2, True)
Expand Down
56 changes: 41 additions & 15 deletions python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,33 +100,29 @@ def _test_model(model_module, ts, p, k):

amici_model.setParameters(np.asarray(p, dtype=np.float64))
amici_model.setFixedParameters(np.asarray(k, dtype=np.float64))
edatas = (
amici.ExpData(sol_amici_ref, 1.0, 1.0),
amici.ExpData(sol_amici_ref, 1.0, 1.0),
)
for edata in edatas:
edata.parameters = amici_model.getParameters()
edata.fixedParameters = amici_model.getFixedParameters()
edata.pscale = amici_model.getParameterScale()
edata = amici.ExpData(sol_amici_ref, 1.0, 1.0)
edata.parameters = amici_model.getParameters()
edata.fixedParameters = amici_model.getFixedParameters()
edata.pscale = amici_model.getParameterScale()
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, edatas)
rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata])

check_fields_jax(rs_amici, jax_model, edatas, ["x", "y", "llh"])
check_fields_jax(rs_amici, jax_model, edata, ["x", "y", "llh"])

check_fields_jax(
rs_amici,
jax_model,
edatas,
edata,
["x", "y", "llh", "sllh"],
sensi_order=amici.SensitivityOrder.first,
)

check_fields_jax(
rs_amici,
jax_model,
edatas,
edata,
["x", "y", "llh", "sllh"],
sensi_order=amici.SensitivityOrder.second,
)
Expand All @@ -135,13 +131,43 @@ def _test_model(model_module, ts, p, k):
def check_fields_jax(
rs_amici,
jax_model,
edatas,
edata,
fields,
sensi_order=amici.SensitivityOrder.none,
):
rs_jax = jax_model.run_simulations(edatas, sensitivity_order=sensi_order)
r_jax = dict()
kwargs = {
"ts": np.array(edata.getTimepoints()),
"ts_dyn": np.array(edata.getTimepoints()),
"p": np.array(edata.parameters),
"k": np.array(edata.fixedParameters),
"k_preeq": np.array([]),
"my": np.array(edata.getObservedData()).reshape(
np.array(edata.getTimepoints()).shape[0], -1
),
"pscale": np.array(edata.pscale),
}
if sensi_order == amici.SensitivityOrder.none:
(
r_jax["llh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.run(**kwargs)
elif sensi_order == amici.SensitivityOrder.first:
(
r_jax["llh"],
r_jax["sllh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.srun(**kwargs)
elif sensi_order == amici.SensitivityOrder.second:
(
r_jax["llh"],
r_jax["sllh"],
r_jax["s2llh"],
(r_jax["x"], r_jax["y"], r_jax["stats"]),
) = jax_model.s2run(**kwargs)

for field in fields:
for r_amici, r_jax in zip(rs_amici, rs_jax):
for r_amici, r_jax in zip(rs_amici, [r_jax]):
assert_allclose(
actual=r_amici[field],
desired=r_jax[field],
Expand Down

0 comments on commit f745be0

Please sign in to comment.