Skip to content

Commit

Permalink
fixup! feat: implement new SII discrete saving interface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed May 21, 2024
1 parent 70756b3 commit ee65ce1
Showing 1 changed file with 80 additions and 34 deletions.
114 changes: 80 additions & 34 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,16 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where {
Timeseries()
end

function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continuity)
is_parameter_timeseries(sol) == Timeseries() || return nothing

discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
interp_discs = map(discs) do partition
ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
end
return ParameterTimeseriesCollection(interp_discs, parameter_values(discs))
end

function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
continuity = :left) where {deriv}
sol(t, deriv, idxs, continuity)
Expand All @@ -170,14 +180,7 @@ end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
idxs::Nothing, continuity) where {deriv}
if is_parameter_timeseries(sol) == Timeseries()
discs = RecursiveArrayTools.get_discretes(sol)
interp_discs = ConstantInterpolation(discs.t, discs.u)
discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity)
else
discretes = nothing
end

discretes = get_interpolated_discretes(sol, t, deriv, continuity)
augment(sol.interp(t, idxs, deriv, sol.prob.p, continuity), sol; discretes)
end

Expand Down Expand Up @@ -214,42 +217,46 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
ps = parameter_values(sol)
# NOTE: This is basically SII.parameter_values_at_time but that isn't public API
# and once we move interpolation to SII, there's no reason for it to be
if is_parameter_timeseries(sol) == Timeseries()
discs = RecursiveArrayTools.get_discretes(sol)
interp_discs = ConstantInterpolation(discs.t, discs.u)
discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity)
ps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes)
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
ps = parameter_values(discs)
for ts_idx in eachindex(discs)
partition = discs[ts_idx]
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val)
end
end
interp_sol = augment(sol.interp([t], nothing, deriv, ps, continuity), sol)
return getu(interp_sol, idxs)(interp_sol, 1)
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
return getu(sol, idxs)(state)
end

function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector,
continuity) where {deriv}
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
ps = parameter_values(sol)
# NOTE: This is basically SII.parameter_values_at_time but that isn't public API
# and once we move interpolation to SII, there's no reason for it to be
if is_parameter_timeseries(sol) == Timeseries()
discs = RecursiveArrayTools.get_discretes(sol)
interp_discs = ConstantInterpolation(discs.t, discs.u)
discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity)
ps = SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes)
discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol)
ps = parameter_values(discs)
for ts_idx in eachindex(discs)
partition = discs[ts_idx]
interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity)
ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val)
end
end
interp_sol = augment(sol.interp([t], nothing, deriv, ps, continuity), sol)
first(getu(interp_sol, idxs)(interp_sol))
state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t)
return getu(sol, idxs)(state)
end

function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs,
continuity) where {deriv}
symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`")
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if is_parameter_timeseries(sol) == Timeseries()
discs = RecursiveArrayTools.get_discretes(sol)
interp_discs = ConstantInterpolation(discs.t, discs.u)
discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity)
else
discretes = nothing
end
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes)
return DiffEqArray(getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes)
end
Expand All @@ -259,18 +266,57 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv},
all(!isequal(NotSymbolic()), symbolic_type.(idxs)) ||
error("Incorrect specification of `idxs`")
p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing
if is_parameter_timeseries(sol) == Timeseries()
discs = RecursiveArrayTools.get_discretes(sol)
interp_discs = ConstantInterpolation(discs.t, discs.u)
discretes = interp_discs(t, nothing, deriv, parameter_values(sol), continuity)
else
discretes = nothing
end
discretes = get_interpolated_discretes(sol, t, deriv, continuity)
interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes)
return DiffEqArray(
getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes)
end

# public API, used by MTK
"""
create_parameter_timeseries_collection(sys, ps)
Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system
`sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters.
Defaults to `nothing`.
"""
function create_parameter_timeseries_collection(sys, ps, tspan)
return nothing
end

const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange}

# public API, used by MTK
"""
get_saveable_values(ps, timeseries_idx)
"""
function get_saveable_values end

function save_discretes!(integ::DEIntegrator, timeseries_idx)
save_discretes!(integ.sol, current_time(integ), get_saveable_values(parameter_values(integ), timeseries_idx), timeseries_idx)
end

save_discretes!(args...) = nothing

# public API, used by MTK
function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx)
RecursiveArrayTools.has_discretes(sol) || return
disc = RecursiveArrayTools.get_discretes(sol)
_save_discretes_internal!(disc[timeseries_idx], t, vals)
end

function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals)
push!(A.t, t)
push!(A.u, vals)
end

function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals)
if all(!isapprox(t), A.t)
error("Tried to save periodic discrete value with timeseries $(A.t) at time $t")
end
push!(A.u, vals)
end

function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
alg, t, u; timeseries_errors = length(u) > 2,
dense = false, dense_errors = dense,
Expand Down

0 comments on commit ee65ce1

Please sign in to comment.