diff --git a/tests/benchmark-models/test_petab_model.py b/tests/benchmark-models/test_petab_model.py index 64af79d8a8..89a482cd7a 100755 --- a/tests/benchmark-models/test_petab_model.py +++ b/tests/benchmark-models/test_petab_model.py @@ -145,65 +145,50 @@ def main(): rdatas = res[RDATAS] llh = res[LLH] - if ( - args.model_name - not in ( - # "Beer_MolBioSystems2014", - # "Brannmark_JBC2010", - # "Fujita_SciSignal2010", - # "Isensee_JCB2018", - # "Smith_BMCSystBiol2013", - # "Weber_BMC2015", - ) - ): - # Beer: Heaviside - # Brannmark: Heaviside - # Fujita: Heaviside - # Isensee: Heaviside - # Smith: Heaviside - # Weber: Heaviside - - jax_model = model_module.get_jax_model() - simulation_conditions = ( - problem.get_simulation_conditions_from_measurement_df() - ) - edatas = create_edatas( - amici_model=amici_model, - petab_problem=problem, - simulation_conditions=simulation_conditions, - ) - problem_parameters = { - t.Index: getattr(t, petab.NOMINAL_VALUE) - for t in problem.parameter_df.itertuples() - } - parameter_mapping = create_parameter_mapping( - petab_problem=problem, - simulation_conditions=simulation_conditions, - scaled_parameters=False, - amici_model=amici_model, - ) - fill_in_parameters( - edatas=edatas, - problem_parameters=problem_parameters, - scaled_parameters=False, - parameter_mapping=parameter_mapping, - amici_model=amici_model, - ) - # run once to JIT - jax_model.run_simulations(edatas) - start_jax = timer() - rdatas_jax = jax_model.run_simulations(edatas) - end_jax = timer() + jax_model = model_module.get_jax_model() + simulation_conditions = ( + problem.get_simulation_conditions_from_measurement_df() + ) + edatas = create_edatas( + amici_model=amici_model, + petab_problem=problem, + simulation_conditions=simulation_conditions, + ) + problem_parameters = { + t.Index: getattr(t, petab.NOMINAL_VALUE) + for t in problem.parameter_df.itertuples() + } + parameter_mapping = create_parameter_mapping( + petab_problem=problem, + simulation_conditions=simulation_conditions, + scaled_parameters=False, + amici_model=amici_model, + ) + fill_in_parameters( + edatas=edatas, + problem_parameters=problem_parameters, + scaled_parameters=False, + parameter_mapping=parameter_mapping, + amici_model=amici_model, + ) + # run once to JIT + jax_model.run_simulations(edatas) + start_jax = timer() + rdatas_jax = jax_model.run_simulations(edatas) + end_jax = timer() - t_jax = end_jax - start_jax - t_amici = sum(r.cpu_time for r in rdatas) / 1e3 + t_jax = end_jax - start_jax + t_amici = sum(r.cpu_time for r in rdatas) / 1e3 - llh_jax = sum(r.llh for r in rdatas_jax) + llh_jax = sum(r.llh for r in rdatas_jax) - print( - f'amici (llh={res["llh"]} after {t_amici}s) vs ' - f'jax (llh={llh_jax} after {t_jax}s)' - ) + print( + f'amici (llh={res["llh"]} after {t_amici}s) vs ' + f'jax (llh={llh_jax} after {t_jax}s)' + ) + assert np.isclose( + llh, llh_jax, rtol=1e-3, atol=1e-3 + ), "LLH mismatch {llh} (amici) vs {llh_jax} (jax)" times = dict()