diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 40dbb27e47..3254667c50 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -19,6 +19,10 @@ jax.config.update("jax_enable_x64", True) +ATOL_SIM = 1e-12 +RTOL_SIM = 1e-12 + + @skip_on_valgrind def test_conversion(): pysb.SelfExporter.cleanup() # reset pysb @@ -115,6 +119,8 @@ def _test_model(model_module, ts, p, k): amici_solver = amici_model.getSolver() amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward) amici_solver.setSensitivityOrder(amici.SensitivityOrder.first) + amici_solver.setAbsoluteTolerance(ATOL_SIM) + amici_solver.setRelativeTolerance(RTOL_SIM) rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata]) check_fields_jax( @@ -158,7 +164,7 @@ def check_fields_jax( jnp.array(my), # my jnp.array(iys), # iys diffrax.Kvaerno5(), # solver - diffrax.PIDController(atol=1e-8, rtol=1e-8), # controller + diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller diffrax.RecursiveCheckpointAdjoint(), # adjoint 2**8, # max_steps )