From 266a6118410b0be36dd82be58c700f60663081cf Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Tue, 26 Nov 2024 22:29:03 +0100 Subject: [PATCH] Improve model extension import Previously, it wasn't possible to import two model modules with the same name. Now this is at least possible if that are in different locations. Overwriting and importing a previously imported extension is still not supported. --- python/sdist/amici/__init__.py | 27 +++++++- python/sdist/amici/__init__.template.py | 17 +++-- python/tests/test_sbml_import.py | 91 +++++++++++++++++++++++++ swig/modelname.template.i | 18 ++++- 4 files changed, 145 insertions(+), 8 deletions(-) diff --git a/python/sdist/amici/__init__.py b/python/sdist/amici/__init__.py index f08d9e9349..81fc2c5980 100644 --- a/python/sdist/amici/__init__.py +++ b/python/sdist/amici/__init__.py @@ -8,13 +8,13 @@ import contextlib import datetime -import importlib +import importlib.util import os import re import sys import sysconfig from pathlib import Path -from types import ModuleType as ModelModule +from types import ModuleType from typing import Any from collections.abc import Callable @@ -145,6 +145,8 @@ def get_model(self) -> amici.Model: def get_jax_model(self) -> JAXModel: ... AmiciModel = Union[amici.Model, amici.ModelPtr] +else: + ModelModule = ModuleType class add_path: @@ -162,6 +164,27 @@ def __exit__(self, exc_type, exc_value, traceback): sys.path.remove(self.path) +def _module_from_path(module_name: str, module_path: Path | str) -> ModuleType: + """Import a module from a given path. + + Import a module from a given path. The module is not added to + `sys.modules`. + + :param module_name: + Name of the module. + :param module_path: + Path to the module file. Absolute or relative to the current working + directory. + """ + module_path = Path(module_path).resolve() + if not module_path.is_file(): + raise ImportError(f"Module file not found: {module_path}") + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def import_model_module( module_name: str, module_path: Path | str ) -> ModelModule: diff --git a/python/sdist/amici/__init__.template.py b/python/sdist/amici/__init__.template.py index 56064535e8..b6522e6a8a 100644 --- a/python/sdist/amici/__init__.template.py +++ b/python/sdist/amici/__init__.template.py @@ -18,14 +18,21 @@ "version currently installed." ) -from .TPL_MODELNAME import * # noqa: F403, F401 -from .TPL_MODELNAME import getModel as get_model # noqa: F401 +# from .TPL_MODELNAME import * # noqa: F403, F401 +# from .TPL_MODELNAME import getModel as get_model # noqa: F401 +TPL_MODELNAME = amici._module_from_path( + "TPL_MODELNAME", Path(__file__).parent / "TPL_MODELNAME.py" +) +for var in dir(TPL_MODELNAME): + if not var.startswith("__"): + globals()[var] = getattr(TPL_MODELNAME, var) +get_model = TPL_MODELNAME.getModel def get_jax_model() -> "JAXModel": - from .jax import JAXModel_TPL_MODELNAME - - return JAXModel_TPL_MODELNAME() + # from .jax import JAXModel_TPL_MODELNAME + jax = amici._module_from_path("jax", Path(__file__).parent / "jax.py") + return jax.JAXModel_TPL_MODELNAME() __version__ = "TPL_PACKAGE_VERSION" diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 7a3f0a2720..58c34f0ecb 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -813,3 +813,94 @@ def test_same_extension_error(): module_name=module_name, module_path=outdir ) assert model_module_1.get_model().getParameters()[0] == 1.0 + + +def test_same_extension_no_error(): + """Test for error when loading a model with the same extension name as an + already loaded model.""" + from amici.antimony_import import antimony2amici + + def import_model_module( + module_name: str, module_path: Path | str + ) -> amici.ModelModule: + import importlib + + importlib.invalidate_caches() + module_path = Path(module_path, module_name, "__init__.py") + module = amici._module_from_path(module_name, module_path) + return module + + ant_model_1 = """ + model test_same_extension_error + species A = 0 + p = 1 + A' = p + end + """ + ant_model_2 = ant_model_1.replace("1", "2") + ant_model_3 = ant_model_1.replace("1", "3") + + module_name = "test_same_extension" + outdir_1 = "deleteme1" + outdir_2 = "deleteme2" + + antimony2amici( + ant_model_1, + model_name=module_name, + output_dir=outdir_1, + compute_conservation_laws=False, + ) + + antimony2amici( + ant_model_2, + model_name=module_name, + output_dir=outdir_2, + compute_conservation_laws=False, + ) + + model_module_1 = import_model_module( + module_name=module_name, module_path=outdir_1 + ) + assert model_module_1.get_model().getParameters()[0] == 1.0 + # no error if the same model is loaded again without changes on disk + model_module_1b = import_model_module( + module_name=module_name, module_path=outdir_1 + ) + # assert model_module_1 == model_module_1b + assert ( + model_module_1.get_model().__class__ + is model_module_1b.get_model().__class__ + ) + assert isinstance( + model_module_1.get_model(), model_module_1b.get_model().__class__ + ) + + assert model_module_1.get_model().getParameters()[0] == 1.0 + + # Try to import another model with the same name + + # # On Windows, this will give "permission denied" when building the + # # extension + # if sys.platform == "win32": + # return + + model_module_2 = import_model_module( + module_name=module_name, module_path=outdir_2 + ) + print(model_module_1) + print(model_module_2) + assert model_module_1.get_model().getParameters()[0] == 1.0 + assert model_module_2.get_model().getParameters()[0] == 2.0 + + # replace #2 by #3 + antimony2amici( + ant_model_3, model_name=module_name, output_dir=outdir_2, verbose=True + ) + + # unsupported -- TODO: pytest.raises + # model_module_3 = import_model_module( + # module_name=module_name, module_path=outdir_2 + # ) + # assert model_module_1.get_model().getParameters()[0] == 1.0 + # assert model_module_2.get_model().getParameters()[0] == 2.0 + # assert model_module_3.get_model().getParameters()[0] == 3.0 diff --git a/swig/modelname.template.i b/swig/modelname.template.i index d7aab8ed8a..036c8678d6 100644 --- a/swig/modelname.template.i +++ b/swig/modelname.template.i @@ -1,4 +1,20 @@ -%module TPL_MODELNAME +%define MODULEIMPORT +" +import amici +import importlib.util +import sysconfig +from pathlib import Path + +ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') +_TPL_MODELNAME = amici._module_from_path( + '_TPL_MODELNAME', + Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}', +) +" +%enddef + +%module(package="TPL_MODELNAME",moduleimport=MODULEIMPORT) TPL_MODELNAME + %import amici.i // Add necessary symbols to generated header