From 8b9c10ae330e669898e6405c318a6481eb15f3db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Fri, 25 Oct 2024 22:53:32 +0100 Subject: [PATCH] fix hessian --- python/sdist/amici/jax.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/python/sdist/amici/jax.py b/python/sdist/amici/jax.py index e798a0138f..74e601dd8c 100644 --- a/python/sdist/amici/jax.py +++ b/python/sdist/amici/jax.py @@ -258,7 +258,15 @@ def s2run( )(ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic) s2llh = jax.hessian(self._run, 2, True)( - ts, ts_dyn, p, k, k_preeq, my, pscale, dynamic=dynamic + ts, + ts_dyn, + p, + k, + k_preeq, + my, + pscale, + checkpointed=False, + dynamic=dynamic, ) return llh, sllh, s2llh, (x, obs, stats)