Skip to content

Commit

Permalink
feat: implement interpolation for parameter timeseries
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 11, 2024
1 parent 78a6f41 commit 7f6658d
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
unknown_tidx = searchsortedfirst(sol.t, t; lt = <=) - 1
ps = parameter_values_at_state_time(sol, unknown_tidx)
return getp(sol, idxs)(ps)

Check warning on line 194 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L192-L194

Added lines #L192 - L194 were not covered by tests
else
return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1]
end
Expand All @@ -200,14 +202,19 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)
[is_parameter(sol, idx) ? getp(sol, idx)(sol) : first(interp_sol[idx]) for idx in idxs]
unknown_tidx = searchsortedfirst(sol.t, t; lt = <=) - 1
ps = parameter_values_at_state_time(sol, unknown_tidx)
[is_parameter(sol, idx) ? getp(sol, idx)(ps) : first(interp_sol[idx]) for idx in idxs]

Check warning on line 207 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L205-L207

Added lines #L205 - L207 were not covered by tests
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
if is_parameter(sol, idxs)
return getp(sol, idxs)(sol)
unknown_tidxs = searchsortedfirst((sol.t,), t; lt = <=) - 1
pss = parameter_values_at_state_time.((sol,), unknown_tidxs)
getter = getp(sol, idxs)
return getter.(pss)

Check warning on line 217 in src/solutions/ode_solutions.jl

View check run for this annotation

Codecov / codecov/patch

src/solutions/ode_solutions.jl#L214-L217

Added lines #L214 - L217 were not covered by tests
else
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
Expand All @@ -222,7 +229,7 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol)
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
return DiffEqArray(
[[interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol)
[[is_parameter(sol, idx) ? getp(sol, idx)(interp_sol, i) : interp_sol[idx][i] for idx in idxs] for i in 1:length(t)], t, p, sol)
end

function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
Expand Down

0 comments on commit 7f6658d

Please sign in to comment.