Skip to content

Commit

Permalink
feat: support inplace parameter observed
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 22, 2024
1 parent d4c430e commit da26283
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
16 changes: 12 additions & 4 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ function SymbolicIndexingInterface.is_timeseries_parameter(sys::AbstractSystem,
end

function SymbolicIndexingInterface.timeseries_parameter_index(sys::AbstractSystem, sym)
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return false
has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing || return nothing
timeseries_parameter_index(ic, sym)

Check warning on line 485 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L483-L485

Added lines #L483 - L485 were not covered by tests
end

Expand All @@ -504,9 +504,17 @@ function SymbolicIndexingInterface.parameter_observed(sys::AbstractSystem, sym)
else
ts_idx = nothing

Check warning on line 505 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L505

Added line #L505 was not covered by tests
end
obsfn = let raw_obs_fn = build_explicit_observed_function(
sys, sym; param_only = true)
f(p::MTKParameters, t) = raw_obs_fn(p..., t)
rawobs = build_explicit_observed_function(

Check warning on line 507 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L507

Added line #L507 was not covered by tests
sys, sym; param_only = true, return_inplace = true)
if rawobs isa Tuple
obsfn = let oop = rawobs[1], iip = rawobs[2]
f1(p::MTKParameters, t) = oop(p..., t)
f1(out, p::MTKParameters, t) = iip(out, p..., t)

Check warning on line 512 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L509-L512

Added lines #L509 - L512 were not covered by tests
end
else
obsfn = let rawobs = rawobs
f2(p::MTKParameters, t) = rawobs(p..., t)

Check warning on line 516 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L515-L516

Added lines #L515 - L516 were not covered by tests
end
end
else
ts_idx = nothing
Expand Down
9 changes: 8 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ function build_explicit_observed_function(sys, ts;
drop_expr = drop_expr,
ps = full_parameters(sys),
param_only = false,
return_inplace = false,
op = Operator,
throw = true)
if (isscalar = symbolic_type(ts) !== NotSymbolic())
Expand Down Expand Up @@ -490,7 +491,13 @@ function build_explicit_observed_function(sys, ts;
args = param_only ? [ipts, ps..., ivs...] : [dvs, ipts, ps..., ivs...]

Check warning on line 491 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L491

Added line #L491 was not covered by tests
end
pre = get_postprocess_fbody(sys)

res = build_function(isscalar ? ts[1] : ts, args...; get_postprocess_fbody = pre, wrap_code = wrap_array_vars(sys, isscalar ? ts[1] : ts; dvs = param_only ? [] : unknowns(sys)), expression = Val{expression})
if isscalar || return_inplace
return res

Check warning on line 496 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L494-L496

Added lines #L494 - L496 were not covered by tests
else
return res[1]

Check warning on line 498 in src/systems/diffeqs/odesystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/diffeqs/odesystem.jl#L498

Added line #L498 was not covered by tests
end

ex = Func(args, [],
pre(Let(obsexprs,
isscalar ? ts[1] : MakeArray(ts, output_type),
Expand Down
1 change: 0 additions & 1 deletion src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ function IndexCache(sys::AbstractSystem)
if istree(sym) && operation(sym) == Shift(t, 1)
sym = only(arguments(sym))

Check warning on line 128 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L125-L128

Added lines #L125 - L128 were not covered by tests
end
# is_parameter(sys, sym) || is_parameter(sys, Hold(sym)) || continue
disc_clocks[sym] = i - 1
disc_clocks[sym] = i - 1
disc_clocks[default_toterm(sym)] = i - 1
Expand Down

0 comments on commit da26283

Please sign in to comment.