Skip to content

Commit

Permalink
Refactor symbolic event processing (#2538)
Browse files Browse the repository at this point in the history
Extract some neutral changes from #1539 to reduce the diff there and avoid conflicts.
  • Loading branch information
dweindl authored Oct 9, 2024
1 parent 6261491 commit b7a3e91
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
34 changes: 25 additions & 9 deletions python/sdist/amici/de_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1617,15 +1617,12 @@ def _compute_equation(self, name: str) -> None:
]
if name == "dzdx":
for ie in range(self.num_events()):
dtaudx = (
-self.eq("drootdx")[ie, :]
/ self.eq("drootdt_total")[ie]
)
dtaudx = self.eq("dtaudx")
for iz in range(self.num_eventobs()):
if ie != self._z2event[iz] - 1:
continue
dzdt = sp.diff(self.eq("z")[ie][iz], time_symbol)
self._eqs[name][ie][iz, :] += dzdt * dtaudx
self._eqs[name][ie][iz, :] += dzdt * -dtaudx[ie]

elif name in ["rz", "drzdx", "drzdp"]:
eq_events = []
Expand All @@ -1644,12 +1641,24 @@ def _compute_equation(self, name: str) -> None:

elif name == "stau":
self._eqs[name] = [
-self.eq("sroot")[ie, :] / self.eq("drootdt_total")[ie]
self.eq("sroot")[ie, :] / self.eq("drootdt_total")[ie]
if not self.eq("drootdt_total")[ie].is_zero
else sp.zeros(*self.eq("sroot")[ie, :].shape)
for ie in range(self.num_events())
]

elif name == "dtaudx":
self._eqs[name] = [
self.eq("drootdx")[ie, :] / self.eq("drootdt_total")[ie]
for ie in range(self.num_events())
]

elif name == "dtaudp":
self._eqs[name] = [
self.eq("drootdp")[ie, :] / self.eq("drootdt_total")[ie]
for ie in range(self.num_events())
]

elif name == "deltasx":
if self.num_states_solver() * self.num_par() == 0:
self._eqs[name] = []
Expand All @@ -1665,7 +1674,7 @@ def _compute_equation(self, name: str) -> None:
self.eq("stau")[ie]
) and not smart_is_zero_matrix(self.eq("xdot")):
tmp_eq += smart_multiply(
self.sym("xdot_old") - self.sym("xdot"),
self.sym("xdot") - self.sym("xdot_old"),
self.sym("stau").T,
)

Expand All @@ -1682,19 +1691,26 @@ def _compute_equation(self, name: str) -> None:
if not smart_is_zero_matrix(self.eq("stau")[ie]):
# chain rule for the time point
tmp_eq += smart_multiply(
self.eq("ddeltaxdt")[ie], self.sym("stau").T
self.eq("ddeltaxdt")[ie],
-self.sym("stau").T,
)

# additional part of chain rule state variables
tmp_dxdp += smart_multiply(
self.sym("xdot_old"), self.sym("stau").T
self.sym("xdot_old"),
-self.sym("stau").T,
)

# finish chain rule for the state variables
tmp_eq += smart_multiply(
self.eq("ddeltaxdx")[ie], tmp_dxdp
)

else:
tmp_eq = smart_multiply(
self.sym("xdot") - self.sym("xdot_old"),
self.eq("stau")[ie],
)
event_eqs.append(tmp_eq)

self._eqs[name] = event_eqs
Expand Down
19 changes: 11 additions & 8 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:
symbolic expressions for arguments of the piecewise function
"""
# how many condition-expression pairs will we have?
formula = sp.Float(0.0)
not_condition = sp.Float(1.0)
formula = sp.Integer(0)
not_condition = sp.Integer(1)

if all(isinstance(arg, ExprCondPair) for arg in args):
# sympy piecewise
Expand All @@ -483,7 +483,7 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:

for coeff, trigger in grouped_args:
if isinstance(coeff, BooleanAtom):
coeff = sp.Float(int(bool(coeff)))
coeff = sp.Integer(int(bool(coeff)))

if trigger == sp.true:
return formula + coeff * not_condition
Expand All @@ -493,7 +493,7 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:

tmp = _parse_heaviside_trigger(trigger)
formula += coeff * sp.simplify(not_condition * tmp)
not_condition *= 1 - tmp
not_condition *= sp.Integer(1) - tmp

return formula

Expand All @@ -517,21 +517,24 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr:
# step with H(0) = 1
if isinstance(trigger, sp.core.relational.StrictLessThan):
# x < y => x - y < 0 => r < 0
return 1 - sp.Heaviside(root)
return sp.Integer(1) - sp.Heaviside(root)
if isinstance(trigger, sp.core.relational.LessThan):
# x <= y => not(y < x) => not(y - x < 0) => not -r < 0
return sp.Heaviside(-root)
if isinstance(trigger, sp.core.relational.StrictGreaterThan):
# y > x => y - x < 0 => -r < 0
return 1 - sp.Heaviside(-root)
return sp.Integer(1) - sp.Heaviside(-root)
if isinstance(trigger, sp.core.relational.GreaterThan):
# y >= x => not(x < y) => not(x - y < 0) => not r < 0
return sp.Heaviside(root)

# or(x,y) = not(and(not(x),not(y))
if isinstance(trigger, sp.Or):
return 1 - sp.Mul(
*[1 - _parse_heaviside_trigger(arg) for arg in trigger.args]
return sp.Integer(1) - sp.Mul(
*[
sp.Integer(1) - _parse_heaviside_trigger(arg)
for arg in trigger.args
]
)

if isinstance(trigger, sp.And):
Expand Down

0 comments on commit b7a3e91

Please sign in to comment.