From 01cd6bbcfd04fac2d53a441fac21a0add7beecdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Sat, 7 Dec 2024 19:51:30 +0000 Subject: [PATCH] fixup --- python/sdist/amici/jax/__init__.py | 15 +++++++++++++-- python/sdist/amici/jax/petab.py | 4 +++- python/tests/test_jax.py | 4 ++-- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/sdist/amici/jax/__init__.py b/python/sdist/amici/jax/__init__.py index 34642e3d49..a5b5dc1cae 100644 --- a/python/sdist/amici/jax/__init__.py +++ b/python/sdist/amici/jax/__init__.py @@ -9,7 +9,12 @@ from warnings import warn -from amici.jax.petab import JAXProblem, run_simulations, petab_simulate +from amici.jax.petab import ( + JAXProblem, + run_simulations, + petab_simulate, + ReturnValue, +) from amici.jax.model import JAXModel warn( @@ -18,4 +23,10 @@ stacklevel=2, ) -__all__ = ["JAXModel", "JAXProblem", "run_simulations", "petab_simulate"] +__all__ = [ + "JAXModel", + "JAXProblem", + "run_simulations", + "petab_simulate", + "ReturnValue", +] diff --git a/python/sdist/amici/jax/petab.py b/python/sdist/amici/jax/petab.py index c1f3da5c2e..3785a39594 100644 --- a/python/sdist/amici/jax/petab.py +++ b/python/sdist/amici/jax/petab.py @@ -562,7 +562,7 @@ def run_simulations( **DEFAULT_CONTROLLER_SETTINGS ), max_steps: int = 2**10, - ret: ReturnValue = ReturnValue.llh, + ret: ReturnValue | str = ReturnValue.llh, ): """ Run simulations for a problem. @@ -582,6 +582,8 @@ def run_simulations( :return: Overall output value and condition specific results and statistics. """ + ret = ReturnValue[ret] + if simulation_conditions is None: simulation_conditions = problem.get_all_simulation_conditions() diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 9a1b3fed31..ef9cbde576 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -15,7 +15,7 @@ from amici.pysb_import import pysb2amici, pysb2jax from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind from amici.petab.petab_import import import_petab_problem -from amici.jax import JAXProblem +from amici.jax import JAXProblem, ReturnValue from numpy.testing import assert_allclose from test_petab_objective import lotka_volterra # noqa: F401 @@ -208,7 +208,7 @@ def check_fields_jax( okwargs = kwargs | { "adjoint": diffrax.DirectAdjoint(), "max_steps": 2**8, - "ret": output, + "ret": ReturnValue[output], } if sensi_order == amici.SensitivityOrder.none: r_jax[output] = fun(p, **okwargs)[0]