Skip to content

Commit

Permalink
no compilation for jax
Browse files Browse the repository at this point in the history
  • Loading branch information
FFroehlich committed Nov 28, 2024
1 parent 9fd5835 commit 862586d
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions python/sdist/amici/petab/petab_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def import_petab_problem(
parameters are required, this should be set to ``False``.
:param jax:
Whether to load the jax version of the model.
Whether to load the jax version of the model. Note that this disables
compilation of the model module unless `compile` is set to `True`.
:param kwargs:
Additional keyword arguments to be passed to
Expand Down Expand Up @@ -145,6 +146,7 @@ def import_petab_problem(
petab_problem,
model_name=model_name,
model_output_dir=model_output_dir,
compile=kwargs.pop("compile", not jax),
**kwargs,
)
else:
Expand All @@ -153,14 +155,19 @@ def import_petab_problem(
model_name=model_name,
model_output_dir=model_output_dir,
non_estimated_parameters_as_constants=non_estimated_parameters_as_constants,
compile=kwargs.pop("compile", not jax),
**kwargs,
)

# import model
model_module = amici.import_model_module(model_name, model_output_dir)
if not jax:
model_module = amici.import_model_module(model_name, model_output_dir)

if jax:
model = model_module.get_jax_model()
else:
jax_model_module = amici._module_from_path(
"jax", Path(model_output_dir) / model_name / "jax.py"
)
model = jax_model_module.Model()

logger.info(
f"Successfully loaded jax model {model_name} "
Expand Down

0 comments on commit 862586d

Please sign in to comment.