Skip to content

Commit

Permalink
Refactor de_export.py, extract sympy_utils.py (#2307)
Browse files Browse the repository at this point in the history
No changes in functionality.

Related to #2306.
  • Loading branch information
dweindl authored Feb 26, 2024
1 parent 8271da1 commit aeb5f34
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 218 deletions.
205 changes: 11 additions & 194 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
C++ Export
----------
This module provides all necessary functionality specify an DE model and
This module provides all necessary functionality specify a DE model and
generate executable C++ simulation code. The user generally won't have to
directly call any function from this module as this will be done by
:py:func:`amici.pysb_import.pysb2amici`,
Expand All @@ -18,12 +18,11 @@
import subprocess
import sys
from dataclasses import dataclass
from itertools import chain, starmap
from itertools import chain
from pathlib import Path
from string import Template
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Optional,
Expand Down Expand Up @@ -59,8 +58,17 @@
strip_pysb,
toposort_symbols,
unique_preserve_order,
_default_simplify,
)
from .logging import get_logger, log_execution_time, set_log_level
from .sympy_utils import (
_custom_pow_eval_derivative,
_monkeypatched,
smart_jacobian,
smart_multiply,
smart_is_zero_matrix,
_parallel_applyfunc,
)

if TYPE_CHECKING:
from . import sbml_import
Expand Down Expand Up @@ -509,109 +517,6 @@ def var_in_function_signature(name: str, varname: str, ode: bool) -> bool:
}


@log_execution_time("running smart_jacobian", logger)
def smart_jacobian(
eq: sp.MutableDenseMatrix, sym_var: sp.MutableDenseMatrix
) -> sp.MutableSparseMatrix:
"""
Wrapper around symbolic jacobian with some additional checks that reduce
computation time for large matrices
:param eq:
equation
:param sym_var:
differentiation variable
:return:
jacobian of eq wrt sym_var
"""
nrow = eq.shape[0]
ncol = sym_var.shape[0]
if (
not min(eq.shape)
or not min(sym_var.shape)
or smart_is_zero_matrix(eq)
or smart_is_zero_matrix(sym_var)
):
return sp.MutableSparseMatrix(nrow, ncol, dict())

# preprocess sparsity pattern
elements = (
(i, j, a, b)
for i, a in enumerate(eq)
for j, b in enumerate(sym_var)
if a.has(b)
)

if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1:
# serial
return sp.MutableSparseMatrix(
nrow, ncol, dict(starmap(_jacobian_element, elements))
)

# parallel
from multiprocessing import get_context

# "spawn" should avoid potential deadlocks occurring with fork
# see e.g. https://stackoverflow.com/a/66113051
ctx = get_context("spawn")
with ctx.Pool(n_procs) as p:
mapped = p.starmap(_jacobian_element, elements)
return sp.MutableSparseMatrix(nrow, ncol, dict(mapped))


@log_execution_time("running smart_multiply", logger)
def smart_multiply(
x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix],
y: sp.MutableDenseMatrix,
) -> Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix]:
"""
Wrapper around symbolic multiplication with some additional checks that
reduce computation time for large matrices
:param x:
educt 1
:param y:
educt 2
:return:
product
"""
if (
not x.shape[0]
or not y.shape[1]
or smart_is_zero_matrix(x)
or smart_is_zero_matrix(y)
):
return sp.zeros(x.shape[0], y.shape[1])
return x.multiply(y)


def smart_is_zero_matrix(
x: Union[sp.MutableDenseMatrix, sp.MutableSparseMatrix],
) -> bool:
"""A faster implementation of sympy's is_zero_matrix
Avoids repeated indexer type checks and double iteration to distinguish
False/None. Found to be about 100x faster for large matrices.
:param x: Matrix to check
"""

if isinstance(x, sp.MutableDenseMatrix):
return all(xx.is_zero is True for xx in x.flat())

if isinstance(x, list):
return all(smart_is_zero_matrix(xx) for xx in x)

return x.nnz() == 0


def _default_simplify(x):
"""Default simplification applied in DEModel"""
# We need this as a free function instead of a lambda to have it picklable
# for parallel simplification
return sp.powsimp(x, deep=True)


class DEModel:
"""
Defines a Differential Equation as set of ModelQuantities.
Expand Down Expand Up @@ -4304,94 +4209,6 @@ def is_valid_identifier(x: str) -> bool:
return IDENTIFIER_PATTERN.match(x) is not None


@contextlib.contextmanager
def _monkeypatched(obj: object, name: str, patch: Any):
"""
Temporarily monkeypatches an object.
:param obj:
object to be patched
:param name:
name of the attribute to be patched
:param patch:
patched value
"""
pre_patched_value = getattr(obj, name)
setattr(obj, name, patch)
try:
yield object
finally:
setattr(obj, name, pre_patched_value)


def _custom_pow_eval_derivative(self, s):
"""
Custom Pow derivative that removes a removable singularity for
``self.base == 0`` and ``self.base.diff(s) == 0``. This function is
intended to be monkeypatched into :py:method:`sympy.Pow._eval_derivative`.
:param self:
sp.Pow class
:param s:
variable with respect to which the derivative will be computed
"""
dbase = self.base.diff(s)
dexp = self.exp.diff(s)
part1 = sp.Pow(self.base, self.exp - 1) * self.exp * dbase
part2 = self * dexp * sp.log(self.base)
if self.base.is_nonzero or dbase.is_nonzero or part2.is_zero:
# first piece never applies or is zero anyways
return part1 + part2

return part1 + sp.Piecewise(
(self.base, sp.And(sp.Eq(self.base, 0), sp.Eq(dbase, 0))),
(part2, True),
)


def _jacobian_element(i, j, eq_i, sym_var_j):
"""Compute a single element of a jacobian"""
return (i, j), eq_i.diff(sym_var_j)


def _parallel_applyfunc(obj: sp.Matrix, func: Callable) -> sp.Matrix:
"""Parallel implementation of sympy's Matrix.applyfunc"""
if (n_procs := int(os.environ.get("AMICI_IMPORT_NPROCS", 1))) == 1:
# serial
return obj.applyfunc(func)

# parallel
from multiprocessing import get_context
from pickle import PicklingError

from sympy.matrices.dense import DenseMatrix

# "spawn" should avoid potential deadlocks occurring with fork
# see e.g. https://stackoverflow.com/a/66113051
ctx = get_context("spawn")
with ctx.Pool(n_procs) as p:
try:
if isinstance(obj, DenseMatrix):
return obj._new(obj.rows, obj.cols, p.map(func, obj))
elif isinstance(obj, sp.SparseMatrix):
dok = obj.todok()
mapped = p.map(func, dok.values())
dok = {k: v for k, v in zip(dok.keys(), mapped) if v != 0}
return obj._new(obj.rows, obj.cols, dok)
else:
raise ValueError(f"Unsupported matrix type {type(obj)}")
except PicklingError as e:
raise ValueError(
f"Couldn't pickle {func}. This is likely because the argument "
"was not a module-level function. Either rewrite the argument "
"to a module-level function or disable parallelization by "
"setting `AMICI_IMPORT_NPROCS=1`."
) from e


def _write_gitignore(dest_dir: Path) -> None:
"""Write .gitignore file.
Expand Down
7 changes: 7 additions & 0 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,3 +748,10 @@ def unique_preserve_order(seq: Sequence) -> list:

sbml_time_symbol = symbol_with_assumptions("time")
amici_time_symbol = symbol_with_assumptions("t")


def _default_simplify(x):
"""Default simplification applied in DEModel"""
# We need this as a free function instead of a lambda to have it picklable
# for parallel simplification
return sp.powsimp(x, deep=True)
2 changes: 1 addition & 1 deletion python/sdist/amici/pysb_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
Observable,
Parameter,
SigmaY,
_default_simplify,
)
from .import_utils import (
_get_str_symbol_identifiers,
_parse_special_functions,
generate_measurement_symbol,
noise_distribution_to_cost_function,
noise_distribution_to_observable_transformation,
_default_simplify,
)
from .logging import get_logger, log_execution_time, set_log_level

Expand Down
4 changes: 2 additions & 2 deletions python/sdist/amici/sbml_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
from .de_export import (
DEExporter,
DEModel,
_default_simplify,
smart_is_zero_matrix,
)
from .sympy_utils import smart_is_zero_matrix
from .import_utils import (
RESERVED_SYMBOLS,
_check_unsupported_functions,
Expand All @@ -50,6 +49,7 @@
smart_subs_dict,
symbol_with_assumptions,
toposort_symbols,
_default_simplify,
)
from .logging import get_logger, log_execution_time, set_log_level
from .sbml_utils import SBMLException, _parse_logical_operators
Expand Down
Loading

0 comments on commit aeb5f34

Please sign in to comment.