Skip to content

Commit

Permalink
Fix missing toposort after rateOf-substitutions in w (#2291)
Browse files Browse the repository at this point in the history
Fixes potentially incorrect simulation results when using rateOf in `w`
where the rates depend on `w`.

Fixes #2290
  • Loading branch information
dweindl authored Feb 21, 2024
1 parent 74d4e1f commit f0c7c59
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 12 deletions.
82 changes: 71 additions & 11 deletions python/sdist/amici/de_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
Union,
)
from collections.abc import Sequence

import numpy as np
import sympy as sp
from sympy.matrices.dense import MutableDenseMatrix
Expand Down Expand Up @@ -1117,20 +1116,17 @@ def transform_dxdt_to_concentration(species_id, dxdt):
for llh in si.symbols[SymbolId.LLHY].values()
)

self._process_sbml_rate_of(
symbols
) # substitute SBML-rateOf constructs
# substitute SBML-rateOf constructs
self._process_sbml_rate_of()

def _process_sbml_rate_of(self, symbols) -> None:
def _process_sbml_rate_of(self) -> None:
"""Substitute any SBML-rateOf constructs in the model equations"""
rate_of_func = sp.core.function.UndefinedFunction("rateOf")
species_sym_to_xdot = dict(zip(self.sym("x"), self.sym("xdot")))
species_sym_to_idx = {x: i for i, x in enumerate(self.sym("x"))}

def get_rate(symbol: sp.Symbol):
"""Get rate of change of the given symbol"""
nonlocal symbols

if symbol.find(rate_of_func):
raise SBMLException("Nesting rateOf() is not allowed.")

Expand All @@ -1142,6 +1138,7 @@ def get_rate(symbol: sp.Symbol):
return 0

# replace rateOf-instances in xdot by xdot symbols
made_substitutions = False
for i_state in range(len(self.eq("xdot"))):
if rate_ofs := self._eqs["xdot"][i_state].find(rate_of_func):
self._eqs["xdot"][i_state] = self._eqs["xdot"][i_state].subs(
Expand All @@ -1151,9 +1148,14 @@ def get_rate(symbol: sp.Symbol):
for rate_of in rate_ofs
}
)
# substitute in topological order
subs = toposort_symbols(dict(zip(self.sym("xdot"), self.eq("xdot"))))
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)
made_substitutions = True

if made_substitutions:
# substitute in topological order
subs = toposort_symbols(
dict(zip(self.sym("xdot"), self.eq("xdot")))
)
self._eqs["xdot"] = smart_subs_dict(self.eq("xdot"), subs)

# replace rateOf-instances in x0 by xdot equation
for i_state in range(len(self.eq("x0"))):
Expand All @@ -1165,9 +1167,55 @@ def get_rate(symbol: sp.Symbol):
}
)

# replace rateOf-instances in w by xdot equation
# here we may need toposort, as xdot may depend on w
made_substitutions = False
for i_expr in range(len(self.eq("w"))):
if rate_ofs := self._eqs["w"][i_expr].find(rate_of_func):
self._eqs["w"][i_expr] = self._eqs["w"][i_expr].subs(
{
rate_of: get_rate(rate_of.args[0])
for rate_of in rate_ofs
}
)
made_substitutions = True

if made_substitutions:
# Sort expressions in self._expressions, w symbols, and w equations
# in topological order. Ideally, this would already happen before
# adding the expressions to the model, but at that point, we don't
# have access to xdot yet.
# NOTE: elsewhere, conservations law expressions are expected to
# occur before any other w expressions, so we must maintain their
# position
# toposort everything but conservation law expressions,
# then prepend conservation laws
w_sorted = toposort_symbols(
dict(
zip(
self.sym("w")[self.num_cons_law() :, :],
self.eq("w")[self.num_cons_law() :, :],
)
)
)
w_sorted = (
dict(
zip(
self.sym("w")[: self.num_cons_law(), :],
self.eq("w")[: self.num_cons_law(), :],
)
)
| w_sorted
)
old_syms = tuple(self._syms["w"])
topo_expr_syms = tuple(w_sorted.keys())
new_order = [old_syms.index(s) for s in topo_expr_syms]
self._expressions = [self._expressions[i] for i in new_order]
self._syms["w"] = sp.Matrix(topo_expr_syms)
self._eqs["w"] = sp.Matrix(list(w_sorted.values()))

for component in chain(
self.observables(),
self.expressions(),
self.events(),
self._algebraic_equations,
):
Expand Down Expand Up @@ -2210,6 +2258,18 @@ def _compute_equation(self, name: str) -> None:
self._eqs[name] = self.sym(name)

elif name == "dwdx":
if (
expected := list(
map(
ConservationLaw.get_x_rdata,
reversed(self.conservation_laws()),
)
)
) != (actual := self.eq("w")[: self.num_cons_law()]):
raise AssertionError(
"Conservation laws are not at the beginning of 'w'. "
f"Got {actual}, expected {expected}."
)
x = self.sym("x")
self._eqs[name] = sp.Matrix(
[
Expand Down
54 changes: 53 additions & 1 deletion python/tests/test_sbml_import_special_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from amici.antimony_import import antimony2amici
from amici.gradient_check import check_derivatives
from amici.testing import TemporaryDirectoryWinSafe, skip_on_valgrind
from numpy.testing import assert_approx_equal, assert_array_almost_equal_nulp
from numpy.testing import (
assert_approx_equal,
assert_array_almost_equal_nulp,
assert_allclose,
)
from scipy.special import loggamma


Expand Down Expand Up @@ -222,3 +226,51 @@ def test_rateof():
assert_array_almost_equal_nulp(
rdata.by_id("p2"), 1 + rdata.by_id("S1")
)


@skip_on_valgrind
def test_rateof_with_expression_dependent_rate():
"""Test rateOf, where the rateOf argument depends on `w` and requires
toposorting."""
ant_model = """
model test_rateof_with_expression_dependent_rate
S1 = 0;
S2 = 0;
S1' = rate;
S2' = 2 * rateOf(S1);
# the id of the following expression must be alphabetically before
# `rate`, so that toposort is required to evaluate the expressions
# in the correct order
e1 := 2 * rateOf(S1);
rate := time
end
"""
module_name = "test_rateof_with_expression_dependent_rate"
with TemporaryDirectoryWinSafe(prefix=module_name) as outdir:
antimony2amici(
ant_model,
model_name=module_name,
output_dir=outdir,
)
model_module = amici.import_model_module(
module_name=module_name, module_path=outdir
)
amici_model = model_module.getModel()
t = np.linspace(0, 10, 11)
amici_model.setTimepoints(t)
amici_solver = amici_model.getSolver()
rdata = amici.runAmiciSimulation(amici_model, amici_solver)

state_ids_solver = amici_model.getStateIdsSolver()

assert_array_almost_equal_nulp(rdata.by_id("e1"), 2 * t, 1)

i_S1 = state_ids_solver.index("S1")
i_S2 = state_ids_solver.index("S2")
assert_approx_equal(rdata["xdot"][i_S1], t[-1])
assert_approx_equal(rdata["xdot"][i_S2], 2 * t[-1])

assert_allclose(np.diff(rdata.by_id("S1")), t[:-1] + 0.5, atol=1e-9)
assert_array_almost_equal_nulp(
rdata.by_id("S2"), 2 * rdata.by_id("S1"), 10
)

0 comments on commit f0c7c59

Please sign in to comment.