From 5366632e716de3a82513e489b7221f94885408f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 25 Oct 2024 13:40:26 +0100 Subject: [PATCH] fixup --- python/sdist/amici/jax.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index 67e0869f9c..c1f083a799 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -171,7 +171,7 @@ def _run( x0 = self.x0(ps, k) # Dynamic simulation - if dynamic and ts_dyn.shape[0] > 0: + if dynamic == "true": x, tcl, stats = self._solve( ts_dyn, ps, k, x0, checkpointed=checkpointed ) @@ -220,7 +220,9 @@ def run( pscale: np.ndarray, dynamic=True, ): - return self._run(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic) + return self._run( + ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ) @eqx.filter_jit def srun( @@ -236,7 +238,7 @@ def srun( ): (llh, (x, obs, stats)), sllh = ( jax.value_and_grad(self._run, 1, True) - )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic) + )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) return llh, sllh, (x, obs, stats) @eqx.filter_jit @@ -253,10 +255,10 @@ def s2run( ): (llh, (x, obs, stats)), sllh = ( jax.value_and_grad(self._run, 1, True) - )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic) + )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) s2llh = jax.hessian(self._run, 1, True)( - ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic + ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic ) return llh, sllh, s2llh, (x, obs, stats) @@ -271,7 +273,7 @@ def run_simulation( my = np.asarray(edata.getObservedData()) pscale = np.asarray(edata.pscale) ts_dyn = ts[np.isfinite(ts)] - dynamic = len(ts_dyn) > 0 and np.max(ts_dyn) > 0 + dynamic = "true" if len(ts_dyn) and np.max(ts_dyn) > 0 else "false" rdata_kwargs = dict() @@ -279,20 +281,26 @@ def run_simulation( ( rdata_kwargs["llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.run(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic) + ) = self.run( + ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ) elif sensitivity_order == amici.SensitivityOrder.first: ( rdata_kwargs["llh"], rdata_kwargs["sllh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.srun(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic) + ) = self.srun( + ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ) elif sensitivity_order == amici.SensitivityOrder.second: ( rdata_kwargs["llh"], rdata_kwargs["sllh"], rdata_kwargs["s2llh"], (rdata_kwargs["x"], rdata_kwargs["y"], rdata_kwargs["stats"]), - ) = self.s2run(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic) + ) = self.s2run( + ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ) for field in rdata_kwargs.keys(): if field == "llh":