From c75909d6029261a72d79d6934a9b83568038d845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Tue, 26 Nov 2024 13:46:58 +0000 Subject: [PATCH] change integration tolerances JAX test (#2601) fixes #2598, test pass locally with `assert_allclose(..., atol=1e-7, rtol=1e-7)` --- python/tests/test_jax.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 40dbb27e47..3254667c50 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -19,6 +19,10 @@ jax.config.update("jax_enable_x64", True) +ATOL_SIM = 1e-12 +RTOL_SIM = 1e-12 + + @skip_on_valgrind def test_conversion(): pysb.SelfExporter.cleanup() # reset pysb @@ -115,6 +119,8 @@ def _test_model(model_module, ts, p, k): amici_solver = amici_model.getSolver() amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward) amici_solver.setSensitivityOrder(amici.SensitivityOrder.first) + amici_solver.setAbsoluteTolerance(ATOL_SIM) + amici_solver.setRelativeTolerance(RTOL_SIM) rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata]) check_fields_jax( @@ -158,7 +164,7 @@ def check_fields_jax( jnp.array(my), # my jnp.array(iys), # iys diffrax.Kvaerno5(), # solver - diffrax.PIDController(atol=1e-8, rtol=1e-8), # controller + diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller diffrax.RecursiveCheckpointAdjoint(), # adjoint 2**8, # max_steps )