From 70e387d30a0a047ec0e7274084e8c56ca69ad9a8 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 15 Feb 2023 18:25:34 +0000 Subject: [PATCH 01/10] symbolic save idxs, new observed --- src/solutions/dae_solutions.jl | 40 +++++++++++++---------------- src/solutions/ode_solutions.jl | 22 ++++++++++------ src/solutions/solution_interface.jl | 24 ++++++++++++----- 3 files changed, 50 insertions(+), 36 deletions(-) diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index edaa3b71a..2284740e7 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ -struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE} <: +struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE, MType} <: AbstractDAESolution{T, N, uType} u::uType du::duType @@ -39,6 +39,7 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE} <: interp::ID dense::Bool tslocation::Int + sym_map::MType destats::DE retcode::ReturnCode.T end @@ -53,6 +54,7 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; HermiteInterpolation(t, u, du), retcode = ReturnCode.Default, destats = nothing, + sym_map = nothing, kwargs...) T = eltype(eltype(u)) @@ -67,18 +69,10 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; errors = Dict{Symbol, real(eltype(prob.u0))}() sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors), - typeof(t), - typeof(prob), typeof(alg), typeof(interp), typeof(destats)}(u, du, - u_analytic, - errors, - t, - prob, - alg, - interp, - dense, - 0, - destats, - retcode) + typeof(t), typeof(prob), typeof(alg), typeof(interp), + typeof(destats), typeof(sym_map)}(u, du, u_analytic, errors, t, + prob, alg, interp, dense, 0, + sym_map, destats, retcode) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -86,15 +80,17 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; end sol else - DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t), - typeof(prob), typeof(alg), typeof(interp), typeof(destats)}(u, du, - nothing, - nothing, t, - prob, alg, - interp, - dense, 0, - destats, - retcode) + DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t), typeof(prob), + typeof(alg), typeof(interp), typeof(destats), typeof(sym_map)}(u, du, + nothing, + nothing, + t, prob, + alg, + interp, + dense, 0, + sym_map, + destats, + retcode) end end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 0b899a525..e9f26f0b9 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -22,12 +22,12 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ - `destats`: statistics of the solver, such as the number of function evaluations required, number of Jacobians computed, and more. - `retcode`: the return code from the solver. Used to determine whether the solver solved - successfully, whether it terminated early due to a user-defined callback, or whether it - exited due to an error. For more details, see + successfully, whether it terminated early due to a user-defined callback, or whether it + exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, - AC <: Union{Nothing, Vector{Int}}} <: + AC <: Union{Nothing, Vector{Int}}, MType} <: AbstractODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -39,17 +39,21 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, interp::IType dense::Bool tslocation::Int + sym_map::MType destats::DE alg_choice::AC retcode::ReturnCode.T end function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, - tslocation, destats, alg_choice, retcode) where {T, N} + tslocation, destats, alg_choice, retcode, + sym_map = nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), - typeof(alg_choice)}(u, u_analytic, errors, t, k, prob, alg, interp, - dense, tslocation, destats, alg_choice, retcode) + typeof(alg_choice), + typeof(sym_map)}(u, u_analytic, errors, t, k, prob, alg, interp, + dense, tslocation, sym_map, destats, alg_choice, + retcode) end function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} @@ -160,6 +164,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), + sym_map = nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) @@ -189,7 +194,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, destats, alg_choice, - retcode) + retcode, + sym_map) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, dense_errors = dense_errors) @@ -207,7 +213,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, destats, alg_choice, - retcode) + retcode, sym_map) end end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 525696136..44afb7f03 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -120,7 +120,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s i = sym end - if i === nothing + if isnothing(i) if issymbollike(sym) && has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || Symbol(sym) == getindepsym(A) A.t[args...] @@ -134,16 +134,28 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end end +function get_dep_idxs(A::AbstractTimeseriesSolution) + if has_sys(A.prob.f) && has_observed(A.prob.f) && !isnothing(A.sym_map) + idxs = map(x -> A.sym_map[x], get_deps_of_observed(A)[sym][i]) + else + idxs = CartesianIndices(first(A.u)) + end + idxs +end + function observed(A::AbstractTimeseriesSolution, sym, i::Int) - getobserved(A)(sym, A[i], A.prob.p, A.t[i]) + idxs = get_dep_idxs(A) + getobserved(A)(sym, A[i][idxs], A.prob.p, A.t[i]) end -function observed(A::AbstractTimeseriesSolution, sym, i::AbstractArray{Int}) - getobserved(A).((sym,), A.u[i], (A.prob.p,), A.t[i]) +function observed(A::AbstractTimeseriesSolution, sym, is::AbstractArray{Int}) + idxs = get_dep_idxs(A) + getobserved(A).((sym,), map(j -> A.u[j][idxs], is), (A.prob.p,), A.t[is]) end function observed(A::AbstractTimeseriesSolution, sym, i::Colon) - getobserved(A).((sym,), A.u, (A.prob.p,), A.t) + idxs = get_dep_idxs(A) + getobserved(A).((sym,), map(j -> A.u[j][idxs], eachindex(A.t)), (A.prob.p,), A.t) end Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) @@ -158,7 +170,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) i = sym end - if i == nothing + if isnothing(i) paramsyms = getparamsyms(A) if issymbollike(sym) && paramsyms !== nothing && Symbol(sym) in paramsyms get_p(A)[findfirst(x -> isequal(x, Symbol(sym)), paramsyms)] From 875658bf4bf79ed880be1f30be912974d60a06c0 Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 17 Feb 2023 13:31:24 +0000 Subject: [PATCH 02/10] more solutions --- src/solutions/dae_solutions.jl | 114 +++++++++------- src/solutions/nonlinear_solutions.jl | 36 +++-- src/solutions/ode_solutions.jl | 23 +++- src/solutions/optimization_solutions.jl | 33 +++-- src/solutions/rode_solutions.jl | 172 +++++++++++++++--------- src/solutions/solution_interface.jl | 16 ++- 6 files changed, 248 insertions(+), 146 deletions(-) diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index 2284740e7..a13cd83b4 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -54,7 +54,9 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; HermiteInterpolation(t, u, du), retcode = ReturnCode.Default, destats = nothing, - sym_map = nothing, + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing, kwargs...) T = eltype(eltype(u)) @@ -135,70 +137,78 @@ end function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u, - sol.du, - u_analytic, - errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.destats, - sol.retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map)}(sol.u, + sol.du, + u_analytic, + errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.retcode) end function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.destats, - retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map)}(sol.u, + sol.du, + sol.u_analytic, + sol.errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + retcode) end function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.destats, - sol.retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map)}(sol.u, + sol.du, + sol.u_analytic, + sol.errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + tslocation, + sol.sym_map, + sol.destats, + sol.retcode) end function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u[I], - sol.du[I], - sol.u_analytic === - nothing ? - nothing : - sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.destats, - sol.retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map)}(sol.u[I], + sol.du[I], + sol.u_analytic === + nothing ? + nothing : + sol.u_analytic[I], + sol.errors, + sol.t[I], + sol.prob, + sol.alg, + sol.interp, + false, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.retcode) end diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index f954a9576..67dfaf098 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -13,12 +13,14 @@ or the steady state solution to a differential equation defined by a SteadyState - `original`: if the solver is wrapped from an alternative solver ecosystem, such as NLsolve.jl, then this is the original return from said solver library. - `retcode`: the return code from the solver. Used to determine whether the solver solved - successfully or whether it exited due to an error. For more details, see + successfully or whether it exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). - `left`: if the solver is bracketing method, this is the final left bracket value. - `right`: if the solver is bracketing method, this is the final right bracket value. +- `sym_map`: Map of symbols to their index in the solution """ -struct NonlinearSolution{T, N, uType, R, P, A, O, uType2} <: AbstractNonlinearSolution{T, N} +struct NonlinearSolution{T, N, uType, R, P, A, O, uType2, MType} <: + AbstractNonlinearSolution{T, N} u::uType resid::R prob::P @@ -27,6 +29,7 @@ struct NonlinearSolution{T, N, uType, R, P, A, O, uType2} <: AbstractNonlinearSo original::O left::uType2 right::uType2 + sym_map::MType end const SteadyStateSolution = NonlinearSolution @@ -37,17 +40,21 @@ function build_solution(prob::AbstractNonlinearProblem, original = nothing, left = nothing, right = nothing, + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) NonlinearSolution{T, N, typeof(u), typeof(resid), - typeof(prob), typeof(alg), typeof(original), typeof(left)}(u, resid, - prob, alg, - retcode, - original, - left, - right) + typeof(prob), typeof(alg), typeof(original), typeof(left), + typeof(sym_map)}(u, resid, + prob, alg, + retcode, + original, left, + right, + sym_map) end function sensitivity_solution(sol::AbstractNonlinearSolution, u) @@ -56,8 +63,13 @@ function sensitivity_solution(sol::AbstractNonlinearSolution, u) NonlinearSolution{T, N, typeof(u), typeof(sol.resid), typeof(sol.prob), typeof(sol.alg), - typeof(sol.original), typeof(sol.left)}(u, sol.resid, sol.prob, - sol.alg, sol.retcode, - sol.original, sol.left, - sol.right) + typeof(sol.original), typeof(sol.left), typeof(sol.sym_map)}(u, + sol.resid, + sol.prob, + sol.alg, + sol.retcode, + sol.original, + sol.left, + sol.right, + sol.sym_map) end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index e9f26f0b9..5fabba6b8 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -21,6 +21,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ - `alg`: the algorithm type used by the solver. - `destats`: statistics of the solver, such as the number of function evaluations required, number of Jacobians computed, and more. + - `sym_map`: Map of symbols to their index in the solution - `retcode`: the return code from the solver. Used to determine whether the solver solved successfully, whether it terminated early due to a user-defined callback, or whether it exited due to an error. For more details, see @@ -46,7 +47,9 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, end function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, destats, alg_choice, retcode, - sym_map = nothing) where {T, N} + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), @@ -164,7 +167,9 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), - sym_map = nothing, + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) @@ -269,7 +274,8 @@ function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} sol.tslocation, sol.destats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.sym_map) end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} @@ -285,7 +291,8 @@ function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} sol.tslocation, sol.destats, sol.alg_choice, - retcode) + retcode, + sol.sym_map) end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} @@ -301,7 +308,8 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N tslocation, sol.destats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.sym_map) end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} @@ -317,7 +325,8 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N} sol.tslocation, sol.destats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.sym_map) end function sensitivity_solution(sol::ODESolution, u, t) @@ -335,5 +344,5 @@ function sensitivity_solution(sol::ODESolution, u, t) nothing, sol.prob, sol.alg, interp, sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, sol.retcode) + sol.destats, sol.alg_choice, sol.retcode, sol.sym_map) end diff --git a/src/solutions/optimization_solutions.jl b/src/solutions/optimization_solutions.jl index fc830d6aa..41464124c 100644 --- a/src/solutions/optimization_solutions.jl +++ b/src/solutions/optimization_solutions.jl @@ -15,8 +15,10 @@ Representation of the solution to a non-linear optimization defined by an Optimi - `original`: if the solver is wrapped from an alternative solver ecosystem, such as Optim.jl, then this is the original return from said solver library. - `solve_time`: Solve time from the solver in Seconds +- `sym_map`: Map of symbols to their index in the solution """ -struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, O, S} <: +struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, O, S, MType + } <: AbstractOptimizationSolution{T, N} u::uType # minimizer cache::C # optimization cache @@ -25,6 +27,7 @@ struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, retcode::ReturnCode.T original::O # original output of the optimizer solve_time::S # [s] solve time from the solver + sym_map::MType end function build_solution(cache::AbstractOptimizationCache, @@ -32,6 +35,9 @@ function build_solution(cache::AbstractOptimizationCache, retcode = ReturnCode.Default, original = nothing, solve_time = nothing, + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) @@ -40,12 +46,13 @@ function build_solution(cache::AbstractOptimizationCache, retcode = symbol_to_ReturnCode(retcode) OptimizationSolution{T, N, typeof(u), typeof(cache), typeof(alg), - typeof(objective), typeof(original), typeof(solve_time)}(u, cache, - alg, - objective, - retcode, - original, - solve_time) + typeof(objective), typeof(original), typeof(solve_time), + typeof(sym_map)}(u, cache, + alg, + objective, + retcode, + original, + solve_time, sym_map) end """ @@ -64,6 +71,9 @@ function build_solution(prob::AbstractOptimizationProblem, alg, u, objective; retcode = ReturnCode.Default, original = nothing, + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) @@ -78,9 +88,12 @@ function build_solution(prob::AbstractOptimizationProblem, retcode = symbol_to_ReturnCode(retcode) OptimizationSolution{T, N, typeof(u), typeof(cache), typeof(alg), - typeof(objective), typeof(original)}(u, cache, alg, objective, - retcode, - original) + typeof(objective), typeof(original), typeof(sym_map)}(u, cache, + alg, + objective, + retcode, + original, + sym_map) end get_p(sol::OptimizationSolution) = sol.cache.p diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index 5be1bd614..b6e6e2657 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -27,13 +27,14 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ - `alg`: the algorithm type used by the solver. - `destats`: statistics of the solver, such as the number of function evaluations required, number of Jacobians computed, and more. +- `sym_map`: Map of symbols to their index in the solution - `retcode`: the return code from the solver. Used to determine whether the solver solved successfully, whether it terminated early due to a user-defined callback, or whether it exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, DE, - AC <: Union{Nothing, Vector{Int}}} <: + AC <: Union{Nothing, Vector{Int}}, MType} <: AbstractRODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -45,6 +46,7 @@ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, DE interp::IType dense::Bool tslocation::Int + sym_map::MType destats::DE alg_choice::AC retcode::ReturnCode.T @@ -65,6 +67,9 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, alg_choice = nothing, + sym_map = has_sys(prob.f) ? + Dict(states(prob.f.sys) .=> + 1:length(prob.f.sys |> states)) : nothing, seed = UInt64(0), destats = nothing, kwargs...) T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) @@ -81,19 +86,20 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(W), typeof(prob), typeof(alg), typeof(interp), typeof(destats), - typeof(alg_choice)}(u, - u_analytic, - errors, - t, W, - prob, - alg, - interp, - dense, - 0, - destats, - alg_choice, - retcode, - seed) + typeof(alg_choice), typeof(sym_map)}(u, + u_analytic, + errors, + t, W, + prob, + alg, + interp, + dense, + 0, + sym_map, + destats, + alg_choice, + retcode, + seed) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -104,10 +110,18 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, else return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t), typeof(W), typeof(prob), typeof(alg), typeof(interp), - typeof(destats), typeof(alg_choice)}(u, nothing, nothing, t, W, - prob, alg, interp, - dense, 0, destats, - alg_choice, retcode, seed) + typeof(destats), typeof(alg_choice), typeof(sym_map)}(u, + nothing, + nothing, + t, W, + prob, alg, + interp, + dense, 0, + sym_map, + destats, + alg_choice, + retcode, + seed) end end @@ -161,53 +175,87 @@ end function build_solution(sol::AbstractRODESolution{T, N}, u_analytic, errors) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(u_analytic), typeof(errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u, u_analytic, errors, - sol.t, sol.W, sol.prob, - sol.alg, sol.interp, - sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, - sol.retcode, sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u, + u_analytic, + errors, + sol.t, + sol.W, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end function solution_new_retcode(sol::AbstractRODESolution{T, N}, retcode) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic, - sol.errors, sol.t, sol.W, - sol.prob, sol.alg, sol.interp, - sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, - retcode, - sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u, + sol.u_analytic, + sol.errors, + sol.t, + sol.W, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + retcode, + sol.seed) end function solution_new_tslocation(sol::AbstractRODESolution{T, N}, tslocation) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic, - sol.errors, sol.t, sol.W, - sol.prob, sol.alg, sol.interp, - sol.dense, tslocation, - sol.destats, sol.alg_choice, - sol.retcode, - sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u, + sol.u_analytic, + sol.errors, + sol.t, + sol.W, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end function solution_slice(sol::AbstractRODESolution{T, N}, I) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u[I], - sol.u_analytic === nothing ? - nothing : sol.u_analytic, - sol.errors, sol.t[I], - sol.W, sol.prob, - sol.alg, sol.interp, - false, sol.tslocation, - sol.destats, sol.alg_choice, - sol.retcode, sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u[I], + sol.u_analytic === + nothing ? + nothing : + sol.u_analytic, + sol.errors, + sol.t[I], + sol.W, + sol.prob, + sol.alg, + sol.interp, + false, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end function sensitivity_solution(sol::AbstractRODESolution, u, t) @@ -224,18 +272,20 @@ function sensitivity_solution(sol::AbstractRODESolution, u, t) RODESolution{T, N, typeof(u), typeof(sol.u_analytic), typeof(sol.errors), typeof(t), typeof(nothing), typeof(sol.prob), typeof(sol.alg), - typeof(sol.interp), typeof(sol.destats), typeof(sol.alg_choice)}(u, - sol.u_analytic, - sol.errors, - t, - nothing, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.destats, - sol.alg_choice, - sol.retcode, - sol.seed) + typeof(sol.interp), typeof(sol.destats), typeof(sol.alg_choice), + typeof(sol.sym_map)}(u, + sol.u_analytic, + sol.errors, + t, + nothing, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 44afb7f03..7e630a859 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -73,7 +73,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if sym isa AbstractArray return A[collect(sym)] end - i = sym_to_index(sym, A) + if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) + i = get(A.sym_map, sym, nothing) + else + i = sym_to_index(sym, A) + end elseif all(issymbollike, sym) if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && @@ -113,7 +117,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if sym isa AbstractArray return A[collect(sym), args...] end - i = sym_to_index(sym, A) + if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) + i = get(A.sym_map, sym, nothing) + else + i = sym_to_index(sym, A) + end elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s, args...]', sym)) else @@ -136,7 +144,7 @@ end function get_dep_idxs(A::AbstractTimeseriesSolution) if has_sys(A.prob.f) && has_observed(A.prob.f) && !isnothing(A.sym_map) - idxs = map(x -> A.sym_map[x], get_deps_of_observed(A)[sym][i]) + idxs = map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) else idxs = CartesianIndices(first(A.u)) end @@ -185,7 +193,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) end function observed(A::AbstractNoTimeSolution, sym) - getobserved(A)(sym, A.u, A.prob.p) + observed(A, sym, :) end function observed(A::AbstractOptimizationSolution, sym) From 4e1611e5ed92b0a3dff48b4d5ca2f7145086871e Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 21 Feb 2023 17:24:07 +0000 Subject: [PATCH 03/10] handle case --- src/solutions/dae_solutions.jl | 4 +--- src/solutions/nonlinear_solutions.jl | 4 +--- src/solutions/ode_solutions.jl | 8 ++------ src/solutions/optimization_solutions.jl | 8 ++------ src/solutions/rode_solutions.jl | 4 +--- src/solutions/solution_interface.jl | 8 ++++++-- 6 files changed, 13 insertions(+), 23 deletions(-) diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index a13cd83b4..672e81dd0 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -54,9 +54,7 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; HermiteInterpolation(t, u, du), retcode = ReturnCode.Default, destats = nothing, - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing, + sym_map = nothing, kwargs...) T = eltype(eltype(u)) diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index 67dfaf098..986d469a6 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -40,9 +40,7 @@ function build_solution(prob::AbstractNonlinearProblem, original = nothing, left = nothing, right = nothing, - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing, + sym_map = nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 5fabba6b8..41b62648a 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -47,9 +47,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, end function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, destats, alg_choice, retcode, - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing) where {T, N} + sym_map = nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), @@ -167,9 +165,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing, + sym_map = nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) diff --git a/src/solutions/optimization_solutions.jl b/src/solutions/optimization_solutions.jl index 41464124c..dad4fb3df 100644 --- a/src/solutions/optimization_solutions.jl +++ b/src/solutions/optimization_solutions.jl @@ -35,9 +35,7 @@ function build_solution(cache::AbstractOptimizationCache, retcode = ReturnCode.Default, original = nothing, solve_time = nothing, - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing, + sym_map = nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) @@ -71,9 +69,7 @@ function build_solution(prob::AbstractOptimizationProblem, alg, u, objective; retcode = ReturnCode.Default, original = nothing, - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing, + sym_map = nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index b6e6e2657..f04d08f74 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -67,9 +67,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, alg_choice = nothing, - sym_map = has_sys(prob.f) ? - Dict(states(prob.f.sys) .=> - 1:length(prob.f.sys |> states)) : nothing, + sym_map = nothing, seed = UInt64(0), destats = nothing, kwargs...) T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 7e630a859..91ccb2733 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -143,8 +143,12 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end function get_dep_idxs(A::AbstractTimeseriesSolution) - if has_sys(A.prob.f) && has_observed(A.prob.f) && !isnothing(A.sym_map) - idxs = map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) + if has_sys(A.prob.f) && has_observed(A.prob.f) + if !isnothing(A.sym_map) + idxs = map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) + else + idxs = map(x -> sym_to_index(x, A), get_deps_of_observed(A.prob.f.sys)) + end else idxs = CartesianIndices(first(A.u)) end From 659a3e4b5dda953e3c5e53decb87bb8a42163b1d Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 21 Feb 2023 20:28:22 +0000 Subject: [PATCH 04/10] passing --- src/solutions/nonlinear_solutions.jl | 37 ++++++++++++++++------------ src/solutions/ode_solutions.jl | 34 +++++++++++++++---------- src/solutions/solution_interface.jl | 30 ++++++++++++++++++---- 3 files changed, 67 insertions(+), 34 deletions(-) diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index 986d469a6..fed9ba42d 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -19,7 +19,7 @@ or the steady state solution to a differential equation defined by a SteadyState - `right`: if the solver is bracketing method, this is the final right bracket value. - `sym_map`: Map of symbols to their index in the solution """ -struct NonlinearSolution{T, N, uType, R, P, A, O, uType2, MType} <: +struct NonlinearSolution{T, N, uType, R, P, A, O, uType2, MType, DI} <: AbstractNonlinearSolution{T, N} u::uType resid::R @@ -30,6 +30,7 @@ struct NonlinearSolution{T, N, uType, R, P, A, O, uType2, MType} <: left::uType2 right::uType2 sym_map::MType + dep_idxs::DI end const SteadyStateSolution = NonlinearSolution @@ -41,18 +42,20 @@ function build_solution(prob::AbstractNonlinearProblem, left = nothing, right = nothing, sym_map = nothing, + dep_idxs = Ref{Vector{Int}}(Int[]), kwargs...) T = eltype(eltype(u)) N = ndims(u) NonlinearSolution{T, N, typeof(u), typeof(resid), typeof(prob), typeof(alg), typeof(original), typeof(left), - typeof(sym_map)}(u, resid, - prob, alg, - retcode, - original, left, - right, - sym_map) + typeof(sym_map), typeof(dep_idxs)}(u, resid, + prob, alg, + retcode, + original, left, + right, + sym_map, + dep_idxs) end function sensitivity_solution(sol::AbstractNonlinearSolution, u) @@ -61,13 +64,15 @@ function sensitivity_solution(sol::AbstractNonlinearSolution, u) NonlinearSolution{T, N, typeof(u), typeof(sol.resid), typeof(sol.prob), typeof(sol.alg), - typeof(sol.original), typeof(sol.left), typeof(sol.sym_map)}(u, - sol.resid, - sol.prob, - sol.alg, - sol.retcode, - sol.original, - sol.left, - sol.right, - sol.sym_map) + typeof(sol.original), typeof(sol.left), typeof(sol.sym_map), + typeof(dep_idxs)}(u, + sol.resid, + sol.prob, + sol.alg, + sol.retcode, + sol.original, + sol.left, + sol.right, + sol.sym_map, + sol.dep_idxs) end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 41b62648a..a82dcf6a0 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -28,7 +28,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, - AC <: Union{Nothing, Vector{Int}}, MType} <: + AC <: Union{Nothing, Vector{Int}}, MType, DI} <: AbstractODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -43,18 +43,21 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, sym_map::MType destats::DE alg_choice::AC + dep_idxs::DI retcode::ReturnCode.T end function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, destats, alg_choice, retcode, - sym_map = nothing) where {T, N} + sym_map = nothing, dep_idxs = Ref{Vector{Int}}(Int[])) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), typeof(alg_choice), - typeof(sym_map)}(u, u_analytic, errors, t, k, prob, alg, interp, - dense, tslocation, sym_map, destats, alg_choice, - retcode) + typeof(sym_map), typeof(dep_idxs)}(u, u_analytic, errors, t, k, prob, + alg, interp, + dense, tslocation, sym_map, + destats, alg_choice, dep_idxs, + retcode) end function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} @@ -165,7 +168,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), - sym_map = nothing, + sym_map = nothing, dep_idxs = Ref{Vector{Int}}(Int[]), retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) @@ -196,7 +199,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, destats, alg_choice, retcode, - sym_map) + sym_map, + dep_idxs) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, dense_errors = dense_errors) @@ -214,7 +218,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, destats, alg_choice, - retcode, sym_map) + retcode, sym_map, dep_idxs) end end @@ -271,7 +275,8 @@ function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} sol.destats, sol.alg_choice, sol.retcode, - sol.sym_map) + sol.sym_map, + sol.dep_idxs) end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} @@ -288,7 +293,8 @@ function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} sol.destats, sol.alg_choice, retcode, - sol.sym_map) + sol.sym_map, + sol.dep_idxs) end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} @@ -305,7 +311,8 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N sol.destats, sol.alg_choice, sol.retcode, - sol.sym_map) + sol.sym_map, + sol.dep_idxs) end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} @@ -322,7 +329,8 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N} sol.destats, sol.alg_choice, sol.retcode, - sol.sym_map) + sol.sym_map, + sol.dep_idxs) end function sensitivity_solution(sol::ODESolution, u, t) @@ -340,5 +348,5 @@ function sensitivity_solution(sol::ODESolution, u, t) nothing, sol.prob, sol.alg, interp, sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, sol.retcode, sol.sym_map) + sol.destats, sol.alg_choice, sol.retcode, sol.sym_map, sol.dep_idxs) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 91ccb2733..dd1bbecf3 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -142,19 +142,39 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end end -function get_dep_idxs(A::AbstractTimeseriesSolution) - if has_sys(A.prob.f) && has_observed(A.prob.f) +function _get_dep_idxs(A::AbstractTimeseriesSolution) + idxs = if has_sys(A.prob.f) && has_observed(A.prob.f) + is_ODAE = hasfield(typeof(A.prob.f.sys), :unknown_states) && + !isnothing(getfield(A.prob.f.sys, :unknown_states)) if !isnothing(A.sym_map) - idxs = map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) + map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) + elseif is_ODAE + sts = getfield(A.prob.f.sys, :unknown_states) + map(x -> sym_to_index(x, A), + get_deps_of_observed(sts, SymbolicIndexingInterface.observed(A.prob.f.sys))) else - idxs = map(x -> sym_to_index(x, A), get_deps_of_observed(A.prob.f.sys)) + map(x -> sym_to_index(x, A), get_deps_of_observed(A.prob.f.sys)) end else - idxs = CartesianIndices(first(A.u)) + CartesianIndices(first(A.u)) end idxs end +function get_dep_idxs(A::AbstractTimeseriesSolution) + if hasfield(typeof(A), :dep_idxs) + if isnothing(A.dep_idxs[]) + A.dep_idxs[] + else + idxs = _get_dep_idxs(A) + A.dep_idxs[] = idxs + idxs + end + else + _get_dep_idxs(A) + end +end + function observed(A::AbstractTimeseriesSolution, sym, i::Int) idxs = get_dep_idxs(A) getobserved(A)(sym, A[i][idxs], A.prob.p, A.t[i]) From c432843a77a23392b46a21a1b0fc39dd7c075b1a Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 27 Feb 2023 15:12:51 +0000 Subject: [PATCH 05/10] passing --- src/SciMLBase.jl | 2 +- src/problems/ode_problems.jl | 8 +- src/remake.jl | 6 +- src/solutions/nonlinear_solutions.jl | 2 +- src/solutions/ode_solutions.jl | 8 +- src/solutions/solution_interface.jl | 148 +++++++-------------------- 6 files changed, 57 insertions(+), 117 deletions(-) diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index aeb48470c..c58ca7527 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -791,7 +791,7 @@ export step!, deleteat!, addat!, get_tmp_cache, set_reltol!, get_du, get_du!, get_dt, get_proposed_dt, set_proposed_dt!, u_modified!, savevalues!, reinit!, auto_dt_reset!, set_t!, set_u!, check_error, change_t_via_interpolation!, addsteps!, - isdiscrete, reeval_internals_due_to_modification! + isdiscrete, reeval_internals_due_to_modification!, is_dense_output export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index c4e0849dd..1a3589723 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -108,15 +108,18 @@ struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: kwargs::K """An internal argument for storing traits about the solving process.""" problem_type::PT + """Whether the output has all saved states.""" + dense_output::Bool @add_kwonly function ODEProblem{iip}(f::AbstractODEFunction{iip}, u0, tspan, p = NullParameters(), problem_type = StandardODEProblem(); + dense_output = true, kwargs...) where {iip} _tspan = promote_tspan(tspan) new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p), typeof(f), typeof(kwargs), - typeof(problem_type)}(f, u0, _tspan, p, kwargs, problem_type) + typeof(problem_type)}(f, u0, _tspan, p, kwargs, problem_type, dense_output) end """ @@ -471,3 +474,6 @@ function IncrementingODEProblem{iip}(f::IncrementingODEFunction, u0, tspan, p = NullParameters(); kwargs...) where {iip} ODEProblem(f, u0, tspan, p, IncrementingODEProblem{iip}(); kwargs...) end + +is_dense_output(prob::ODEProblem) = prob.dense_output +is_dense_output(prob) = true diff --git a/src/remake.jl b/src/remake.jl index 024036a5a..d04e64fff 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -113,10 +113,12 @@ function remake(prob::ODEProblem; f = missing, end if kwargs === missing - ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; prob.kwargs..., + ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; + dense_output = prob.dense_output, prob.kwargs..., _kwargs...) else - ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; kwargs...) + ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; + dense_output = prob.dense_output, kwargs...) end end diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index fed9ba42d..a1ffa222c 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -42,7 +42,7 @@ function build_solution(prob::AbstractNonlinearProblem, left = nothing, right = nothing, sym_map = nothing, - dep_idxs = Ref{Vector{Int}}(Int[]), + dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]), kwargs...) T = eltype(eltype(u)) N = ndims(u) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index a82dcf6a0..6e68a8ae7 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -48,7 +48,11 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, end function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, tslocation, destats, alg_choice, retcode, - sym_map = nothing, dep_idxs = Ref{Vector{Int}}(Int[])) where {T, N} + sym_map = nothing, + dep_idxs = nothing) where {T, N} + if isnothing(dep_idxs) + dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]) + end return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), @@ -168,7 +172,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), - sym_map = nothing, dep_idxs = Ref{Vector{Int}}(Int[]), + sym_map = nothing, dep_idxs = nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index dd1bbecf3..a90951ca9 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -143,50 +143,63 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end function _get_dep_idxs(A::AbstractTimeseriesSolution) - idxs = if has_sys(A.prob.f) && has_observed(A.prob.f) - is_ODAE = hasfield(typeof(A.prob.f.sys), :unknown_states) && - !isnothing(getfield(A.prob.f.sys, :unknown_states)) + SII = SymbolicIndexingInterface + if has_sys(A.prob.f) && has_observed(A.prob.f) if !isnothing(A.sym_map) - map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) - elseif is_ODAE - sts = getfield(A.prob.f.sys, :unknown_states) - map(x -> sym_to_index(x, A), - get_deps_of_observed(sts, SymbolicIndexingInterface.observed(A.prob.f.sys))) - else - map(x -> sym_to_index(x, A), get_deps_of_observed(A.prob.f.sys)) + is_ODAE = hasfield(typeof(A.prob.f.sys), :unknown_states) && + !isnothing(getfield(A.prob.f.sys, :unknown_states)) + if is_ODAE + sts = getfield(A.prob.f.sys, :unknown_states) + return map(x -> sym_to_index(x, A), + get_deps_of_observed(sts, + SII.observed(A.prob.f.sys))) + end + return map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) end - else - CartesianIndices(first(A.u)) end - idxs + return [nothing] end +idxs_initialized(idxs) = isempty(idxs) || !isnothing(first(idxs)) + function get_dep_idxs(A::AbstractTimeseriesSolution) if hasfield(typeof(A), :dep_idxs) - if isnothing(A.dep_idxs[]) - A.dep_idxs[] + if idxs_initialized(A.dep_idxs[]) + return A.dep_idxs[] else idxs = _get_dep_idxs(A) A.dep_idxs[] = idxs - idxs + return A.dep_idxs[] end else - _get_dep_idxs(A) + return [nothing] end end function observed(A::AbstractTimeseriesSolution, sym, i::Int) - idxs = get_dep_idxs(A) + dense = is_dense_output(A.prob) + idxs = dense ? [nothing] : get_dep_idxs(A) + if dense || !idxs_initialized(idxs) + return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) + end getobserved(A)(sym, A[i][idxs], A.prob.p, A.t[i]) end function observed(A::AbstractTimeseriesSolution, sym, is::AbstractArray{Int}) - idxs = get_dep_idxs(A) + dense = is_dense_output(A.prob) + idxs = dense ? [nothing] : get_dep_idxs(A) + if dense || !idxs_initialized(idxs) + return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) + end getobserved(A).((sym,), map(j -> A.u[j][idxs], is), (A.prob.p,), A.t[is]) end function observed(A::AbstractTimeseriesSolution, sym, i::Colon) - idxs = get_dep_idxs(A) + dense = is_dense_output(A.prob) + idxs = dense ? [nothing] : get_dep_idxs(A) + if dense || !idxs_initialized(idxs) + return getobserved(A).((sym,), A.u, (A.prob.p,), A.t) + end getobserved(A).((sym,), map(j -> A.u[j][idxs], eachindex(A.t)), (A.prob.p,), A.t) end @@ -195,7 +208,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) if sym isa AbstractArray return A[collect(sym)] end - i = sym_to_index(sym, A) + if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) + i = get(A.sym_map, sym, nothing) + else + i = sym_to_index(sym, A) + end elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s]', sym)) else @@ -390,95 +407,6 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug (plot_vecs...,) end -function getsyms(sol) - if has_syms(sol.prob.f) - return sol.prob.f.syms - else - return keys(sol.u[1]) - end -end -function getsyms(sol::AbstractOptimizationSolution) - if has_syms(sol) - return get_syms(sol) - else - return keys(sol.u[1]) - end -end - -function getindepsym(sol) - if has_indepsym(sol.prob.f) - return sol.prob.f.indepsym - else - return nothing - end -end - -function getparamsyms(sol) - if has_paramsyms(sol.prob.f) - return sol.prob.f.paramsyms - else - return nothing - end -end -function getparamsyms(sol::AbstractOptimizationSolution) - if has_paramsyms(sol) - return get_paramsyms(sol) - else - return nothing - end -end - -# Only for compatibility! -function getindepsym_defaultt(sol) - if has_indepsym(sol.prob.f) - return sol.prob.f.indepsym - else - return :t - end -end - -function getobserved(sol) - if has_observed(sol.prob.f) - return sol.prob.f.observed - else - return DEFAULT_OBSERVED - end -end -function getobserved(sol::AbstractOptimizationSolution) - if has_observed(sol) - return get_observed(sol) - else - return DEFAULT_OBSERVED - end -end - -cleansyms(syms::Nothing) = nothing -cleansyms(syms::Tuple) = collect(cleansym(sym) for sym in syms) -cleansyms(syms::Vector{Symbol}) = cleansym.(syms) -cleansyms(syms::LinearIndices) = nothing -cleansyms(syms::CartesianIndices) = nothing -cleansyms(syms::Base.OneTo) = nothing - -function cleansym(sym::Symbol) - str = String(sym) - # MTK generated names - rules = ("₊" => ".", "⦗" => "(", "⦘" => ")") - for r in rules - str = replace(str, r) - end - return str -end - -function sym_to_index(sym, sol::AbstractSciMLSolution) - if has_sys(sol.prob.f) && is_state_sym(sol.prob.f.sys, sym) - return state_sym_to_index(sol.prob.f.sys, sym) - else - return sym_to_index(sym, getsyms(sol)) - end -end -sym_to_index(sym, syms) = findfirst(isequal(Symbol(sym)), syms) -const issymbollike = RecursiveArrayTools.issymbollike - function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, axis_safety, vars, int_vars, tscale, strs) if tspan === nothing From 82df76ce68d5a02ebddb7b14b08e6143d11385e1 Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 27 Feb 2023 17:55:17 +0000 Subject: [PATCH 06/10] fix nonlinear --- src/solutions/solution_interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index a90951ca9..c84f72d88 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -73,6 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if sym isa AbstractArray return A[collect(sym)] end + if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) i = get(A.sym_map, sym, nothing) else @@ -234,7 +235,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) end function observed(A::AbstractNoTimeSolution, sym) - observed(A, sym, :) + getobserved(A)(sym, A.u, A.prob.p) end function observed(A::AbstractOptimizationSolution, sym) From bb17abc876d124a78c9661138941fcd401deef6c Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 9 Mar 2023 16:04:25 +0000 Subject: [PATCH 07/10] why this fail --- src/problems/ode_problems.jl | 8 +--- src/remake.jl | 6 +-- src/solutions/dae_solutions.jl | 2 +- src/solutions/nonlinear_solutions.jl | 2 +- src/solutions/ode_solutions.jl | 5 ++- src/solutions/optimization_solutions.jl | 2 +- src/solutions/rode_solutions.jl | 2 +- src/solutions/solution_interface.jl | 60 ++++++++++++------------- src/symbolic_utils.jl | 8 ++++ src/utils.jl | 9 ++++ 10 files changed, 57 insertions(+), 47 deletions(-) diff --git a/src/problems/ode_problems.jl b/src/problems/ode_problems.jl index 09e42940a..53beb2d7d 100644 --- a/src/problems/ode_problems.jl +++ b/src/problems/ode_problems.jl @@ -108,18 +108,15 @@ struct ODEProblem{uType, tType, isinplace, P, F, K, PT} <: kwargs::K """An internal argument for storing traits about the solving process.""" problem_type::PT - """Whether the output has all saved states.""" - dense_output::Bool @add_kwonly function ODEProblem{iip}(f::AbstractODEFunction{iip}, u0, tspan, p = NullParameters(), problem_type = StandardODEProblem(); - dense_output = true, kwargs...) where {iip} _tspan = promote_tspan(tspan) new{typeof(u0), typeof(_tspan), isinplace(f), typeof(p), typeof(f), typeof(kwargs), - typeof(problem_type)}(f, u0, _tspan, p, kwargs, problem_type, dense_output) + typeof(problem_type)}(f, u0, _tspan, p, kwargs, problem_type) end """ @@ -487,6 +484,3 @@ function IncrementingODEProblem{iip}(f::IncrementingODEFunction, u0, tspan, p = NullParameters(); kwargs...) where {iip} ODEProblem(f, u0, tspan, p, IncrementingODEProblem{iip}(); kwargs...) end - -is_dense_output(prob::ODEProblem) = prob.dense_output -is_dense_output(prob) = true diff --git a/src/remake.jl b/src/remake.jl index d04e64fff..024036a5a 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -113,12 +113,10 @@ function remake(prob::ODEProblem; f = missing, end if kwargs === missing - ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; - dense_output = prob.dense_output, prob.kwargs..., + ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; prob.kwargs..., _kwargs...) else - ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; - dense_output = prob.dense_output, kwargs...) + ODEProblem{isinplace(prob)}(_f, u0, tspan, p, prob.problem_type; kwargs...) end end diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index 6fcc35f46..6bb0859e9 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -75,7 +75,7 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; HermiteInterpolation(t, u, du), retcode = ReturnCode.Default, destats = nothing, - sym_map = nothing, + sym_map = default_sym_map(prob), kwargs...) T = eltype(eltype(u)) diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index 7037656e6..ae92fce54 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -53,7 +53,7 @@ function build_solution(prob::AbstractNonlinearProblem, original = nothing, left = nothing, right = nothing, - sym_map = nothing, + sym_map = default_sym_map(prob), dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]), kwargs...) T = eltype(eltype(u)) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index eacbef6df..5e7545528 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -53,6 +53,9 @@ function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense if isnothing(dep_idxs) dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]) end + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), @@ -173,7 +176,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), - sym_map = nothing, dep_idxs = nothing, + sym_map = default_sym_map(prob), dep_idxs = nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) diff --git a/src/solutions/optimization_solutions.jl b/src/solutions/optimization_solutions.jl index 3a56a87f8..1a24f7364 100644 --- a/src/solutions/optimization_solutions.jl +++ b/src/solutions/optimization_solutions.jl @@ -80,7 +80,7 @@ function build_solution(prob::AbstractOptimizationProblem, alg, u, objective; retcode = ReturnCode.Default, original = nothing, - sym_map = nothing, + sym_map = default_sym_map(prob), kwargs...) T = eltype(eltype(u)) N = ndims(u) diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index 5d7b83ff5..a8e77a4f2 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -80,7 +80,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, alg_choice = nothing, - sym_map = nothing, + sym_map = default_sym_map(prob), seed = UInt64(0), destats = nothing, kwargs...) T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index c84f72d88..eb2ecd0d7 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -73,12 +73,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if sym isa AbstractArray return A[collect(sym)] end - - if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) - i = get(A.sym_map, sym, nothing) - else - i = sym_to_index(sym, A) - end + i = state_sym_to_index(A, sym) elseif all(issymbollike, sym) if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && @@ -87,6 +82,8 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s else return [getindex.((A,), sym, i) for i in eachindex(A)] end + elseif is_symbolic_expr(sym) + return convert_to_getindex(A, sym) else i = sym end @@ -104,6 +101,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s return observed(A, sym, :) end else + @show sym observed(A, sym, :) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} @@ -118,13 +116,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if sym isa AbstractArray return A[collect(sym), args...] end - if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) - i = get(A.sym_map, sym, nothing) - else - i = sym_to_index(sym, A) - end + i = state_sym_to_index(A, sym) elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s, args...]', sym)) + elseif is_symbolic_expr(sym) + return convert_to_getindex(A, sym, args...) else i = sym end @@ -134,6 +130,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s Symbol(sym) == getindepsym(A) A.t[args...] else + @show sym observed(A, sym, args...) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} @@ -143,19 +140,20 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end end -function _get_dep_idxs(A::AbstractTimeseriesSolution) +function _get_dep_idxs(A::AbstractSciMLSolution) SII = SymbolicIndexingInterface if has_sys(A.prob.f) && has_observed(A.prob.f) if !isnothing(A.sym_map) is_ODAE = hasfield(typeof(A.prob.f.sys), :unknown_states) && !isnothing(getfield(A.prob.f.sys, :unknown_states)) if is_ODAE - sts = getfield(A.prob.f.sys, :unknown_states) - return map(x -> sym_to_index(x, A), + sts = unknown_states(A.prob.f.sys) + return map(x -> state_sym_to_index(A, x), get_deps_of_observed(sts, SII.observed(A.prob.f.sys))) end - return map(x -> A.sym_map[x], get_deps_of_observed(A.prob.f.sys)) + @show "not ODAE" + return map(x -> A.sym_map[safe_unwrap(x)], get_deps_of_observed(A.prob.f.sys)) end end return [nothing] @@ -163,11 +161,12 @@ end idxs_initialized(idxs) = isempty(idxs) || !isnothing(first(idxs)) -function get_dep_idxs(A::AbstractTimeseriesSolution) +function get_dep_idxs(A::AbstractSciMLSolution) if hasfield(typeof(A), :dep_idxs) if idxs_initialized(A.dep_idxs[]) return A.dep_idxs[] else + @show "recomputing dep_idxs" idxs = _get_dep_idxs(A) A.dep_idxs[] = idxs return A.dep_idxs[] @@ -178,27 +177,25 @@ function get_dep_idxs(A::AbstractTimeseriesSolution) end function observed(A::AbstractTimeseriesSolution, sym, i::Int) - dense = is_dense_output(A.prob) - idxs = dense ? [nothing] : get_dep_idxs(A) - if dense || !idxs_initialized(idxs) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) end getobserved(A)(sym, A[i][idxs], A.prob.p, A.t[i]) end function observed(A::AbstractTimeseriesSolution, sym, is::AbstractArray{Int}) - dense = is_dense_output(A.prob) - idxs = dense ? [nothing] : get_dep_idxs(A) - if dense || !idxs_initialized(idxs) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) end getobserved(A).((sym,), map(j -> A.u[j][idxs], is), (A.prob.p,), A.t[is]) end function observed(A::AbstractTimeseriesSolution, sym, i::Colon) - dense = is_dense_output(A.prob) - idxs = dense ? [nothing] : get_dep_idxs(A) - if dense || !idxs_initialized(idxs) + idxs = get_dep_idxs(A) + @show idxs + if !idxs_initialized(idxs) return getobserved(A).((sym,), A.u, (A.prob.p,), A.t) end getobserved(A).((sym,), map(j -> A.u[j][idxs], eachindex(A.t)), (A.prob.p,), A.t) @@ -209,11 +206,8 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) if sym isa AbstractArray return A[collect(sym)] end - if hasfield(typeof(A), :sym_map) && !isnothing(A.sym_map) - i = get(A.sym_map, sym, nothing) - else - i = sym_to_index(sym, A) - end + i = state_sym_to_index(A, sym) + elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s]', sym)) else @@ -235,7 +229,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) end function observed(A::AbstractNoTimeSolution, sym) - getobserved(A)(sym, A.u, A.prob.p) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) + return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) + end + getobserved(A)(sym, A.u[idxs], A.prob.p) end function observed(A::AbstractOptimizationSolution, sym) diff --git a/src/symbolic_utils.jl b/src/symbolic_utils.jl index 5a103814a..1dfd529b8 100644 --- a/src/symbolic_utils.jl +++ b/src/symbolic_utils.jl @@ -50,6 +50,14 @@ function getparamsyms(sol::AbstractOptimizationSolution) end end +function SymbolicIndexingInterface.state_sym_to_index(A::S, sym) where {S <: AbstractSciMLSolution} + if hasfield(S, :sym_map) && !isnothing(A.sym_map) + return get(A.sym_map, sym, nothing) + else + return sym_to_index(sym, A) + end +end + # Only for compatibility! function getindepsym_defaultt(sol) if has_indepsym(sol.prob.f) diff --git a/src/utils.jl b/src/utils.jl index 03109d75f..d424ee51d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -488,3 +488,12 @@ end _unwrap_val(::Val{B}) where {B} = B _unwrap_val(B) = B + +function default_sym_map(prob) + if has_sys(prob.f) + sts = safe_unwrap.(unknown_states(prob.f.sys)) + return Dict(sts .=> eachindex(sts)) + else + return nothing + end +end From 1d20e970b48fbe7f8afc30c3918959479e094f0e Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 10 Mar 2023 15:58:56 +0000 Subject: [PATCH 08/10] odae passing!! --- src/solutions/ode_solutions.jl | 6 +++-- src/solutions/solution_interface.jl | 34 +++++++++++++---------------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 5e7545528..f1d78297a 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -176,10 +176,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), - sym_map = default_sym_map(prob), dep_idxs = nothing, + sym_map = nothing, dep_idxs = nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) - + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end if prob.u0 === nothing N = 2 else diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index eb2ecd0d7..2316c7fa3 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -69,6 +69,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, end Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, sym) + sym = safe_unwrap(sym) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym)] @@ -82,13 +83,15 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s else return [getindex.((A,), sym, i) for i in eachindex(A)] end - elseif is_symbolic_expr(sym) - return convert_to_getindex(A, sym) else i = sym end - if i === nothing + if is_symbolic_expr(sym) + return convert_to_getindex(A, sym) + end + + if isnothing(i) if issymbollike(sym) if has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || Symbol(sym) == getindepsym(A) @@ -101,7 +104,6 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s return observed(A, sym, :) end else - @show sym observed(A, sym, :) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} @@ -112,6 +114,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, sym, args...) + sym = safe_unwrap(sym) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym), args...] @@ -119,18 +122,19 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s i = state_sym_to_index(A, sym) elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s, args...]', sym)) - elseif is_symbolic_expr(sym) - return convert_to_getindex(A, sym, args...) else i = sym end + if is_symbolic_expr(sym) + return convert_to_getindex(A, sym, args...) + end + if isnothing(i) if issymbollike(sym) && has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || Symbol(sym) == getindepsym(A) A.t[args...] else - @show sym observed(A, sym, args...) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} @@ -144,16 +148,9 @@ function _get_dep_idxs(A::AbstractSciMLSolution) SII = SymbolicIndexingInterface if has_sys(A.prob.f) && has_observed(A.prob.f) if !isnothing(A.sym_map) - is_ODAE = hasfield(typeof(A.prob.f.sys), :unknown_states) && - !isnothing(getfield(A.prob.f.sys, :unknown_states)) - if is_ODAE - sts = unknown_states(A.prob.f.sys) - return map(x -> state_sym_to_index(A, x), - get_deps_of_observed(sts, - SII.observed(A.prob.f.sys))) - end - @show "not ODAE" - return map(x -> A.sym_map[safe_unwrap(x)], get_deps_of_observed(A.prob.f.sys)) + sys = A.prob.f.sys + return map(x -> A.sym_map[safe_unwrap(x)], + get_deps_of_observed(unknown_states(sys), SII.observed(sys))) end end return [nothing] @@ -166,7 +163,6 @@ function get_dep_idxs(A::AbstractSciMLSolution) if idxs_initialized(A.dep_idxs[]) return A.dep_idxs[] else - @show "recomputing dep_idxs" idxs = _get_dep_idxs(A) A.dep_idxs[] = idxs return A.dep_idxs[] @@ -194,7 +190,6 @@ end function observed(A::AbstractTimeseriesSolution, sym, i::Colon) idxs = get_dep_idxs(A) - @show idxs if !idxs_initialized(idxs) return getobserved(A).((sym,), A.u, (A.prob.p,), A.t) end @@ -202,6 +197,7 @@ function observed(A::AbstractTimeseriesSolution, sym, i::Colon) end Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) + sym = safe_unwrap(sym) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym)] From 92d307afbe3b680c77411957b46d07b2fbb406c7 Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 13 Mar 2023 14:26:43 +0000 Subject: [PATCH 09/10] ode passing --- src/solutions/dae_solutions.jl | 161 +++++++++++++++------------ src/solutions/nonlinear_solutions.jl | 6 +- src/solutions/rode_solutions.jl | 6 +- 3 files changed, 101 insertions(+), 72 deletions(-) diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index 6bb0859e9..3b2d8d286 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ -struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE, MType} <: +struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE, MType, DI} <: AbstractDAESolution{T, N, uType} u::uType du::duType @@ -42,6 +42,7 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE, MTyp sym_map::MType destats::DE retcode::ReturnCode.T + dep_idxs::DI end function Base.show(io::IO, @@ -75,7 +76,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; HermiteInterpolation(t, u, du), retcode = ReturnCode.Default, destats = nothing, - sym_map = default_sym_map(prob), + sym_map = nothing, + dep_idxs = nothing, kwargs...) T = eltype(eltype(u)) @@ -84,6 +86,12 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; else N = length((size(prob.u0)..., length(u))) end + if isnothing(dep_idxs) + dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]) + end + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end if has_analytic(prob.f) u_analytic = Vector{typeof(prob.u0)}() @@ -91,9 +99,16 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors), typeof(t), typeof(prob), typeof(alg), typeof(interp), - typeof(destats), typeof(sym_map)}(u, du, u_analytic, errors, t, - prob, alg, interp, dense, 0, - sym_map, destats, retcode) + typeof(destats), typeof(sym_map), typeof(dep_idxs)}(u, du, + u_analytic, + errors, t, + prob, alg, + interp, dense, + 0, + sym_map, + destats, + retcode, + dep_idxs) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -102,16 +117,18 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; sol else DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t), typeof(prob), - typeof(alg), typeof(interp), typeof(destats), typeof(sym_map)}(u, du, - nothing, - nothing, - t, prob, - alg, - interp, - dense, 0, - sym_map, - destats, - retcode) + typeof(alg), typeof(interp), typeof(destats), typeof(sym_map), + typeof(dep_idxs)}(u, du, + nothing, + nothing, + t, prob, + alg, + interp, + dense, 0, + sym_map, + destats, + retcode, + dep_idxs) end end @@ -157,77 +174,81 @@ function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) wher DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), - typeof(sol.sym_map)}(sol.u, - sol.du, - u_analytic, - errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.sym_map, - sol.destats, - sol.retcode) + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u, + sol.du, + u_analytic, + errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.retcode, + sol.dep_idxs) end function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), - typeof(sol.sym_map)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.sym_map, - sol.destats, - retcode) + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u, + sol.du, + sol.u_analytic, + sol.errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + retcode, + sol.dep_idxs) end function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), - typeof(sol.sym_map)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.sym_map, - sol.destats, - sol.retcode) + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u, + sol.du, + sol.u_analytic, + sol.errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + tslocation, + sol.sym_map, + sol.destats, + sol.retcode, + sol.dep_idxs) end function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), - typeof(sol.sym_map)}(sol.u[I], - sol.du[I], - sol.u_analytic === - nothing ? - nothing : - sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.sym_map, - sol.destats, - sol.retcode) + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u[I], + sol.du[I], + sol.u_analytic === + nothing ? + nothing : + sol.u_analytic[I], + sol.errors, + sol.t[I], + sol.prob, + sol.alg, + sol.interp, + false, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.retcode, + sol.dep_idxs) end diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index ae92fce54..aa9033109 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -53,12 +53,16 @@ function build_solution(prob::AbstractNonlinearProblem, original = nothing, left = nothing, right = nothing, - sym_map = default_sym_map(prob), + sym_map = nothing, dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]), kwargs...) T = eltype(eltype(u)) N = ndims(u) + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end + NonlinearSolution{T, N, typeof(u), typeof(resid), typeof(prob), typeof(alg), typeof(original), typeof(left), typeof(sym_map), typeof(dep_idxs)}(u, resid, diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index a8e77a4f2..44bba2159 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -80,7 +80,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, alg_choice = nothing, - sym_map = default_sym_map(prob), + sym_map = nothing, seed = UInt64(0), destats = nothing, kwargs...) T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) @@ -91,6 +91,10 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, f = prob.f end + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end + if has_analytic(f) u_analytic = Vector{typeof(prob.u0)}() errors = Dict{Symbol, real(eltype(prob.u0))}() From 59dd39329e1ba1fdee68620363d62e54353ca98d Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 21 Mar 2023 15:18:30 +0000 Subject: [PATCH 10/10] add comment --- src/solutions/solution_interface.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 2316c7fa3..ec0b218f4 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -158,6 +158,12 @@ end idxs_initialized(idxs) = isempty(idxs) || !isnothing(first(idxs)) +# When you add symbolic save_idxs to a problem, give its solution type a dep_idxs +# field that is a Ref{Union{Vector{Nothing}, Vector{Int}}}. Then, when you call get_dep_idxs, it will +# check if the dep_idxs field is initialized. If it is, it will return the value. If it is not, +# it will call _get_dep_idxs to get the value, set the field, and return the value. +# You will also need to remove `dense_output = true` from the `build_explicit_observed_function` +# call in ModelingToolkit in its source system. function get_dep_idxs(A::AbstractSciMLSolution) if hasfield(typeof(A), :dep_idxs) if idxs_initialized(A.dep_idxs[])