diff --git a/python/tests/test_jax.py b/python/tests/test_jax.py index 1ccd388257..d124a6e1be 100644 --- a/python/tests/test_jax.py +++ b/python/tests/test_jax.py @@ -162,7 +162,7 @@ def check_fields_jax( ) fun = beartype(jax_model.simulate_condition) - for output in ["nllh", "x0", "x", "y", "res"]: + for output in ["llh", "x0", "x", "y", "res"]: oargs = (*args[:-2], diffrax.DirectAdjoint(), 2**8, output) if sensi_order == amici.SensitivityOrder.none: r_jax[output] = fun(p, *oargs)[0]