Skip to content

Commit

Permalink
..
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Oct 9, 2024
1 parent 76811db commit 4851257
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ class GradientCheckSettings:
rng_seed=1,
)
settings["Sneyd_PNAS2002"] = GradientCheckSettings(
atol_sim=1e-16,
rtol_sim=1e-10,
atol_sim=1e-15,
rtol_sim=1e-12,
atol_check=1e-5,
rtol_check=1e-4,
)
Expand All @@ -115,6 +115,11 @@ class GradientCheckSettings:
settings["Brannmark_JBC2010"] = GradientCheckSettings(
ss_sensitivity_mode=amici.SteadyStateSensitivityMode.integrationOnly,
)
settings["Borghans_BiophysChem1997"] = GradientCheckSettings(
rng_seed=2,
atol_check=1e-5,
rtol_check=1e-3,
)


def assert_gradient_check_success(
Expand All @@ -123,6 +128,7 @@ def assert_gradient_check_success(
point: np.array,
atol: float,
rtol: float,
always_print: bool = False,
) -> None:
if not derivative.df.success.all():
raise AssertionError(
Expand All @@ -135,7 +141,7 @@ def assert_gradient_check_success(
)
check_result = check(rtol=rtol, atol=atol)

if check_result.success is True:
if check_result.success is True and not always_print:
return

df = check_result.df
Expand All @@ -144,12 +150,22 @@ def assert_gradient_check_success(
max_adiff = df["abs_diff"].max()
max_rdiff = df["rel_diff"].max()
with pd.option_context("display.max_columns", None, "display.width", None):
message = (
f"Gradient check failed:\n{df}\n\n"
f"Maximum absolute difference: {max_adiff}\n"
f"Maximum relative difference: {max_rdiff}"
)

if check_result.success is False:
raise AssertionError(
f"Gradient check failed:\n{df}\n\n"
f"Maximum absolute difference: {max_adiff}\n"
f"Maximum relative difference: {max_rdiff}"
)

if always_print:
print(message)


@pytest.mark.filterwarnings(
"ignore:divide by zero encountered in log",
Expand Down Expand Up @@ -205,7 +221,7 @@ def test_benchmark_gradient(model, scale, sensitivity_method, request):
amici_solver = amici_model.getSolver()
amici_solver.setAbsoluteTolerance(cur_settings.atol_sim)
amici_solver.setRelativeTolerance(cur_settings.rtol_sim)
amici_solver.setMaxSteps(int(1e5))
amici_solver.setMaxSteps(2 * 10**5)
amici_solver.setSensitivityMethod(sensitivity_method)
# TODO: we should probably test all sensitivity modes
amici_model.setSteadyStateSensitivityMode(cur_settings.ss_sensitivity_mode)
Expand Down Expand Up @@ -278,6 +294,7 @@ def test_benchmark_gradient(model, scale, sensitivity_method, request):
point,
rtol=cur_settings.rtol_check,
atol=cur_settings.atol_check,
always_print=True,
)


Expand Down

0 comments on commit 4851257

Please sign in to comment.