From 69f2fa451239b4b957fc234bface63308f9c1caa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fabian=20Fr=C3=B6hlich?= Date: Thu, 5 Dec 2024 15:02:47 +0000 Subject: [PATCH] fix path SNAFU --- python/sdist/amici/jax/ode_export.py | 5 ++--- python/sdist/amici/petab/import_helpers.py | 25 +++++++++++++++++++++- python/sdist/amici/petab/petab_import.py | 13 ++++++++--- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/python/sdist/amici/jax/ode_export.py b/python/sdist/amici/jax/ode_export.py index 7ea4a29d8a..cec5104ded 100644 --- a/python/sdist/amici/jax/ode_export.py +++ b/python/sdist/amici/jax/ode_export.py @@ -234,12 +234,10 @@ def _generate_jax_code(self) -> None: "MODEL_API_VERSION": f"'{JAXModel.MODEL_API_VERSION}'", }, } - outdir = self.model_path / (self.model_name + "_jax") - outdir.mkdir(parents=True, exist_ok=True) apply_template( Path(amiciModulePath) / "jax" / "jax.template.py", - outdir / "__init__.py", + self.model_path / "__init__.py", tpl_data, ) @@ -258,6 +256,7 @@ def set_paths(self, output_dir: str | Path | None = None) -> None: output_dir = Path(os.getcwd()) / f"amici-{self.model_name}" self.model_path = Path(output_dir).resolve() + self.model_path.mkdir(parents=True, exist_ok=True) def set_name(self, model_name: str) -> None: """ diff --git a/python/sdist/amici/petab/import_helpers.py b/python/sdist/amici/petab/import_helpers.py index 57bc551205..d42e99b1e3 100644 --- a/python/sdist/amici/petab/import_helpers.py +++ b/python/sdist/amici/petab/import_helpers.py @@ -139,7 +139,9 @@ def _can_import_model( """ # try to import (in particular checks version) try: - model_module = amici.import_model_module(model_name, model_output_dir) + model_module = amici.import_model_module( + *_get_package_name_and_path(model_name, model_output_dir, jax) + ) except ModuleNotFoundError: return False @@ -268,3 +270,24 @@ def check_model( "the current model might also resolve this. Parameters: " f"{amici_ids_free_required.difference(amici_ids_free)}" ) + + +def _get_package_name_and_path( + model_name: str, model_output_dir: str | Path, jax: bool = False +) -> tuple[str, Path]: + """ + Get the package name and path for the generated model module. + + :param model_name: + Name of the model + :param model_output_dir: + Target directory for the generated model module + :param jax: + Whether to generate the paths for a JAX or CPP model + :return: + """ + if jax: + outdir = Path(model_output_dir) + return outdir.stem, outdir.parent + else: + return model_name, Path(model_output_dir) diff --git a/python/sdist/amici/petab/petab_import.py b/python/sdist/amici/petab/petab_import.py index 24cb21a466..b7fccca241 100644 --- a/python/sdist/amici/petab/petab_import.py +++ b/python/sdist/amici/petab/petab_import.py @@ -16,7 +16,12 @@ from petab.v1.models import MODEL_TYPE_PYSB, MODEL_TYPE_SBML from ..logging import get_logger -from .import_helpers import _can_import_model, _create_model_name, check_model +from .import_helpers import ( + _can_import_model, + _create_model_name, + check_model, + _get_package_name_and_path, +) from .sbml_import import import_model_sbml try: @@ -136,7 +141,7 @@ def import_petab_problem( ) # remove folder if exists - if os.path.exists(model_output_dir): + if not jax and os.path.exists(model_output_dir): shutil.rmtree(model_output_dir) logger.info(f"Compiling model {model_name} to {model_output_dir}.") @@ -160,7 +165,9 @@ def import_petab_problem( ) # import model - model_module = amici.import_model_module(model_name, model_output_dir) + model_module = amici.import_model_module( + *_get_package_name_and_path(model_name, model_output_dir, jax=jax) + ) if jax: model = model_module.Model()