diff --git a/python/sdist/amici/__init__.py b/python/sdist/amici/__init__.py index f08d9e9349..c12559ce71 100644 --- a/python/sdist/amici/__init__.py +++ b/python/sdist/amici/__init__.py @@ -8,13 +8,13 @@ import contextlib import datetime -import importlib +import importlib.util import os import re import sys import sysconfig from pathlib import Path -from types import ModuleType as ModelModule +from types import ModuleType from typing import Any from collections.abc import Callable @@ -145,6 +145,8 @@ def get_model(self) -> amici.Model: def get_jax_model(self) -> JAXModel: ... AmiciModel = Union[amici.Model, amici.ModelPtr] +else: + ModelModule = ModuleType class add_path: @@ -162,6 +164,27 @@ def __exit__(self, exc_type, exc_value, traceback): sys.path.remove(self.path) +def _module_from_path(module_name: str, module_path: Path | str) -> ModuleType: + """Import a module from a given path. + + Import a module from a given path. The module is not added to + `sys.modules`. + + :param module_name: + Name of the module. + :param module_path: + Path to the module file. Absolute or relative to the current working + directory. + """ + module_path = Path(module_path).resolve() + if not module_path.is_file(): + raise ModuleNotFoundError(f"Module file not found: {module_path}") + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + def import_model_module( module_name: str, module_path: Path | str ) -> ModelModule: diff --git a/python/sdist/amici/__init__.template.py b/python/sdist/amici/__init__.template.py index 56064535e8..b6522e6a8a 100644 --- a/python/sdist/amici/__init__.template.py +++ b/python/sdist/amici/__init__.template.py @@ -18,14 +18,21 @@ "version currently installed." ) -from .TPL_MODELNAME import * # noqa: F403, F401 -from .TPL_MODELNAME import getModel as get_model # noqa: F401 +# from .TPL_MODELNAME import * # noqa: F403, F401 +# from .TPL_MODELNAME import getModel as get_model # noqa: F401 +TPL_MODELNAME = amici._module_from_path( + "TPL_MODELNAME", Path(__file__).parent / "TPL_MODELNAME.py" +) +for var in dir(TPL_MODELNAME): + if not var.startswith("__"): + globals()[var] = getattr(TPL_MODELNAME, var) +get_model = TPL_MODELNAME.getModel def get_jax_model() -> "JAXModel": - from .jax import JAXModel_TPL_MODELNAME - - return JAXModel_TPL_MODELNAME() + # from .jax import JAXModel_TPL_MODELNAME + jax = amici._module_from_path("jax", Path(__file__).parent / "jax.py") + return jax.JAXModel_TPL_MODELNAME() __version__ = "TPL_PACKAGE_VERSION" diff --git a/python/tests/test_sbml_import.py b/python/tests/test_sbml_import.py index 7a3f0a2720..0c7dd204dc 100644 --- a/python/tests/test_sbml_import.py +++ b/python/tests/test_sbml_import.py @@ -813,3 +813,89 @@ def test_same_extension_error(): module_name=module_name, module_path=outdir ) assert model_module_1.get_model().getParameters()[0] == 1.0 + + +def test_same_extension_no_error(): + """Test for error when loading a model with the same extension name as an + already loaded model.""" + from amici.antimony_import import antimony2amici + + def import_model_module( + module_name: str, module_path: Path | str + ) -> amici.ModelModule: + import importlib + + importlib.invalidate_caches() + module_path = Path(module_path, module_name, "__init__.py") + module = amici._module_from_path(module_name, module_path) + return module + + ant_model_1 = """ + model test_same_extension_error + species A = 0 + p = 1 + A' = p + end + """ + ant_model_2 = ant_model_1.replace("1", "2") + # ant_model_3 = ant_model_1.replace("1", "3") + + module_name = "test_same_extension" + outdir_1 = "deleteme1" + outdir_2 = "deleteme2" + + antimony2amici( + ant_model_1, + model_name=module_name, + output_dir=outdir_1, + compute_conservation_laws=False, + ) + + antimony2amici( + ant_model_2, + model_name=module_name, + output_dir=outdir_2, + compute_conservation_laws=False, + ) + + model_module_1 = import_model_module( + module_name=module_name, module_path=outdir_1 + ) + assert model_module_1.get_model().getParameters()[0] == 1.0 + # no error if the same model is loaded again without changes on disk + model_module_1b = import_model_module( + module_name=module_name, module_path=outdir_1 + ) + # downside: the modules will compare as different + assert (model_module_1 == model_module_1b) is False + assert model_module_1.__file__ == model_module_1b.__file__ + + assert model_module_1.get_model().getParameters()[0] == 1.0 + + # Try to import another model with the same name + + # # On Windows, this will give "permission denied" when building the + # # extension + # if sys.platform == "win32": + # return + + model_module_2 = import_model_module( + module_name=module_name, module_path=outdir_2 + ) + print(model_module_1) + print(model_module_2) + assert model_module_1.get_model().getParameters()[0] == 1.0 + assert model_module_2.get_model().getParameters()[0] == 2.0 + + # replace #2 by #3 + # antimony2amici( + # ant_model_3, model_name=module_name, output_dir=outdir_2, verbose=True + # ) + + # unsupported -- TODO: pytest.raises + # model_module_3 = import_model_module( + # module_name=module_name, module_path=outdir_2 + # ) + # assert model_module_1.get_model().getParameters()[0] == 1.0 + # assert model_module_2.get_model().getParameters()[0] == 2.0 + # assert model_module_3.get_model().getParameters()[0] == 3.0 diff --git a/swig/modelname.template.i b/swig/modelname.template.i index d7aab8ed8a..036c8678d6 100644 --- a/swig/modelname.template.i +++ b/swig/modelname.template.i @@ -1,4 +1,20 @@ -%module TPL_MODELNAME +%define MODULEIMPORT +" +import amici +import importlib.util +import sysconfig +from pathlib import Path + +ext_suffix = sysconfig.get_config_var('EXT_SUFFIX') +_TPL_MODELNAME = amici._module_from_path( + '_TPL_MODELNAME', + Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}', +) +" +%enddef + +%module(package="TPL_MODELNAME",moduleimport=MODULEIMPORT) TPL_MODELNAME + %import amici.i // Add necessary symbols to generated header