Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Oct 25, 2024
1 parent 5366632 commit 5a86f4c
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 18 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/test_benchmark_collection_models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ jobs:
- name: Download benchmark collection
run: |
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python \
&& AMICI_PARALLEL_COMPILE="" tests/benchmark-models/test_benchmark_collection.sh
pip install git+https://github.com/Benchmarking-Initiative/Benchmark-Models-PEtab.git@master#subdirectory=src/python
- name: Run tests
env:
Expand Down
16 changes: 0 additions & 16 deletions tests/benchmark-models/test_petab_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import contextlib
import logging
import yaml
import equinox as eqx
from amici.logging import get_logger
from amici.petab.simulations import (
LLH,
Expand Down Expand Up @@ -145,8 +144,6 @@ class GradientCheckSettings:
# forward/backward/central differences.
atol_consistency: float = 1e-5
rtol_consistency: float = 1e-1
# maximum number of integration steps
maxsteps: int = 10_000
# Step sizes for finite difference gradient checks.
step_sizes: list[float] = field(
default_factory=lambda: [
Expand Down Expand Up @@ -264,9 +261,6 @@ def test_jax_llh(benchmark_problem):
problem_id, petab_problem, amici_model = benchmark_problem

amici_solver = amici_model.getSolver()
amici_solver.setAbsoluteTolerance(settings[problem_id].atol_sim)
amici_solver.setRelativeTolerance(settings[problem_id].rtol_sim)
amici_solver.setMaxSteps(settings[problem_id].maxsteps)

llh_amici = simulate_petab(
petab_problem=petab_problem,
Expand Down Expand Up @@ -306,16 +300,6 @@ def test_jax_llh(benchmark_problem):
amici_model=amici_model,
)

jax_model = eqx.tree_at(
lambda x: x.maxsteps, jax_model, settings[problem_id].maxsteps
)
jax_model = eqx.tree_at(
lambda x: x.atol, jax_model, settings[problem_id].atol_sim
)
jax_model = eqx.tree_at(
lambda x: x.rtol, jax_model, settings[problem_id].rtol_sim
)

rdatas_jax = jax_model.run_simulations(edatas)

llh_jax = sum(r.llh for r in rdatas_jax)
Expand Down

0 comments on commit 5a86f4c

Please sign in to comment.