Skip to content

Commit

Permalink
cleanup & actually run tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Oct 21, 2024
1 parent 7faae32 commit 907acb7
Showing 1 changed file with 41 additions and 56 deletions.
97 changes: 41 additions & 56 deletions tests/benchmark-models/test_petab_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,65 +145,50 @@ def main():
rdatas = res[RDATAS]
llh = res[LLH]

if (
args.model_name
not in (
# "Beer_MolBioSystems2014",
# "Brannmark_JBC2010",
# "Fujita_SciSignal2010",
# "Isensee_JCB2018",
# "Smith_BMCSystBiol2013",
# "Weber_BMC2015",
)
):
# Beer: Heaviside
# Brannmark: Heaviside
# Fujita: Heaviside
# Isensee: Heaviside
# Smith: Heaviside
# Weber: Heaviside

jax_model = model_module.get_jax_model()
simulation_conditions = (
problem.get_simulation_conditions_from_measurement_df()
)
edatas = create_edatas(
amici_model=amici_model,
petab_problem=problem,
simulation_conditions=simulation_conditions,
)
problem_parameters = {
t.Index: getattr(t, petab.NOMINAL_VALUE)
for t in problem.parameter_df.itertuples()
}
parameter_mapping = create_parameter_mapping(
petab_problem=problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=amici_model,
)
fill_in_parameters(
edatas=edatas,
problem_parameters=problem_parameters,
scaled_parameters=False,
parameter_mapping=parameter_mapping,
amici_model=amici_model,
)
# run once to JIT
jax_model.run_simulations(edatas)
start_jax = timer()
rdatas_jax = jax_model.run_simulations(edatas)
end_jax = timer()
jax_model = model_module.get_jax_model()
simulation_conditions = (
problem.get_simulation_conditions_from_measurement_df()
)
edatas = create_edatas(
amici_model=amici_model,
petab_problem=problem,
simulation_conditions=simulation_conditions,
)
problem_parameters = {
t.Index: getattr(t, petab.NOMINAL_VALUE)
for t in problem.parameter_df.itertuples()
}
parameter_mapping = create_parameter_mapping(
petab_problem=problem,
simulation_conditions=simulation_conditions,
scaled_parameters=False,
amici_model=amici_model,
)
fill_in_parameters(
edatas=edatas,
problem_parameters=problem_parameters,
scaled_parameters=False,
parameter_mapping=parameter_mapping,
amici_model=amici_model,
)
# run once to JIT
jax_model.run_simulations(edatas)
start_jax = timer()
rdatas_jax = jax_model.run_simulations(edatas)
end_jax = timer()

t_jax = end_jax - start_jax
t_amici = sum(r.cpu_time for r in rdatas) / 1e3
t_jax = end_jax - start_jax
t_amici = sum(r.cpu_time for r in rdatas) / 1e3

llh_jax = sum(r.llh for r in rdatas_jax)
llh_jax = sum(r.llh for r in rdatas_jax)

print(
f'amici (llh={res["llh"]} after {t_amici}s) vs '
f'jax (llh={llh_jax} after {t_jax}s)'
)
print(
f'amici (llh={res["llh"]} after {t_amici}s) vs '
f'jax (llh={llh_jax} after {t_jax}s)'
)
assert np.isclose(
llh, llh_jax, rtol=1e-3, atol=1e-3
), "LLH mismatch {llh} (amici) vs {llh_jax} (jax)"

times = dict()

Expand Down

0 comments on commit 907acb7

Please sign in to comment.