diff --git a/python/sdist/amici/__init__.py b/python/sdist/amici/__init__.py index f08d9e9349..ee78a6045b 100644 --- a/python/sdist/amici/__init__.py +++ b/python/sdist/amici/__init__.py @@ -148,7 +148,10 @@ def get_jax_model(self) -> JAXModel: ... class add_path: - """Context manager for temporarily changing PYTHONPATH""" + """Context manager for temporarily changing PYTHONPATH. + + Add a path to the PYTHONPATH for the duration of the context manager. + """ def __init__(self, path: str | Path): self.path: str = str(path) @@ -162,6 +165,23 @@ def __exit__(self, exc_type, exc_value, traceback): sys.path.remove(self.path) +class set_path: + """Context manager for temporarily changing PYTHONPATH. + + Set the PYTHONPATH to a given path for the duration of the context manager. + """ + + def __init__(self, path: str | Path): + self.path: str = str(path) + + def __enter__(self): + self.orginal_path = sys.path.copy() + sys.path = [self.path] + + def __exit__(self, exc_type, exc_value, traceback): + sys.path = self.orginal_path + + def import_model_module( module_name: str, module_path: Path | str ) -> ModelModule: @@ -198,13 +218,20 @@ def import_model_module( module_name, f"_{module_name}{ext_suffix}", ) + # if we import a matlab-generated model where the extension + # is in a different directory + needed_file_matlab = Path( + module_path, + f"_{module_name}{ext_suffix}", + ) if not needed_file.exists(): - # if we import a matlab-generated model where the extension - # is in a different directory - needed_file = Path( - module_path, - f"_{module_name}{ext_suffix}", - ) + if needed_file_matlab.exists(): + needed_file = needed_file_matlab + else: + raise ModuleNotFoundError( + f"Cannot find extension module for {module_name} in " + f"{module_path}." + ) if not loaded_file.samefile(needed_file): # this is not the right module, and we can't unload it @@ -246,7 +273,7 @@ def import_model_module( for m in to_unload: del sys.modules[m] - with add_path(module_path): + with set_path(module_path): return importlib.import_module(module_name) diff --git a/python/sdist/amici/petab/import_helpers.py b/python/sdist/amici/petab/import_helpers.py index 70af87c3b3..19afe5b237 100644 --- a/python/sdist/amici/petab/import_helpers.py +++ b/python/sdist/amici/petab/import_helpers.py @@ -3,7 +3,6 @@ Functions for PEtab import that are independent of the model format. """ -import importlib import logging import os import re @@ -138,8 +137,7 @@ def _can_import_model(model_name: str, model_output_dir: str | Path) -> bool: """ # try to import (in particular checks version) try: - with amici.add_path(model_output_dir): - model_module = importlib.import_module(model_name) + model_module = amici.import_model_module(model_name, model_output_dir) except ModuleNotFoundError: return False