diff --git a/python/sdist/amici/__init__.py b/python/sdist/amici/__init__.py index f8776d8ba6..ee78a6045b 100644 --- a/python/sdist/amici/__init__.py +++ b/python/sdist/amici/__init__.py @@ -148,7 +148,10 @@ def get_jax_model(self) -> JAXModel: ... class add_path: - """Context manager for temporarily changing PYTHONPATH""" + """Context manager for temporarily changing PYTHONPATH. + + Add a path to the PYTHONPATH for the duration of the context manager. + """ def __init__(self, path: str | Path): self.path: str = str(path) @@ -162,6 +165,23 @@ def __exit__(self, exc_type, exc_value, traceback): sys.path.remove(self.path) +class set_path: + """Context manager for temporarily changing PYTHONPATH. + + Set the PYTHONPATH to a given path for the duration of the context manager. + """ + + def __init__(self, path: str | Path): + self.path: str = str(path) + + def __enter__(self): + self.orginal_path = sys.path.copy() + sys.path = [self.path] + + def __exit__(self, exc_type, exc_value, traceback): + sys.path = self.orginal_path + + def import_model_module( module_name: str, module_path: Path | str ) -> ModelModule: @@ -253,7 +273,7 @@ def import_model_module( for m in to_unload: del sys.modules[m] - with add_path(module_path): + with set_path(module_path): return importlib.import_module(module_name)