Skip to content

Commit

Permalink
fixup! feat: support inplace parameter observed
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 24, 2024
1 parent 3080ee7 commit 054fe1f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
13 changes: 12 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,17 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys
end
end

function wrap_assignments(isscalar, assignments; let_block = false)
function wrapper(expr)
Func(expr.args, [], Let(assignments, expr.body, let_block))

Check warning on line 206 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L204-L206

Added lines #L204 - L206 were not covered by tests
end
if isscalar
wrapper

Check warning on line 209 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L208-L209

Added lines #L208 - L209 were not covered by tests
else
wrapper, wrapper

Check warning on line 211 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L211

Added line #L211 was not covered by tests
end
end

function wrap_array_vars(sys::AbstractSystem, exprs; dvs = unknowns(sys))
isscalar = !(exprs isa AbstractArray)
array_vars = Dict{Any, AbstractArray{Int}}()
Expand Down Expand Up @@ -505,7 +516,7 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
ts_idx = nothing

Check warning on line 516 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L516

Added line #L516 was not covered by tests
end
rawobs = build_explicit_observed_function(

Check warning on line 518 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L518

Added line #L518 was not covered by tests
sys, sym; param_only = true, return_inplace = true)
sys, sym; param_only = true, return_inplace = true)
if rawobs isa Tuple
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1(p::MTKParameters, t) = oop(p..., t)
Expand Down
12 changes: 9 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -487,17 +487,23 @@ function build_explicit_observed_function(sys, ts;
if inputs === nothing
args = param_only ? [ps..., ivs...] : [dvs, ps..., ivs...]

Check warning on line 488 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L488

Added line #L488 was not covered by tests
else
ipts = DestructuredArgs(inputs, inbounds = !checkbounds)
ipts = DestructuredArgs(unwrap.(inputs), inbounds = !checkbounds)
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]

Check warning on line 491 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L490-L491

Added lines #L490 - L491 were not covered by tests
end
pre = get_postprocess_fbody(sys)
res = build_function(isscalar ? ts[1] : ts, args...; get_postprocess_fbody = pre, wrap_code = wrap_array_vars(sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)), expression = Val{expression})
res = build_function(isscalar ? ts[1] : ts,

Check warning on line 494 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L494

Added line #L494 was not covered by tests
args...;
postprocess_fbody = pre,
wrap_code = wrap_array_vars(
sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)) .∘
wrap_assignments(isscalar, obsexprs),
expression = Val{expression})
if isscalar || return_inplace
return res

Check warning on line 502 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L501-L502

Added lines #L501 - L502 were not covered by tests
else
return res[1]

Check warning on line 504 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L504

Added line #L504 was not covered by tests
end

ex = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
Expand Down

0 comments on commit 054fe1f

Please sign in to comment.