From 63b17783077e0c1d532da3f5c725e27595bb2acf Mon Sep 17 00:00:00 2001 From: Daniel Weindl Date: Wed, 9 Oct 2024 08:34:43 +0200 Subject: [PATCH] Refactor symbolic event processing Extract some neutral changes from #1539 to reduce the diff there and avoid conflicts. --- python/sdist/amici/de_model.py | 34 ++++++++++++++++++++++-------- python/sdist/amici/import_utils.py | 19 ++++++++++------- 2 files changed, 36 insertions(+), 17 deletions(-) diff --git a/python/sdist/amici/de_model.py b/python/sdist/amici/de_model.py index 0e48bf3af8..ba11687929 100644 --- a/python/sdist/amici/de_model.py +++ b/python/sdist/amici/de_model.py @@ -1617,15 +1617,12 @@ def _compute_equation(self, name: str) -> None: ] if name == "dzdx": for ie in range(self.num_events()): - dtaudx = ( - -self.eq("drootdx")[ie, :] - / self.eq("drootdt_total")[ie] - ) + dtaudx = self.eq("dtaudx") for iz in range(self.num_eventobs()): if ie != self._z2event[iz] - 1: continue dzdt = sp.diff(self.eq("z")[ie][iz], time_symbol) - self._eqs[name][ie][iz, :] += dzdt * dtaudx + self._eqs[name][ie][iz, :] += dzdt * -dtaudx[ie] elif name in ["rz", "drzdx", "drzdp"]: eq_events = [] @@ -1644,12 +1641,24 @@ def _compute_equation(self, name: str) -> None: elif name == "stau": self._eqs[name] = [ - -self.eq("sroot")[ie, :] / self.eq("drootdt_total")[ie] + self.eq("sroot")[ie, :] / self.eq("drootdt_total")[ie] if not self.eq("drootdt_total")[ie].is_zero else sp.zeros(*self.eq("sroot")[ie, :].shape) for ie in range(self.num_events()) ] + elif name == "dtaudx": + self._eqs[name] = [ + self.eq("drootdx")[ie, :] / self.eq("drootdt_total")[ie] + for ie in range(self.num_events()) + ] + + elif name == "dtaudp": + self._eqs[name] = [ + self.eq("drootdp")[ie, :] / self.eq("drootdt_total")[ie] + for ie in range(self.num_events()) + ] + elif name == "deltasx": if self.num_states_solver() * self.num_par() == 0: self._eqs[name] = [] @@ -1665,7 +1674,7 @@ def _compute_equation(self, name: str) -> None: self.eq("stau")[ie] ) and not smart_is_zero_matrix(self.eq("xdot")): tmp_eq += smart_multiply( - self.sym("xdot_old") - self.sym("xdot"), + self.sym("xdot") - self.sym("xdot_old"), self.sym("stau").T, ) @@ -1682,12 +1691,14 @@ def _compute_equation(self, name: str) -> None: if not smart_is_zero_matrix(self.eq("stau")[ie]): # chain rule for the time point tmp_eq += smart_multiply( - self.eq("ddeltaxdt")[ie], self.sym("stau").T + self.eq("ddeltaxdt")[ie], + -self.sym("stau").T, ) # additional part of chain rule state variables tmp_dxdp += smart_multiply( - self.sym("xdot_old"), self.sym("stau").T + self.sym("xdot_old"), + -self.sym("stau").T, ) # finish chain rule for the state variables @@ -1695,6 +1706,11 @@ def _compute_equation(self, name: str) -> None: self.eq("ddeltaxdx")[ie], tmp_dxdp ) + else: + tmp_eq = smart_multiply( + self.sym("xdot") - self.sym("xdot_old"), + self.eq("stau")[ie], + ) event_eqs.append(tmp_eq) self._eqs[name] = event_eqs diff --git a/python/sdist/amici/import_utils.py b/python/sdist/amici/import_utils.py index 1a0dc782db..793cade3e2 100644 --- a/python/sdist/amici/import_utils.py +++ b/python/sdist/amici/import_utils.py @@ -471,8 +471,8 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr: symbolic expressions for arguments of the piecewise function """ # how many condition-expression pairs will we have? - formula = sp.Float(0.0) - not_condition = sp.Float(1.0) + formula = sp.Integer(0) + not_condition = sp.Integer(1) if all(isinstance(arg, ExprCondPair) for arg in args): # sympy piecewise @@ -483,7 +483,7 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr: for coeff, trigger in grouped_args: if isinstance(coeff, BooleanAtom): - coeff = sp.Float(int(bool(coeff))) + coeff = sp.Integer(int(bool(coeff))) if trigger == sp.true: return formula + coeff * not_condition @@ -493,7 +493,7 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr: tmp = _parse_heaviside_trigger(trigger) formula += coeff * sp.simplify(not_condition * tmp) - not_condition *= 1 - tmp + not_condition *= sp.Integer(1) - tmp return formula @@ -517,21 +517,24 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr: # step with H(0) = 1 if isinstance(trigger, sp.core.relational.StrictLessThan): # x < y => x - y < 0 => r < 0 - return 1 - sp.Heaviside(root) + return sp.Integer(1) - sp.Heaviside(root) if isinstance(trigger, sp.core.relational.LessThan): # x <= y => not(y < x) => not(y - x < 0) => not -r < 0 return sp.Heaviside(-root) if isinstance(trigger, sp.core.relational.StrictGreaterThan): # y > x => y - x < 0 => -r < 0 - return 1 - sp.Heaviside(-root) + return sp.Integer(1) - sp.Heaviside(-root) if isinstance(trigger, sp.core.relational.GreaterThan): # y >= x => not(x < y) => not(x - y < 0) => not r < 0 return sp.Heaviside(root) # or(x,y) = not(and(not(x),not(y)) if isinstance(trigger, sp.Or): - return 1 - sp.Mul( - *[1 - _parse_heaviside_trigger(arg) for arg in trigger.args] + return sp.Integer(1) - sp.Mul( + *[ + sp.Integer(1) - _parse_heaviside_trigger(arg) + for arg in trigger.args + ] ) if isinstance(trigger, sp.And):