Skip to content

Commit

Permalink
change integration tolerances JAX test (#2601)
Browse files Browse the repository at this point in the history
fixes #2598, test pass locally with `assert_allclose(..., atol=1e-7, rtol=1e-7)`
  • Loading branch information
FFroehlich authored Nov 26, 2024
1 parent 5a16d3a commit c75909d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion python/tests/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
jax.config.update("jax_enable_x64", True)


ATOL_SIM = 1e-12
RTOL_SIM = 1e-12


@skip_on_valgrind
def test_conversion():
pysb.SelfExporter.cleanup() # reset pysb
Expand Down Expand Up @@ -115,6 +119,8 @@ def _test_model(model_module, ts, p, k):
amici_solver = amici_model.getSolver()
amici_solver.setSensitivityMethod(amici.SensitivityMethod.forward)
amici_solver.setSensitivityOrder(amici.SensitivityOrder.first)
amici_solver.setAbsoluteTolerance(ATOL_SIM)
amici_solver.setRelativeTolerance(RTOL_SIM)
rs_amici = amici.runAmiciSimulations(amici_model, amici_solver, [edata])

check_fields_jax(
Expand Down Expand Up @@ -158,7 +164,7 @@ def check_fields_jax(
jnp.array(my), # my
jnp.array(iys), # iys
diffrax.Kvaerno5(), # solver
diffrax.PIDController(atol=1e-8, rtol=1e-8), # controller
diffrax.PIDController(atol=ATOL_SIM, rtol=RTOL_SIM), # controller
diffrax.RecursiveCheckpointAdjoint(), # adjoint
2**8, # max_steps
)
Expand Down

0 comments on commit c75909d

Please sign in to comment.