From 68fbc9dfc251d0e11e1b93e8e0c62dddcbf871ab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 17 Dec 2024 10:11:00 +0000 Subject: [PATCH] add event check in equilibration, add test --- python/sdist/amici/jax/model.py | 10 +++++++++- python/tests/test_jax.py | 25 ++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/python/sdist/amici/jax/model.py b/python/sdist/amici/jax/model.py index 98e123b5f0..90c287aeac 100644 --- a/python/sdist/amici/jax/model.py +++ b/python/sdist/amici/jax/model.py @@ -281,7 +281,15 @@ def _eq( event=diffrax.Event(cond_fn=diffrax.steady_state_event()), throw=False, ) - return sol.ys[-1, :], sol.stats + # If the event was triggered, the event mask is True and the solution is the steady state. Otherwise, the + # solution is the last state and the event mask is False. In the latter case, we return inf for the steady + # state. + ys = jnp.where( + sol.event_mask, + sol.ys[-1, :], + jnp.inf * jnp.ones_like(sol.ys[-1, :]), + ) + return ys, sol.stats def _solve( self, diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index ef9cbde576..43c79d6980 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -11,11 +11,12 @@ import diffrax import numpy as np from beartype import beartype +from petab.v2.C import PREEQUILIBRATION_CONDITION_ID, SIMULATION_CONDITION_ID from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem -from amici.jax import JAXProblem, ReturnValue +from amici.jax import JAXProblem, ReturnValue, run_simulations from numpy.testing import assert_allclose from test_petab_objective import lotka_volterra # noqa: F401 @@ -268,6 +269,28 @@ def check_fields_jax( ) +def test_preequilibration_failure(lotka_volterra): # noqa: F811 + petab_problem = lotka_volterra + # oscillating system, preequilibation should fail when interaction is active + with TemporaryDirectoryWinSafe(prefix="normal") as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + r = run_simulations(jax_problem) + assert not np.isinf(r[0].item()) + petab_problem.measurement_df[PREEQUILIBRATION_CONDITION_ID] = ( + petab_problem.measurement_df[SIMULATION_CONDITION_ID] + ) + with TemporaryDirectoryWinSafe(prefix="failure") as model_dir: + jax_model = import_petab_problem( + petab_problem, jax=True, model_output_dir=model_dir + ) + jax_problem = JAXProblem(jax_model, petab_problem) + r = run_simulations(jax_problem) + assert np.isinf(r[0].item()) + + @skip_on_valgrind def test_serialisation(lotka_volterra): # noqa: F811 petab_problem = lotka_volterra