Skip to content

Commit

Permalink
Raise during unsupported model imports
Browse files Browse the repository at this point in the history
We can't import multiple extensions with the same name.
Raise if this is attempted.

Related to #1936.
  • Loading branch information
dweindl committed Nov 25, 2024
1 parent 7b9340f commit 8d9b827
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 2 deletions.
38 changes: 36 additions & 2 deletions python/sdist/amici/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import re
import sys
import sysconfig
from pathlib import Path
from types import ModuleType as ModelModule
from typing import Any
Expand Down Expand Up @@ -140,8 +141,7 @@ def get_model(self) -> amici.Model:
"""Create a model instance."""
...

def get_jax_model(self) -> JAXModel:
...
def get_jax_model(self) -> JAXModel: ...

AmiciModel = Union[amici.Model, amici.ModelPtr]

Expand Down Expand Up @@ -183,8 +183,42 @@ def import_model_module(
raise ValueError(f"module_path '{module_path}' is not a directory.")

module_path = os.path.abspath(module_path)
ext_suffix = sysconfig.get_config_var("EXT_SUFFIX")
ext_mod_name = f"{module_name}._{module_name}"

# module already loaded?
if (m := sys.modules.get(ext_mod_name)) and m.__file__.endswith(
ext_suffix
):
# this is the c++ extension we can't unload
loaded_file = Path(m.__file__)
needed_file = Path(
module_path,
module_name,
f"_{module_name}{ext_suffix}",
)
if not loaded_file.samefile(needed_file):
# this is not the right module, and we can't unload it
raise RuntimeError(
f"Cannot import extension for {module_name} from "
f"{module_path}, because an extension with the same name was "
f"has already been imported from {loaded_file.parent}. "
"Import the module with a different name or restart the "
"Python kernel."
)
# this is the right file, but did it change on disk?
t_imported = m.get_import_time()
t_modified = os.path.getmtime(m.__file__)
if t_imported < t_modified:
raise RuntimeError(
f"Cannot import extension for {module_name} from "
f"{module_path}, because an extension in the same location "
f"has already been imported, but the file was modified on "
"disk. Import the module with a different name or restart the "
"Python kernel."
)

# unlike extension modules, Python modules can be unloaded
if module_name in sys.modules:
# if a module with that name is already in sys.modules, we remove it,
# along with all other modules from that package. otherwise, there
Expand Down
45 changes: 45 additions & 0 deletions python/tests/test_sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,3 +773,48 @@ def test_constraints():
amici_solver.getAbsoluteTolerance(),
)
)


@skip_on_valgrind
def test_same_extension_warning():
"""Test for error when loading a model with the same extension name as an
already loaded model."""
from amici.antimony_import import antimony2amici

ant_model_1 = """
model test_same_extension_warning
species A = 0
p = 1
A' = p
end
"""
ant_model_2 = ant_model_1.replace("1", "2")

module_name = "test_same_extension"
with TemporaryDirectory(prefix=module_name, delete=False) as outdir:
antimony2amici(
ant_model_1,
model_name=module_name,
output_dir=outdir,
compute_conservation_laws=False,
)
model_module_1 = amici.import_model_module(
module_name=module_name, module_path=outdir
)
assert model_module_1.get_model().getParameters()[0] == 1.0
# no error if the same model is loaded again without changes on disk
model_module_1 = amici.import_model_module(
module_name=module_name, module_path=outdir
)
assert model_module_1.get_model().getParameters()[0] == 1.0
antimony2amici(
ant_model_2,
model_name=module_name,
output_dir=outdir,
compute_conservation_laws=False,
)
with pytest.raises(RuntimeError, match="has already been imported"):
amici.import_model_module(
module_name=module_name, module_path=outdir
)
assert model_module_1.get_model().getParameters()[0] == 1.0
20 changes: 20 additions & 0 deletions swig/modelname.template.i
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,26 @@
using namespace amici;
%}

// store the time a module was imported
%{
#include <ctime>
static std::time_t _module_import_time;

static std::time_t get_module_import_time() {
return _module_import_time;
}

static double get_import_time() {
return static_cast<double>(get_module_import_time());
}
%}

static double get_import_time();

%init %{
_module_import_time = std::time(nullptr);
%}


// Make model module accessible from the model
%feature("pythonappend") amici::generic_model::getModel %{
Expand Down

0 comments on commit 8d9b827

Please sign in to comment.