Skip to content

Commit

Permalink
feat: add discrete saving feature to ODESolution
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Mar 8, 2024
1 parent 0998e07 commit b5971a3
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand Down Expand Up @@ -78,6 +79,7 @@ RecursiveArrayTools = "3.8.0"
Reexport = "1"
RuntimeGeneratedFunctions = "0.5.12"
SciMLOperators = "0.3.7"
SciMLStructures = "1.1"
StaticArrays = "1.7"
StaticArraysCore = "1.4"
Statistics = "1.10"
Expand Down
1 change: 1 addition & 0 deletions src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ end
using ConstructionBase
using RecipesBase, RecursiveArrayTools, Tables
using SymbolicIndexingInterface
using SciMLStructures
using DocStringExtensions
using LinearAlgebra
using Statistics
Expand Down
37 changes: 32 additions & 5 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,15 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
"""
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S,
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S,
AC <: Union{Nothing, Vector{Int}}} <:
AbstractODESolution{T, N, uType}
u::uType
u_analytic::uType2
errors::DType
t::tType
k::rateType
discretes::discType
prob::P
alg::A
interp::IType
Expand All @@ -133,12 +134,12 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy
return getfield(x, s)
end

function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense,
function ODESolution{T, N}(u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense,
tslocation, stats, alg_choice, retcode) where {T, N}
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(prob), typeof(alg), typeof(interp),
typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp),
typeof(stats),
typeof(alg_choice)}(u, u_analytic, errors, t, k, prob, alg, interp,
typeof(alg_choice)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp,
dense, tslocation, stats, alg_choice, retcode)
end

Expand Down Expand Up @@ -257,13 +258,25 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
Base.depwarn(msg, :build_solution)
end

ps = parameter_values(prob)
if ps === NullParameters() || ps === nothing
discretes = nothing
else
discs, _, _ = SciMLStructures.canonicalize(SciMLStructures.Discrete(), ps)
if discs === nothing || isempty(discs)
discretes = nothing
else
discretes = DiffEqArray(typeof(discs)[copy(discs)], eltype(t)[current_time(prob)], nothing, nothing)
end
end
if has_analytic(f)
u_analytic = Vector{typeof(prob.u0)}()
errors = Dict{Symbol, real(eltype(prob.u0))}()
sol = ODESolution{T, N}(u,
u_analytic,
errors,
t, k,
discretes,
prob,
alg,
interp,
Expand All @@ -282,6 +295,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
nothing,
nothing,
t, k,
discretes,
prob,
alg,
interp,
Expand Down Expand Up @@ -339,6 +353,7 @@ function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N}
errors,
sol.t,
sol.k,
sol.discretes,
sol.prob,
sol.alg,
sol.interp,
Expand All @@ -355,6 +370,7 @@ function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N}
sol.errors,
sol.t,
sol.k,
sol.discretes,
sol.prob,
sol.alg,
sol.interp,
Expand All @@ -371,6 +387,7 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N
sol.errors,
sol.t,
sol.k,
sol.discretes,
sol.prob,
sol.alg,
sol.interp,
Expand All @@ -382,11 +399,20 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N
end

function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
new_t = sol.t[I]
if sol.discretes === nothing
discretes = nothing
else
mint, maxt = extrema(new_t)
disc_mask = mint .<= sol.discretes.t .<= maxt
discretes = sol.discretes[:, disc_mask]
end
ODESolution{T, N}(sol.u[I],
sol.u_analytic === nothing ? nothing : sol.u_analytic[I],
sol.errors,
sol.t[I],
new_t,
sol.dense ? sol.k[I] : sol.k,
discretes,
sol.prob,
sol.alg,
sol.interp,
Expand All @@ -411,6 +437,7 @@ function sensitivity_solution(sol::ODESolution, u, t)
interp = enable_interpolation_sensitivitymode(sol.interp)
ODESolution{T, N}(u, sol.u_analytic, sol.errors,
t isa Vector ? t : collect(t),
sol.discretes,
sol.k, sol.prob,
sol.alg, interp,
sol.dense, sol.tslocation,
Expand Down
41 changes: 41 additions & 0 deletions src/solutions/solution_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,51 @@ SymbolicIndexingInterface.is_time_dependent(::AbstractTimeseriesSolution) = true

SymbolicIndexingInterface.is_time_dependent(::AbstractNoTimeSolution) = false

function SymbolicIndexingInterface.parameter_timeseries(A::AbstractTimeseriesSolution)
if !isdefined(A, :discretes) || (discretes = A.discretes) === nothing
return [0]
end
return discretes.t
end

function SymbolicIndexingInterface.parameter_values_at_time(A::AbstractTimeseriesSolution, t)
ps = parameter_values(A)
if !isdefined(A, :discretes) || (discretes = A.discretes) === nothing
return ps
end
return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[t])
end

function SymbolicIndexingInterface.parameter_values_at_state_time(A::AbstractTimeseriesSolution, tidx)
ps = parameter_values(A)
if !isdefined(A, :discretes) || (discretes = A.discretes) === nothing
return ps
end
t = A.t[tidx]
idx = searchsortedfirst(discretes.t, t; lt = <=)
if idx == firstindex(discretes.t)
error("This should never happen: there is no discrete parameter value before the current time")
end
return SciMLStructures.replace(SciMLStructures.Discrete(), ps, discretes.u[idx - 1])
end

# TODO make this nontrivial once dynamic state selection works
SymbolicIndexingInterface.constant_structure(::AbstractSolution) = true
SymbolicIndexingInterface.state_values(A::AbstractNoTimeSolution) = A.u

function save_discrete_parameters_after_callback(A::AbstractTimeseriesSolution, ps, t)
if !isdefined(A, :discretes) || A.discretes === nothing
return
end
discretes, _, alias = SciMLStructures.canonicalize(SciMLStructures.Discrete(), ps)
if alias
discretes = copy(discretes)
end
push!(A.discretes.u, discretes)
push!(A.discretes.t, t)
nothing
end

Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, ::Colon)
return A.u[:]
end
Expand Down

0 comments on commit b5971a3

Please sign in to comment.