Skip to content

Commit

Permalink
Improve model extension import
Browse files Browse the repository at this point in the history
Previously, it wasn't possible to import two model modules with the same name.
Now this is at least possible if they are in different locations.
Overwriting and importing a previously imported extension is still not supported.
  • Loading branch information
dweindl committed Nov 27, 2024
1 parent 4ed131b commit 6ae0638
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 8 deletions.
27 changes: 25 additions & 2 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@

import contextlib
import datetime
import importlib
import importlib.util
import os
import re
import sys
import sysconfig
from pathlib import Path
from types import ModuleType as ModelModule
from types import ModuleType
from typing import Any
from collections.abc import Callable

Expand Down Expand Up @@ -145,6 +145,8 @@ def get_model(self) -> amici.Model:
def get_jax_model(self) -> JAXModel: ...

AmiciModel = Union[amici.Model, amici.ModelPtr]
else:
ModelModule = ModuleType

Check warning on line 149 in python/sdist/amici/__init__.py

View check run for this annotation

Codecov / codecov/patch

python/sdist/amici/__init__.py#L149

Added line #L149 was not covered by tests


class add_path:
Expand Down Expand Up @@ -182,6 +184,27 @@ def __exit__(self, exc_type, exc_value, traceback):
sys.path = self.orginal_path


def _module_from_path(module_name: str, module_path: Path | str) -> ModuleType:
"""Import a module from a given path.
Import a module from a given path. The module is not added to
`sys.modules`.
:param module_name:
Name of the module.
:param module_path:
Path to the module file. Absolute or relative to the current working
directory.
"""
module_path = Path(module_path).resolve()
if not module_path.is_file():
raise ModuleNotFoundError(f"Module file not found: {module_path}")
spec = importlib.util.spec_from_file_location(module_name, module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


def import_model_module(
module_name: str, module_path: Path | str
) -> ModelModule:
Expand Down
17 changes: 12 additions & 5 deletions python/sdist/amici/__init__.template.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,21 @@
"version currently installed."
)

from .TPL_MODELNAME import * # noqa: F403, F401
from .TPL_MODELNAME import getModel as get_model # noqa: F401
# from .TPL_MODELNAME import * # noqa: F403, F401
# from .TPL_MODELNAME import getModel as get_model # noqa: F401
TPL_MODELNAME = amici._module_from_path(
"TPL_MODELNAME", Path(__file__).parent / "TPL_MODELNAME.py"
)
for var in dir(TPL_MODELNAME):
if not var.startswith("__"):
globals()[var] = getattr(TPL_MODELNAME, var)
get_model = TPL_MODELNAME.getModel


def get_jax_model() -> "JAXModel":
from .jax import JAXModel_TPL_MODELNAME

return JAXModel_TPL_MODELNAME()
# from .jax import JAXModel_TPL_MODELNAME
jax = amici._module_from_path("jax", Path(__file__).parent / "jax.py")
return jax.JAXModel_TPL_MODELNAME()


__version__ = "TPL_PACKAGE_VERSION"
86 changes: 86 additions & 0 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,3 +813,89 @@ def test_same_extension_error():
module_name=module_name, module_path=outdir
)
assert model_module_1.get_model().getParameters()[0] == 1.0


def test_same_extension_no_error():
"""Test for error when loading a model with the same extension name as an
already loaded model."""
from amici.antimony_import import antimony2amici

def import_model_module(
module_name: str, module_path: Path | str
) -> amici.ModelModule:
import importlib

importlib.invalidate_caches()
module_path = Path(module_path, module_name, "__init__.py")
module = amici._module_from_path(module_name, module_path)
return module

ant_model_1 = """
model test_same_extension_error
species A = 0
p = 1
A' = p
end
"""
ant_model_2 = ant_model_1.replace("1", "2")
# ant_model_3 = ant_model_1.replace("1", "3")

module_name = "test_same_extension"
outdir_1 = "deleteme1"
outdir_2 = "deleteme2"

antimony2amici(
ant_model_1,
model_name=module_name,
output_dir=outdir_1,
compute_conservation_laws=False,
)

antimony2amici(
ant_model_2,
model_name=module_name,
output_dir=outdir_2,
compute_conservation_laws=False,
)

model_module_1 = import_model_module(
module_name=module_name, module_path=outdir_1
)
assert model_module_1.get_model().getParameters()[0] == 1.0
# no error if the same model is loaded again without changes on disk
model_module_1b = import_model_module(
module_name=module_name, module_path=outdir_1
)
# downside: the modules will compare as different
assert (model_module_1 == model_module_1b) is False
assert model_module_1.__file__ == model_module_1b.__file__

assert model_module_1.get_model().getParameters()[0] == 1.0

# Try to import another model with the same name

# # On Windows, this will give "permission denied" when building the
# # extension
# if sys.platform == "win32":
# return

model_module_2 = import_model_module(
module_name=module_name, module_path=outdir_2
)
print(model_module_1)
print(model_module_2)
assert model_module_1.get_model().getParameters()[0] == 1.0
assert model_module_2.get_model().getParameters()[0] == 2.0

# replace #2 by #3
# antimony2amici(
# ant_model_3, model_name=module_name, output_dir=outdir_2, verbose=True
# )

# unsupported -- TODO: pytest.raises
# model_module_3 = import_model_module(
# module_name=module_name, module_path=outdir_2
# )
# assert model_module_1.get_model().getParameters()[0] == 1.0
# assert model_module_2.get_model().getParameters()[0] == 2.0
# assert model_module_3.get_model().getParameters()[0] == 3.0
18 changes: 17 additions & 1 deletion swig/modelname.template.i
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
%module TPL_MODELNAME
%define MODULEIMPORT
"
import amici
import importlib.util
import sysconfig
from pathlib import Path
ext_suffix = sysconfig.get_config_var('EXT_SUFFIX')
_TPL_MODELNAME = amici._module_from_path(
'_TPL_MODELNAME',
Path(__file__).parent / f'_TPL_MODELNAME{ext_suffix}',
)
"
%enddef

%module(package="TPL_MODELNAME",moduleimport=MODULEIMPORT) TPL_MODELNAME

%import amici.i
// Add necessary symbols to generated header

Expand Down

0 comments on commit 6ae0638

Please sign in to comment.