Skip to content

Commit

Permalink
Fix benchmark collection gradient check (AMICI-dev#2564)
Browse files Browse the repository at this point in the history
* Fix instance vs class attribute for step sizes
* Only flatten problems where necessary
  • Loading branch information
dweindl authored Oct 24, 2024
1 parent 9c603e1 commit 56e6956
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
15 changes: 14 additions & 1 deletion tests/benchmark-models/test_benchmark_collection.sh
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,22 @@ for model in $models; do
yaml="${model_dir}"/"${model}"/problem.yaml
fi

# problems we need to flatten
to_flatten=(
"Bruno_JExpBot2016" "Chen_MSB2009" "Crauste_CellSystems2017"
"Fiedler_BMCSystBiol2016" "Fujita_SciSignal2010" "SalazarCavazos_MBoC2020"
)
flatten=""
for item in "${to_flatten[@]}"; do
if [[ "$item" == "$model" ]]; then
flatten="--flatten"
break
fi
done

amici_model_dir=test_bmc/"${model}"
mkdir -p "$amici_model_dir"
cmd_import="amici_import_petab ${yaml} -o ${amici_model_dir} -n ${model} --flatten"
cmd_import="amici_import_petab ${yaml} -o ${amici_model_dir} -n ${model} ${flatten}"
cmd_run="$script_path/test_petab_model.py -y ${yaml} -d ${amici_model_dir} -m ${model} -c"

printf '=%.0s' {1..40}
Expand Down
29 changes: 18 additions & 11 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from amici.petab.petab_import import import_petab_problem
import benchmark_models_petab
from collections import defaultdict
from dataclasses import dataclass
from dataclasses import dataclass, field
from amici import SensitivityMethod
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings
from fiddy import MethodId, get_derivative
from fiddy.derivative_check import NumpyIsCloseDerivativeCheck
from fiddy.extensions.amici import simulate_petab_to_cached_functions
Expand Down Expand Up @@ -58,14 +59,18 @@ class GradientCheckSettings:
atol_consistency: float = 1e-5
rtol_consistency: float = 1e-1
# Step sizes for finite difference gradient checks.
step_sizes = [
1e-1,
5e-2,
1e-2,
1e-3,
1e-4,
1e-5,
]
step_sizes: list[float] = field(
default_factory=lambda: [
2e-1,
1e-1,
5e-2,
1e-2,
5e-1,
1e-3,
1e-4,
1e-5,
]
)
rng_seed: int = 0
ss_sensitivity_mode: amici.SteadyStateSensitivityMode = (
amici.SteadyStateSensitivityMode.integrateIfNewtonFails
Expand Down Expand Up @@ -97,7 +102,6 @@ class GradientCheckSettings:
noise_level=0.01,
atol_consistency=1e-3,
)
settings["Okuonghae_ChaosSolitonsFractals2020"].step_sizes.extend([0.2, 0.005])
settings["Oliveira_NatCommun2021"] = GradientCheckSettings(
# Avoid "root after reinitialization"
atol_sim=1e-12,
Expand Down Expand Up @@ -176,7 +180,10 @@ def test_benchmark_gradient(model, scale, sensitivity_method, request):
pytest.skip()

petab_problem = benchmark_models_petab.get_problem(model)
petab.flatten_timepoint_specific_output_overrides(petab_problem)
if measurement_table_has_timepoint_specific_mappings(
petab_problem.measurement_df,
):
petab.flatten_timepoint_specific_output_overrides(petab_problem)

# Only compute gradient for estimated parameters.
parameter_ids = petab_problem.x_free_ids
Expand Down
7 changes: 6 additions & 1 deletion tests/benchmark-models/test_petab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
simulate_petab,
)
from petab.v1.visualize import plot_problem
from petab.v1.lint import measurement_table_has_timepoint_specific_mappings

logger = get_logger(f"amici.{__name__}", logging.WARNING)

Expand Down Expand Up @@ -115,7 +116,11 @@ def main():

# load PEtab files
problem = petab.Problem.from_yaml(args.yaml_file_name)
petab.flatten_timepoint_specific_output_overrides(problem)

if measurement_table_has_timepoint_specific_mappings(
problem.measurement_df
):
petab.flatten_timepoint_specific_output_overrides(problem)

# load model
if args.model_directory:
Expand Down

0 comments on commit 56e6956

Please sign in to comment.