diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 5d70a08aef..ec69f34361 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -186,7 +186,7 @@ def _run( my: jnp.ndarray, pscale: np.ndarray, checkpointed=True, - dynamic=True, + dynamic="true", ): ps = self.unscale_p(p, pscale) @@ -227,11 +227,11 @@ def _run( else: x = x_posteq - obs = self._obs(ts, x, ps, k, tcl) + obs = jnp.stack(self._obs(ts, x, ps, k, tcl), axis=1) my_r = my.reshape((len(ts), -1)) sigmay = self._sigmay(obs, ps, k) llh = self._loss(obs, sigmay, my_r) - x_rdata = self._x_rdata(x, tcl) + x_rdata = jnp.stack(self._x_rdata(x, tcl), axis=1) return llh, (x_rdata, obs, stats) @eqx.filter_jit @@ -244,7 +244,7 @@ def run( k_preeq: np.ndarray, my: np.ndarray, pscale: np.ndarray, - dynamic=True, + dynamic="true", ): return self._run( ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic @@ -260,7 +260,7 @@ def srun( k_preeq: np.ndarray, my: np.ndarray, pscale: np.ndarray, - dynamic=True, + dynamic="true", ): (llh, (x, obs, stats)), sllh = ( jax.value_and_grad(self._run, 2, True) @@ -277,7 +277,7 @@ def s2run( k_preeq: np.ndarray, my: np.ndarray, pscale: np.ndarray, - dynamic=True, + dynamic="true", ): (llh, (x, obs, stats)), sllh = ( jax.value_and_grad(self._run, 2, True) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 34fd70a201..5898262f90 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -100,25 +100,21 @@ def _test_model(model_module, ts, p, k): amici_model.setParameters(np.asarray(p, dtype=np.float64)) amici_model.setFixedParameters(np.asarray(k, dtype=np.float64)) - edatas = ( - amici.ExpData(sol_amici_ref, 1.0, 1.0), - amici.ExpData(sol_amici_ref, 1.0, 1.0), - ) - for edata in edatas: - edata.parameters = amici_model.getParameters() - edata.fixedParameters = amici_model.getFixedParameters() - edata.pscale = amici_model.getParameterScale() + edata = amici.ExpData(sol_amici_ref, 1.0, 1.0) + edata.parameters = amici_model.getParameters() + edata.fixedParameters = amici_model.getFixedParameters() + edata.pscale = amici_model.getParameterScale() amici_solver = amici_model.getSolver() amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward) amici_solver.setSensitivityOrder(amici.SensitivityOrder.first) - rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, edatas) + rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata]) - check_fields_jax(rs_amici, jax_model, edatas, ["x", "y", "llh"]) + check_fields_jax(rs_amici, jax_model, edata, ["x", "y", "llh"]) check_fields_jax( rs_amici, jax_model, - edatas, + edata, ["x", "y", "llh", "sllh"], sensi_order=amici.SensitivityOrder.first, ) @@ -126,7 +122,7 @@ def _test_model(model_module, ts, p, k): check_fields_jax( rs_amici, jax_model, - edatas, + edata, ["x", "y", "llh", "sllh"], sensi_order=amici.SensitivityOrder.second, ) @@ -135,13 +131,43 @@ def _test_model(model_module, ts, p, k): def check_fields_jax( rs_amici, jax_model, - edatas, + edata, fields, sensi_order=amici.SensitivityOrder.none, ): - rs_jax = jax_model.run_simulations(edatas, sensitivity_order=sensi_order) + r_jax = dict() + kwargs = { + "ts": np.array(edata.getTimepoints()), + "ts_dyn": np.array(edata.getTimepoints()), + "p": np.array(edata.parameters), + "k": np.array(edata.fixedParameters), + "k_preeq": np.array([]), + "my": np.array(edata.getObservedData()).reshape( + np.array(edata.getTimepoints()).shape[0], -1 + ), + "pscale": np.array(edata.pscale), + } + if sensi_order == amici.SensitivityOrder.none: + ( + r_jax["llh"], + (r_jax["x"], r_jax["y"], r_jax["stats"]), + ) = jax_model.run(**kwargs) + elif sensi_order == amici.SensitivityOrder.first: + ( + r_jax["llh"], + r_jax["sllh"], + (r_jax["x"], r_jax["y"], r_jax["stats"]), + ) = jax_model.srun(**kwargs) + elif sensi_order == amici.SensitivityOrder.second: + ( + r_jax["llh"], + r_jax["sllh"], + r_jax["s2llh"], + (r_jax["x"], r_jax["y"], r_jax["stats"]), + ) = jax_model.s2run(**kwargs) + for field in fields: - for r_amici, r_jax in zip(rs_amici, rs_jax): + for r_amici, r_jax in zip(rs_amici, [r_jax]): assert_allclose( actual=r_amici[field], desired=r_jax[field],