Skip to content

Commit

Permalink
Merge branch 'develop' into adjoint_event
Browse files Browse the repository at this point in the history
  • Loading branch information
dweindl committed Oct 9, 2024
2 parents a8aea88 + b7a3e91 commit f66c6a5
Show file tree
Hide file tree
Showing 7 changed files with 149 additions and 96 deletions.
3 changes: 1 addition & 2 deletions include/amici/vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,7 @@ class AmiVector {
return;
}
nvec_ = N_VMake_Serial(
gsl::narrow<long int>(vold.vec_.size()), vec_.data(),
vold.nvec_->sunctx
gsl::narrow<long int>(vec_.size()), vec_.data(), vold.nvec_->sunctx
);
}

Expand Down
19 changes: 11 additions & 8 deletions python/sdist/amici/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,8 +471,8 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:
symbolic expressions for arguments of the piecewise function
"""
# how many condition-expression pairs will we have?
formula = sp.Float(0.0)
not_condition = sp.Float(1.0)
formula = sp.Integer(0)
not_condition = sp.Integer(1)

if all(isinstance(arg, ExprCondPair) for arg in args):
# sympy piecewise
Expand All @@ -483,7 +483,7 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:

for coeff, trigger in grouped_args:
if isinstance(coeff, BooleanAtom):
coeff = sp.Float(int(bool(coeff)))
coeff = sp.Integer(int(bool(coeff)))

if trigger == sp.true:
return formula + coeff * not_condition
Expand All @@ -493,7 +493,7 @@ def _parse_piecewise_to_heaviside(args: Iterable[sp.Expr]) -> sp.Expr:

tmp = _parse_heaviside_trigger(trigger)
formula += coeff * sp.simplify(not_condition * tmp)
not_condition *= sp.Float(1.0) - tmp
not_condition *= sp.Integer(1) - tmp

return formula

Expand All @@ -517,21 +517,24 @@ def _parse_heaviside_trigger(trigger: sp.Expr) -> sp.Expr:
# step with H(0) = 1
if isinstance(trigger, sp.core.relational.StrictLessThan):
# x < y => x - y < 0 => r < 0
return sp.Float(1.0) - sp.Heaviside(root)
return sp.Integer(1) - sp.Heaviside(root)
if isinstance(trigger, sp.core.relational.LessThan):
# x <= y => not(y < x) => not(y - x < 0) => not -r < 0
return sp.Heaviside(-root)
if isinstance(trigger, sp.core.relational.StrictGreaterThan):
# y > x => y - x < 0 => -r < 0
return sp.Float(1.0) - sp.Heaviside(-root)
return sp.Integer(1) - sp.Heaviside(-root)
if isinstance(trigger, sp.core.relational.GreaterThan):
# y >= x => not(x < y) => not(x - y < 0) => not r < 0
return sp.Heaviside(root)

# or(x,y) = not(and(not(x),not(y))
if isinstance(trigger, sp.Or):
return 1 - sp.Mul(
*[1 - _parse_heaviside_trigger(arg) for arg in trigger.args]
return sp.Integer(1) - sp.Mul(
*[
sp.Integer(1) - _parse_heaviside_trigger(arg)
for arg in trigger.args
]
)

if isinstance(trigger, sp.And):
Expand Down
2 changes: 1 addition & 1 deletion src/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ void Solver::setConstraints(std::vector<realtype> const& constraints) {

if (!any_constraint) {
// all-0 must be converted to empty, otherwise sundials will fail
constraints_ = AmiVector();
constraints_ = AmiVector(0, sunctx_);
return;
}

Expand Down
10 changes: 5 additions & 5 deletions src/vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@ void AmiVector::copy(AmiVector const& other) {
void AmiVector::synchroniseNVector(SUNContext sunctx) {
if (nvec_)
N_VDestroy_Serial(nvec_);
nvec_ = vec_.empty()
? nullptr
: N_VMake_Serial(
gsl::narrow<long int>(vec_.size()), vec_.data(), sunctx
);
if (sunctx) {
nvec_ = N_VMake_Serial(
gsl::narrow<long int>(vec_.size()), vec_.data(), sunctx
);
}
}

AmiVector::~AmiVector() {
Expand Down
Loading

0 comments on commit f66c6a5

Please sign in to comment.