Skip to content

Commit

Permalink
Improve model extension import
Browse files Browse the repository at this point in the history
Previously, it wasn't possible to import two model modules with the same name.
Now this is at least possible if they are in different locations.
Overwriting and importing a previously imported extension is still not supported.
  • Loading branch information
dweindl committed Nov 27, 2024
1 parent 4ed131b commit 09e1141
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 101 deletions.
121 changes: 43 additions & 78 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@
"""

import contextlib
import datetime
import importlib.util
import importlib
import os
import re
import sys
import sysconfig
from pathlib import Path
from types import ModuleType as ModelModule
from types import ModuleType
from typing import Any
from collections.abc import Callable

Expand Down Expand Up @@ -145,6 +144,8 @@ def get_model(self) -> amici.Model:
def get_jax_model(self) -> JAXModel: ...

AmiciModel = Union[amici.Model, amici.ModelPtr]
else:
ModelModule = ModuleType

Check warning on line 148 in python/sdist/amici/__init__.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/__init__.py#L148

Added line #L148 was not covered by tests


class add_path:
Expand Down Expand Up @@ -182,6 +183,27 @@ def __exit__(self, exc_type, exc_value, traceback):
sys.path = self.orginal_path


def _module_from_path(module_name: str, module_path: Path | str) -> ModuleType:
"""Import a module from a given path.
Import a module from a given path. The module is not added to
`sys.modules`.
:param module_name:
Name of the module.
:param module_path:
Path to the module file. Absolute or relative to the current working
directory.
"""
module_path = Path(module_path).resolve()
if not module_path.is_file():
raise ModuleNotFoundError(f"Module file not found: {module_path}")
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def import_model_module(
module_name: str, module_path: Path | str
) -> ModelModule:
Expand All @@ -195,86 +217,29 @@ def import_model_module(
:return:
The model module
"""
module_path = str(module_path)
model_root = str(module_path)

# ensure we will find the newly created module
importlib.invalidate_caches()

if not os.path.isdir(module_path):
raise ValueError(f"module_path '{module_path}' is not a directory.")

module_path = os.path.abspath(module_path)
ext_suffix = sysconfig.get_config_var("EXT_SUFFIX")
ext_mod_name = f"{module_name}._{module_name}"

# module already loaded?
if (m := sys.modules.get(ext_mod_name)) and m.__file__.endswith(
ext_suffix
):
# this is the c++ extension we can't unload
loaded_file = Path(m.__file__)
needed_file = Path(
module_path,
module_name,
f"_{module_name}{ext_suffix}",
)
# if we import a matlab-generated model where the extension
# is in a different directory
needed_file_matlab = Path(
module_path,
f"_{module_name}{ext_suffix}",
)
if not needed_file.exists():
if needed_file_matlab.exists():
needed_file = needed_file_matlab
else:
raise ModuleNotFoundError(
f"Cannot find extension module for {module_name} in "
f"{module_path}."
)

if not loaded_file.samefile(needed_file):
# this is not the right module, and we can't unload it
raise RuntimeError(
f"Cannot import extension for {module_name} from "
f"{module_path}, because an extension with the same name was "
f"has already been imported from {loaded_file.parent}. "
"Import the module with a different name or restart the "
"Python kernel."
)
# this is the right file, but did it change on disk?
t_imported = m._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(m.__file__)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Cannot import extension for {module_name} from "
f"{module_path}, because an extension in the same location "
f"has already been imported, but the file was modified on "
f"disk. \nImported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)

# unlike extension modules, Python modules can be unloaded
if module_name in sys.modules:
# if a module with that name is already in sys.modules, we remove it,
# along with all other modules from that package. otherwise, there
# will be trouble if two different models with the same name are to
# be imported.
del sys.modules[module_name]
# collect first, don't delete while iterating
to_unload = {
loaded_module_name
for loaded_module_name in sys.modules.keys()
if loaded_module_name.startswith(f"{module_name}.")
}
for m in to_unload:
del sys.modules[m]

with set_path(module_path):
return importlib.import_module(module_name)
raise ValueError(f"module_path '{model_root}' is not a directory.")

module_path = Path(model_root, module_name, "__init__.py")

# We may want to import a matlab-generated model where the extension
# is in a different directory. This is not a regular use case. It's only
# used in the amici tests and can be removed at any time.
# The models (currently) use the default swig-import and require
# modifying sys.path.
module_path_matlab = Path(model_root, f"{module_name}.py")
if not module_path.is_file() and module_path_matlab.is_file():
with set_path(model_root):
return _module_from_path(module_name, module_path_matlab)

module = _module_from_path(module_name, module_path)
module._self = module
return module


class AmiciVersionError(RuntimeError):
Expand Down
38 changes: 33 additions & 5 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
"""AMICI-generated module for model TPL_MODELNAME"""

import datetime
import os
from pathlib import Path
from typing import TYPE_CHECKING
import amici

# this module; will be set during import
_self = None

if TYPE_CHECKING:
from amici.jax import JAXModel

Expand All @@ -18,14 +23,37 @@
"version currently installed."
)

from .TPL_MODELNAME import * # noqa: F403, F401
from .TPL_MODELNAME import getModel as get_model # noqa: F401
TPL_MODELNAME = amici._module_from_path(
"TPL_MODELNAME", Path(__file__).parent / "TPL_MODELNAME.py"
)
for var in dir(TPL_MODELNAME):
if not var.startswith("__"):
globals()[var] = getattr(TPL_MODELNAME, var)
get_model = TPL_MODELNAME.getModel
TPL_MODELNAME._model_module = _self


def get_jax_model() -> "JAXModel":
from .jax import JAXModel_TPL_MODELNAME

return JAXModel_TPL_MODELNAME()
# If the model directory was meanwhile overwritten, this would load the
# new version, which would not match the previously imported extension.
# This is not allowed, as it would lead to inconsistencies.
jax_py_file = Path(__file__).parent / "jax.py"
jax_py_file = jax_py_file.resolve()
t_imported = TPL_MODELNAME._get_import_time() # noqa: protected-access
t_modified = os.path.getmtime(jax_py_file)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
raise RuntimeError(
f"Refusing to import {jax_py_file} which was changed since "
f"TPL_MODELNAME was imported. This is to avoid inconsistencies "
"between the different model implementations.\n"
f"Imported at {t_imp_str}\nModified at {t_mod_str}.\n"
"Import the module with a different name or restart the "
"Python kernel."
)
jax = amici._module_from_path("jax", jax_py_file)
return jax.JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
61 changes: 46 additions & 15 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,11 +763,14 @@ def test_constraints():


@skip_on_valgrind
def test_same_extension_error():
def test_import_same_model_name():
"""Test for error when loading a model with the same extension name as an
already loaded model."""
from amici.antimony_import import antimony2amici
from amici import import_model_module

# create three versions of a toy model with different parameter values
# to detect which model was loaded
ant_model_1 = """
model test_same_extension_error
species A = 0
Expand All @@ -776,40 +779,68 @@ def test_same_extension_error():
end
"""
ant_model_2 = ant_model_1.replace("1", "2")
ant_model_3 = ant_model_1.replace("1", "3")

module_name = "test_same_extension"
with TemporaryDirectory(prefix=module_name, delete=False) as outdir:
outdir_1 = Path(outdir, "model_1")
outdir_2 = Path(outdir, "model_2")

# import the first two models, with the same name,
# but in different location (this is now supported)
antimony2amici(
ant_model_1,
model_name=module_name,
output_dir=outdir,
output_dir=outdir_1,
compute_conservation_laws=False,
)
model_module_1 = amici.import_model_module(
module_name=module_name, module_path=outdir

antimony2amici(
ant_model_2,
model_name=module_name,
output_dir=outdir_2,
compute_conservation_laws=False,
)

model_module_1 = import_model_module(
module_name=module_name, module_path=outdir_1
)
assert model_module_1.get_model().getParameters()[0] == 1.0

# no error if the same model is loaded again without changes on disk
model_module_1 = amici.import_model_module(
module_name=module_name, module_path=outdir
model_module_1b = import_model_module(
module_name=module_name, module_path=outdir_1
)
# downside: the modules will compare as different
assert (model_module_1 == model_module_1b) is False
assert model_module_1.__file__ == model_module_1b.__file__
assert model_module_1b.get_model().getParameters()[0] == 1.0

model_module_2 = import_model_module(
module_name=module_name, module_path=outdir_2
)
assert model_module_1.get_model().getParameters()[0] == 1.0
assert model_module_2.get_model().getParameters()[0] == 2.0

# Try to import another model with the same name
# import the third model, with the same name and location as the second
# model -- this is not supported, because there is some caching at
# the C level we cannot control (or don't know how to)

# On Windows, this will give "permission denied" when building the
# extension
# extension, because we cannot delete a shared library that is in use

if sys.platform == "win32":
return

antimony2amici(
ant_model_2,
ant_model_3,
model_name=module_name,
output_dir=outdir,
compute_conservation_laws=False,
output_dir=outdir_2,
)
with pytest.raises(RuntimeError, match="has already been imported"):
amici.import_model_module(
module_name=module_name, module_path=outdir
)

with pytest.raises(RuntimeError, match="in the same location"):
import_model_module(module_name=module_name, module_path=outdir_2)

# this should not affect the previously loaded models
assert model_module_1.get_model().getParameters()[0] == 1.0
assert model_module_2.get_model().getParameters()[0] == 2.0
45 changes: 42 additions & 3 deletions swig/modelname.template.i
Original file line number Diff line number Diff line change
@@ -1,4 +1,44 @@
%module TPL_MODELNAME
%define MODULEIMPORT
"
import amici
import datetime
import importlib.util
import os
import sysconfig
from pathlib import Path
# the model-package __init__.py module (will be set during import)
_model_module = None
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
_TPL_MODELNAME = amici._module_from_path(
'TPL_MODELNAME._TPL_MODELNAME' if __package__ or '.' in __name__
else '_TPL_MODELNAME',
Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}',
)
def _get_import_time():
return _TPL_MODELNAME._get_import_time()
t_imported = _get_import_time()
t_modified = os.path.getmtime(__file__)
if t_imported < t_modified:
t_imp_str = datetime.datetime.fromtimestamp(t_imported).isoformat()
t_mod_str = datetime.datetime.fromtimestamp(t_modified).isoformat()
module_path = Path(__file__).resolve()
raise RuntimeError(
f'Cannot import extension for TPL_MODELNAME from '
f'{module_path}, because an extension in the same location '
f'has already been imported, but the file was modified on '
f'disk. \\nImported at {t_imp_str}\\nModified at {t_mod_str}.\\n'
'Import the module with a different name or restart the '
'Python kernel.'
)
"
%enddef

%module(package="TPL_MODELNAME",moduleimport=MODULEIMPORT) TPL_MODELNAME

%import amici.i
// Add necessary symbols to generated header

Expand Down Expand Up @@ -30,8 +70,7 @@ static double _get_import_time();
// Make model module accessible from the model
%feature("pythonappend") amici::generic_model::getModel %{
if '.' in __name__:
import sys
val.module = sys.modules['.'.join(__name__.split('.')[:-1])]
val.module = _model_module
%}


Expand Down

0 comments on commit 09e1141

Please sign in to comment.