Skip to content

Commit

Permalink
Merge branch 'develop' into require_py310
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Apr 20, 2024
2 parents 2102744 + e8493cf commit d7805d6
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 45 deletions.
11 changes: 9 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,16 @@ elseif(AMICI_TRY_ENABLE_HDF5)
endif()

set(VENDORED_SUNDIALS_DIR ${CMAKE_CURRENT_SOURCE_DIR}/ThirdParty/sundials)
set(VENDORED_SUNDIALS_BUILD_DIR ${VENDORED_SUNDIALS_DIR}/build)
set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR})
set(SUNDIALS_PRIVATE_INCLUDE_DIRS "${VENDORED_SUNDIALS_DIR}/src")
# Handle different sundials build/install dirs, depending on whether we are
# building the Python extension only or the full C++ interface
if(AMICI_PYTHON_BUILD_EXT_ONLY)
set(VENDORED_SUNDIALS_BUILD_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR})
else()
set(VENDORED_SUNDIALS_BUILD_DIR ${VENDORED_SUNDIALS_DIR}/build)
set(VENDORED_SUNDIALS_INSTALL_DIR ${VENDORED_SUNDIALS_BUILD_DIR})
endif()
find_package(
SUNDIALS REQUIRED PATHS
"${VENDORED_SUNDIALS_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/cmake/sundials/")
Expand Down
28 changes: 21 additions & 7 deletions documentation/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,33 @@
import subprocess
import sys
from enum import EnumType

# need to import before setting typing.TYPE_CHECKING=True, fails otherwise
import amici
import exhale.deploy
import exhale_multiproject_monkeypatch
from unittest import mock
import pandas as pd
import sphinx
import sympy as sp
from exhale import configs as exhale_configs
from sphinx.transforms.post_transforms import ReferencesResolver

exhale_multiproject_monkeypatch, pd, sp # to avoid removal of unused import
try:
import exhale_multiproject_monkeypatch # noqa: F401
except ModuleNotFoundError:
# for unclear reasons, the import of exhale_multiproject_monkeypatch
# fails on some systems, because the the location of the editable install
# is not automatically added to sys.path ¯\_(ツ)_/¯
from importlib.metadata import Distribution
import json
from urllib.parse import unquote_plus, urlparse

dist = Distribution.from_name("sphinx-contrib-exhale-multiproject")
url = json.loads(dist.read_text("direct_url.json"))["url"]
package_dir = unquote_plus(urlparse(url).path)
sys.path.append(package_dir)
import exhale_multiproject_monkeypatch # noqa: F401

# need to import before setting typing.TYPE_CHECKING=True, fails otherwise
import amici
import pandas as pd # noqa: F401
import sympy as sp # noqa: F401


# BEGIN Monkeypatch exhale
from exhale.deploy import _generate_doxygen as exhale_generate_doxygen
Expand Down
17 changes: 10 additions & 7 deletions python/sdist/amici/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,18 @@ def __getitem__(self, item: str) -> Union[np.ndarray, float]:
if item in self._cache:
return self._cache[item]

if item == "id":
return getattr(self._swigptr, item)
if item in self._field_names:
value = _field_as_numpy(
self._field_dimensions, item, self._swigptr
)
self._cache[item] = value

if item not in self._field_names:
self.__missing__(item)
return value

if not item.startswith("_") and hasattr(self._swigptr, item):
return getattr(self._swigptr, item)

value = _field_as_numpy(self._field_dimensions, item, self._swigptr)
self._cache[item] = value
return value
self.__missing__(item)

def __missing__(self, key: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion python/sdist/amici/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def plot_observable_trajectories(
if not ax:
fig, ax = plt.subplots()
if not observable_indices:
observable_indices = range(rdata["y"].shape[1])
observable_indices = range(rdata.ny)

if marker is None:
# Show marker if only one time point is available,
Expand Down
75 changes: 55 additions & 20 deletions python/sdist/amici/swig.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Functions related to SWIG or SWIG-generated code"""
from __future__ import annotations
import ast
import contextlib
import re


class TypeHintFixer(ast.NodeTransformer):
"""Replaces SWIG-generated C++ typehints by corresponding Python types"""
"""Replaces SWIG-generated C++ typehints by corresponding Python types."""

mapping = {
"void": None,
Expand Down Expand Up @@ -53,9 +54,13 @@ class TypeHintFixer(ast.NodeTransformer):
"std::allocator< amici::ParameterScaling > > const &": ast.Constant(
"ParameterScalingVector"
),
"H5::H5File": None,
}

def visit_FunctionDef(self, node):
# convert type/rtype from docstring to annotation, if possible.
# those may be c++ types, not valid in python, that need to be
# converted to python types below.
self._annotation_from_docstring(node)

# Has a return type annotation?
Expand All @@ -67,14 +72,17 @@ def visit_FunctionDef(self, node):
for arg in node.args.args:
if not arg.annotation:
continue
if isinstance(arg.annotation, ast.Name):
if not isinstance(arg.annotation, ast.Constant):
# there is already proper annotation
continue

arg.annotation = self._new_annot(arg.annotation.value)
return node

def _new_annot(self, old_annot: str):
def _new_annot(self, old_annot: str | ast.Name):
if isinstance(old_annot, ast.Name):
old_annot = old_annot.id

with contextlib.suppress(KeyError):
return self.mapping[old_annot]

Expand Down Expand Up @@ -117,6 +125,8 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
Swig sometimes generates ``:type solver: :py:class:`Solver`` instead of
``:type solver: Solver``. Those need special treatment.
Overloaded functions are skipped.
"""
docstring = ast.get_docstring(node, clean=False)
if not docstring or "*Overload 1:*" in docstring:
Expand All @@ -127,22 +137,18 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
lines_to_remove = set()

for line_no, line in enumerate(docstring):
if (
match := re.match(
r"\s*:rtype:\s*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
node.returns = ast.Constant(match.group(1))
if type_str := self.extract_rtype(line):
# handle `:rtype:`
node.returns = ast.Constant(type_str)
lines_to_remove.add(line_no)
continue

if (
match := re.match(
r"\s*:type\s*(\w+):\W*(?::py:class:`)?(\w+)`?\s+$", line
)
) and not match.group(1).startswith(":"):
arg_name, type_str = self.extract_type(line)
if arg_name is not None:
# handle `:type ...:`
for arg in node.args.args:
if arg.arg == match.group(1):
arg.annotation = ast.Constant(match.group(2))
if arg.arg == arg_name:
arg.annotation = ast.Constant(type_str)
lines_to_remove.add(line_no)

if lines_to_remove:
Expand All @@ -155,13 +161,42 @@ def _annotation_from_docstring(self, node: ast.FunctionDef):
)
node.body[0].value = ast.Str(new_docstring)

@staticmethod
def extract_type(line: str) -> tuple[str, str] | tuple[None, None]:
"""Extract argument name and type string from ``:type:`` docstring
line."""
match = re.match(r"\s*:type\s+(\w+):\s+(.+?)(?:, optional)?\s*$", line)
if not match:
return None, None

arg_name = match.group(1)

# get rid of any :py:class`...` in the type string if necessary
if not match.group(2).startswith(":py:"):
return arg_name, match.group(2)

match = re.match(r":py:\w+:`(.+)`", match.group(2))
assert match
return arg_name, match.group(1)

@staticmethod
def extract_rtype(line: str) -> str | None:
"""Extract type string from ``:rtype:`` docstring line."""
match = re.match(r"\s*:rtype:\s+(.+)\s*$", line)
if not match:
return None

# get rid of any :py:class`...` in the type string if necessary
if not match.group(1).startswith(":py:"):
return match.group(1)

match = re.match(r":py:\w+:`(.+)`", match.group(1))
assert match
return match.group(1)


def fix_typehints(infilename, outfilename):
"""Change SWIG-generated C++ typehints to Python typehints"""
# Only available from Python3.9
if not getattr(ast, "unparse", None):
return

# file -> AST
with open(infilename) as f:
source = f.read()
Expand Down
3 changes: 3 additions & 0 deletions python/tests/test_swig_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ def test_rdataview(sbml_example_presimulation_module):
rdata = amici.runAmiciSimulation(model, model.getSolver())
assert isinstance(rdata, amici.ReturnDataView)

# check that non-array attributes are looked up in the wrapped object
assert rdata.ptr.ny == rdata.ny

# fields are accessible via dot notation and [] operator,
# __contains__ and __getattr__ are implemented correctly
with pytest.raises(AttributeError):
Expand Down
3 changes: 0 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1671,7 +1671,6 @@ int Model::checkFinite(
&& model_quantity != ModelQuantity::ts) {
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
if (!always_check_finite_ && model_quantity != ModelQuantity::w) {
// don't check twice if always_check_finite_ is true
checkFinite(derived_state_.w_, ModelQuantity::w, t);
Expand Down Expand Up @@ -1789,7 +1788,6 @@ int Model::checkFinite(
// check upstream
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
checkFinite(derived_state_.w_, ModelQuantity::w, t);

return AMICI_RECOVERABLE_ERROR;
Expand Down Expand Up @@ -1880,7 +1878,6 @@ int Model::checkFinite(SUNMatrix m, ModelQuantity model_quantity, realtype t)
// check upstream
checkFinite(state_.fixedParameters, ModelQuantity::k, t);
checkFinite(state_.unscaledParameters, ModelQuantity::p, t);
checkFinite(simulation_parameters_.ts_, ModelQuantity::ts, t);
checkFinite(derived_state_.w_, ModelQuantity::w, t);

return AMICI_RECOVERABLE_ERROR;
Expand Down
10 changes: 5 additions & 5 deletions src/solver_idas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
namespace amici {

/*
* The following static members are callback function to CVODES.
* The following static members are callback function to IDAS.
* Their signatures must not be changes.
*/

Expand Down Expand Up @@ -437,7 +437,7 @@ void IDASolver::reInitPostProcess(

auto status = IDASetStopTime(ida_mem, tout);
if (status != IDA_SUCCESS)
throw IDAException(status, "CVodeSetStopTime");
throw IDAException(status, "IDASetStopTime");

status = IDASolve(
ami_mem, tout, t, yout->getNVector(), ypout->getNVector(), IDA_ONE_STEP
Expand Down Expand Up @@ -853,7 +853,7 @@ void IDASolver::setNonLinearSolver() const {
solver_memory_.get(), non_linear_solver_->get()
);
if (status != IDA_SUCCESS)
throw CvodeException(status, "CVodeSetNonlinearSolver");
throw IDAException(status, "IDASetNonlinearSolver");
}

void IDASolver::setNonLinearSolverSens() const {
Expand Down Expand Up @@ -883,15 +883,15 @@ void IDASolver::setNonLinearSolverSens() const {
}

if (status != IDA_SUCCESS)
throw CvodeException(status, "CVodeSolver::setNonLinearSolverSens");
throw IDAException(status, "IDASolver::setNonLinearSolverSens");
}

void IDASolver::setNonLinearSolverB(int which) const {
int status = IDASetNonlinearSolverB(
solver_memory_.get(), which, non_linear_solver_B_->get()
);
if (status != IDA_SUCCESS)
throw CvodeException(status, "CVodeSetNonlinearSolverB");
throw IDAException(status, "IDASetNonlinearSolverB");
}

/**
Expand Down
2 changes: 2 additions & 0 deletions swig/misc.i
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
%ignore amici::regexErrorToString;
%ignore amici::writeSlice;
%ignore ContextManager;
%ignore amici::scaleParameters;
%ignore amici::unscaleParameters;

// Add necessary symbols to generated header
%{
Expand Down

0 comments on commit d7805d6

Please sign in to comment.