diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 69e4e0699..a6dd9e9fd 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -811,7 +811,7 @@ export step!, deleteat!, addat!, get_tmp_cache, set_reltol!, get_du, get_du!, get_dt, get_proposed_dt, set_proposed_dt!, u_modified!, savevalues!, reinit!, auto_dt_reset!, set_t!, set_u!, check_error, change_t_via_interpolation!, addsteps!, - isdiscrete, reeval_internals_due_to_modification! + isdiscrete, reeval_internals_due_to_modification!, is_dense_output export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback diff --git a/src/solutions/dae_solutions.jl b/src/solutions/dae_solutions.jl index 9c60af570..3b2d8d286 100644 --- a/src/solutions/dae_solutions.jl +++ b/src/solutions/dae_solutions.jl @@ -27,7 +27,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ -struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE} <: +struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE, MType, DI} <: AbstractDAESolution{T, N, uType} u::uType du::duType @@ -39,8 +39,10 @@ struct DAESolution{T, N, uType, duType, uType2, DType, tType, P, A, ID, DE} <: interp::ID dense::Bool tslocation::Int + sym_map::MType destats::DE retcode::ReturnCode.T + dep_idxs::DI end function Base.show(io::IO, @@ -74,6 +76,8 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; HermiteInterpolation(t, u, du), retcode = ReturnCode.Default, destats = nothing, + sym_map = nothing, + dep_idxs = nothing, kwargs...) T = eltype(eltype(u)) @@ -82,24 +86,29 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; else N = length((size(prob.u0)..., length(u))) end + if isnothing(dep_idxs) + dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]) + end + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end if has_analytic(prob.f) u_analytic = Vector{typeof(prob.u0)}() errors = Dict{Symbol, real(eltype(prob.u0))}() sol = DAESolution{T, N, typeof(u), typeof(du), typeof(u_analytic), typeof(errors), - typeof(t), - typeof(prob), typeof(alg), typeof(interp), typeof(destats)}(u, du, - u_analytic, - errors, - t, - prob, - alg, - interp, - dense, - 0, - destats, - retcode) + typeof(t), typeof(prob), typeof(alg), typeof(interp), + typeof(destats), typeof(sym_map), typeof(dep_idxs)}(u, du, + u_analytic, + errors, t, + prob, alg, + interp, dense, + 0, + sym_map, + destats, + retcode, + dep_idxs) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -107,15 +116,19 @@ function build_solution(prob::AbstractDAEProblem, alg, t, u, du = nothing; end sol else - DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t), - typeof(prob), typeof(alg), typeof(interp), typeof(destats)}(u, du, - nothing, - nothing, t, - prob, alg, - interp, - dense, 0, - destats, - retcode) + DAESolution{T, N, typeof(u), typeof(du), Nothing, Nothing, typeof(t), typeof(prob), + typeof(alg), typeof(interp), typeof(destats), typeof(sym_map), + typeof(dep_idxs)}(u, du, + nothing, + nothing, + t, prob, + alg, + interp, + dense, 0, + sym_map, + destats, + retcode, + dep_idxs) end end @@ -160,70 +173,82 @@ end function build_solution(sol::AbstractDAESolution{T, N}, u_analytic, errors) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(u_analytic), typeof(errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u, - sol.du, - u_analytic, - errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.destats, - sol.retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u, + sol.du, + u_analytic, + errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.retcode, + sol.dep_idxs) end function solution_new_retcode(sol::AbstractDAESolution{T, N}, retcode) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.destats, - retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u, + sol.du, + sol.u_analytic, + sol.errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + retcode, + sol.dep_idxs) end function solution_new_tslocation(sol::AbstractDAESolution{T, N}, tslocation) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u, - sol.du, - sol.u_analytic, - sol.errors, - sol.t, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - tslocation, - sol.destats, - sol.retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u, + sol.du, + sol.u_analytic, + sol.errors, + sol.t, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + tslocation, + sol.sym_map, + sol.destats, + sol.retcode, + sol.dep_idxs) end function solution_slice(sol::AbstractDAESolution{T, N}, I) where {T, N} DAESolution{T, N, typeof(sol.u), typeof(sol.du), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), - typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats)}(sol.u[I], - sol.du[I], - sol.u_analytic === - nothing ? - nothing : - sol.u_analytic[I], - sol.errors, - sol.t[I], - sol.prob, - sol.alg, - sol.interp, - false, - sol.tslocation, - sol.destats, - sol.retcode) + typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), typeof(sol.destats), + typeof(sol.sym_map), typeof(sol.dep_idxs)}(sol.u[I], + sol.du[I], + sol.u_analytic === + nothing ? + nothing : + sol.u_analytic[I], + sol.errors, + sol.t[I], + sol.prob, + sol.alg, + sol.interp, + false, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.retcode, + sol.dep_idxs) end diff --git a/src/solutions/nonlinear_solutions.jl b/src/solutions/nonlinear_solutions.jl index 3317dac86..aa9033109 100644 --- a/src/solutions/nonlinear_solutions.jl +++ b/src/solutions/nonlinear_solutions.jl @@ -13,12 +13,14 @@ or the steady state solution to a differential equation defined by a SteadyState - `original`: if the solver is wrapped from an alternative solver ecosystem, such as NLsolve.jl, then this is the original return from said solver library. - `retcode`: the return code from the solver. Used to determine whether the solver solved - successfully or whether it exited due to an error. For more details, see + successfully or whether it exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). - `left`: if the solver is bracketing method, this is the final left bracket value. - `right`: if the solver is bracketing method, this is the final right bracket value. +- `sym_map`: Map of symbols to their index in the solution """ -struct NonlinearSolution{T, N, uType, R, P, A, O, uType2} <: AbstractNonlinearSolution{T, N} +struct NonlinearSolution{T, N, uType, R, P, A, O, uType2, MType, DI} <: + AbstractNonlinearSolution{T, N} u::uType resid::R prob::P @@ -27,6 +29,8 @@ struct NonlinearSolution{T, N, uType, R, P, A, O, uType2} <: AbstractNonlinearSo original::O left::uType2 right::uType2 + sym_map::MType + dep_idxs::DI end function Base.show(io::IO, @@ -49,17 +53,25 @@ function build_solution(prob::AbstractNonlinearProblem, original = nothing, left = nothing, right = nothing, + sym_map = nothing, + dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]), kwargs...) T = eltype(eltype(u)) N = ndims(u) + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end + NonlinearSolution{T, N, typeof(u), typeof(resid), - typeof(prob), typeof(alg), typeof(original), typeof(left)}(u, resid, - prob, alg, - retcode, - original, - left, - right) + typeof(prob), typeof(alg), typeof(original), typeof(left), + typeof(sym_map), typeof(dep_idxs)}(u, resid, + prob, alg, + retcode, + original, left, + right, + sym_map, + dep_idxs) end function sensitivity_solution(sol::AbstractNonlinearSolution, u) @@ -68,8 +80,15 @@ function sensitivity_solution(sol::AbstractNonlinearSolution, u) NonlinearSolution{T, N, typeof(u), typeof(sol.resid), typeof(sol.prob), typeof(sol.alg), - typeof(sol.original), typeof(sol.left)}(u, sol.resid, sol.prob, - sol.alg, sol.retcode, - sol.original, sol.left, - sol.right) + typeof(sol.original), typeof(sol.left), typeof(sol.sym_map), + typeof(dep_idxs)}(u, + sol.resid, + sol.prob, + sol.alg, + sol.retcode, + sol.original, + sol.left, + sol.right, + sol.sym_map, + sol.dep_idxs) end diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 288b94f25..f1d78297a 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -21,13 +21,14 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ - `alg`: the algorithm type used by the solver. - `destats`: statistics of the solver, such as the number of function evaluations required, number of Jacobians computed, and more. + - `sym_map`: Map of symbols to their index in the solution - `retcode`: the return code from the solver. Used to determine whether the solver solved - successfully, whether it terminated early due to a user-defined callback, or whether it - exited due to an error. For more details, see + successfully, whether it terminated early due to a user-defined callback, or whether it + exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, - AC <: Union{Nothing, Vector{Int}}} <: + AC <: Union{Nothing, Vector{Int}}, MType, DI} <: AbstractODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -39,17 +40,31 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE, interp::IType dense::Bool tslocation::Int + sym_map::MType destats::DE alg_choice::AC + dep_idxs::DI retcode::ReturnCode.T end function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, - tslocation, destats, alg_choice, retcode) where {T, N} + tslocation, destats, alg_choice, retcode, + sym_map = nothing, + dep_idxs = nothing) where {T, N} + if isnothing(dep_idxs) + dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]) + end + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(k), typeof(prob), typeof(alg), typeof(interp), typeof(destats), - typeof(alg_choice)}(u, u_analytic, errors, t, k, prob, alg, interp, - dense, tslocation, destats, alg_choice, retcode) + typeof(alg_choice), + typeof(sym_map), typeof(dep_idxs)}(u, u_analytic, errors, t, k, prob, + alg, interp, + dense, tslocation, sym_map, + destats, alg_choice, dep_idxs, + retcode) end function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, @@ -161,9 +176,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, k = nothing, alg_choice = nothing, interp = LinearInterpolation(t, u), + sym_map = nothing, dep_idxs = nothing, retcode = ReturnCode.Default, destats = nothing, kwargs...) T = eltype(eltype(u)) - + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end if prob.u0 === nothing N = 2 else @@ -190,7 +208,9 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, destats, alg_choice, - retcode) + retcode, + sym_map, + dep_idxs) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, dense_errors = dense_errors) @@ -208,7 +228,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, 0, destats, alg_choice, - retcode) + retcode, sym_map, dep_idxs) end end @@ -264,7 +284,9 @@ function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N} sol.tslocation, sol.destats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.sym_map, + sol.dep_idxs) end function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} @@ -280,7 +302,9 @@ function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N} sol.tslocation, sol.destats, sol.alg_choice, - retcode) + retcode, + sol.sym_map, + sol.dep_idxs) end function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N} @@ -296,7 +320,9 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N tslocation, sol.destats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.sym_map, + sol.dep_idxs) end function solution_slice(sol::ODESolution{T, N}, I) where {T, N} @@ -312,7 +338,9 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N} sol.tslocation, sol.destats, sol.alg_choice, - sol.retcode) + sol.retcode, + sol.sym_map, + sol.dep_idxs) end function sensitivity_solution(sol::ODESolution, u, t) @@ -330,5 +358,5 @@ function sensitivity_solution(sol::ODESolution, u, t) nothing, sol.prob, sol.alg, interp, sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, sol.retcode) + sol.destats, sol.alg_choice, sol.retcode, sol.sym_map, sol.dep_idxs) end diff --git a/src/solutions/optimization_solutions.jl b/src/solutions/optimization_solutions.jl index 5b22f7bd6..1a24f7364 100644 --- a/src/solutions/optimization_solutions.jl +++ b/src/solutions/optimization_solutions.jl @@ -15,8 +15,10 @@ Representation of the solution to a non-linear optimization defined by an Optimi - `original`: if the solver is wrapped from an alternative solver ecosystem, such as Optim.jl, then this is the original return from said solver library. - `solve_time`: Solve time from the solver in Seconds +- `sym_map`: Map of symbols to their index in the solution """ -struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, O, S} <: +struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, O, S, MType + } <: AbstractOptimizationSolution{T, N} u::uType # minimizer cache::C # optimization cache @@ -25,6 +27,7 @@ struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, retcode::ReturnCode.T original::O # original output of the optimizer solve_time::S # [s] solve time from the solver + sym_map::MType end function build_solution(cache::AbstractOptimizationCache, @@ -32,6 +35,7 @@ function build_solution(cache::AbstractOptimizationCache, retcode = ReturnCode.Default, original = nothing, solve_time = nothing, + sym_map = nothing, kwargs...) T = eltype(eltype(u)) N = ndims(u) @@ -40,12 +44,13 @@ function build_solution(cache::AbstractOptimizationCache, retcode = symbol_to_ReturnCode(retcode) OptimizationSolution{T, N, typeof(u), typeof(cache), typeof(alg), - typeof(objective), typeof(original), typeof(solve_time)}(u, cache, - alg, - objective, - retcode, - original, - solve_time) + typeof(objective), typeof(original), typeof(solve_time), + typeof(sym_map)}(u, cache, + alg, + objective, + retcode, + original, + solve_time, sym_map) end function Base.show(io::IO, @@ -75,6 +80,7 @@ function build_solution(prob::AbstractOptimizationProblem, alg, u, objective; retcode = ReturnCode.Default, original = nothing, + sym_map = default_sym_map(prob), kwargs...) T = eltype(eltype(u)) N = ndims(u) @@ -89,9 +95,12 @@ function build_solution(prob::AbstractOptimizationProblem, retcode = symbol_to_ReturnCode(retcode) OptimizationSolution{T, N, typeof(u), typeof(cache), typeof(alg), - typeof(objective), typeof(original)}(u, cache, alg, objective, - retcode, - original) + typeof(objective), typeof(original), typeof(sym_map)}(u, cache, + alg, + objective, + retcode, + original, + sym_map) end get_p(sol::OptimizationSolution) = sol.cache.p diff --git a/src/solutions/rode_solutions.jl b/src/solutions/rode_solutions.jl index 8fdbcdb07..44bba2159 100644 --- a/src/solutions/rode_solutions.jl +++ b/src/solutions/rode_solutions.jl @@ -27,13 +27,14 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ - `alg`: the algorithm type used by the solver. - `destats`: statistics of the solver, such as the number of function evaluations required, number of Jacobians computed, and more. +- `sym_map`: Map of symbols to their index in the solution - `retcode`: the return code from the solver. Used to determine whether the solver solved successfully, whether it terminated early due to a user-defined callback, or whether it exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, DE, - AC <: Union{Nothing, Vector{Int}}} <: + AC <: Union{Nothing, Vector{Int}}, MType} <: AbstractRODESolution{T, N, uType} u::uType u_analytic::uType2 @@ -45,6 +46,7 @@ struct RODESolution{T, N, uType, uType2, DType, tType, randType, P, A, IType, DE interp::IType dense::Bool tslocation::Int + sym_map::MType destats::DE alg_choice::AC retcode::ReturnCode.T @@ -78,6 +80,7 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, interp = LinearInterpolation(t, u), retcode = ReturnCode.Default, alg_choice = nothing, + sym_map = nothing, seed = UInt64(0), destats = nothing, kwargs...) T = eltype(eltype(u)) N = length((size(prob.u0)..., length(u))) @@ -88,25 +91,30 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, f = prob.f end + if isnothing(sym_map) + sym_map = default_sym_map(prob) + end + if has_analytic(f) u_analytic = Vector{typeof(prob.u0)}() errors = Dict{Symbol, real(eltype(prob.u0))}() sol = RODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), typeof(W), typeof(prob), typeof(alg), typeof(interp), typeof(destats), - typeof(alg_choice)}(u, - u_analytic, - errors, - t, W, - prob, - alg, - interp, - dense, - 0, - destats, - alg_choice, - retcode, - seed) + typeof(alg_choice), typeof(sym_map)}(u, + u_analytic, + errors, + t, W, + prob, + alg, + interp, + dense, + 0, + sym_map, + destats, + alg_choice, + retcode, + seed) if calculate_error calculate_solution_errors!(sol; timeseries_errors = timeseries_errors, @@ -117,10 +125,18 @@ function build_solution(prob::Union{AbstractRODEProblem, AbstractSDDEProblem}, else return RODESolution{T, N, typeof(u), Nothing, Nothing, typeof(t), typeof(W), typeof(prob), typeof(alg), typeof(interp), - typeof(destats), typeof(alg_choice)}(u, nothing, nothing, t, W, - prob, alg, interp, - dense, 0, destats, - alg_choice, retcode, seed) + typeof(destats), typeof(alg_choice), typeof(sym_map)}(u, + nothing, + nothing, + t, W, + prob, alg, + interp, + dense, 0, + sym_map, + destats, + alg_choice, + retcode, + seed) end end @@ -174,53 +190,87 @@ end function build_solution(sol::AbstractRODESolution{T, N}, u_analytic, errors) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(u_analytic), typeof(errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u, u_analytic, errors, - sol.t, sol.W, sol.prob, - sol.alg, sol.interp, - sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, - sol.retcode, sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u, + u_analytic, + errors, + sol.t, + sol.W, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end function solution_new_retcode(sol::AbstractRODESolution{T, N}, retcode) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic, - sol.errors, sol.t, sol.W, - sol.prob, sol.alg, sol.interp, - sol.dense, sol.tslocation, - sol.destats, sol.alg_choice, - retcode, - sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u, + sol.u_analytic, + sol.errors, + sol.t, + sol.W, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + retcode, + sol.seed) end function solution_new_tslocation(sol::AbstractRODESolution{T, N}, tslocation) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u, sol.u_analytic, - sol.errors, sol.t, sol.W, - sol.prob, sol.alg, sol.interp, - sol.dense, tslocation, - sol.destats, sol.alg_choice, - sol.retcode, - sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u, + sol.u_analytic, + sol.errors, + sol.t, + sol.W, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end function solution_slice(sol::AbstractRODESolution{T, N}, I) where {T, N} RODESolution{T, N, typeof(sol.u), typeof(sol.u_analytic), typeof(sol.errors), typeof(sol.t), typeof(sol.W), typeof(sol.prob), typeof(sol.alg), typeof(sol.interp), - typeof(sol.destats), typeof(sol.alg_choice)}(sol.u[I], - sol.u_analytic === nothing ? - nothing : sol.u_analytic, - sol.errors, sol.t[I], - sol.W, sol.prob, - sol.alg, sol.interp, - false, sol.tslocation, - sol.destats, sol.alg_choice, - sol.retcode, sol.seed) + typeof(sol.destats), typeof(sol.alg_choice), typeof(sol.sym_map)}(sol.u[I], + sol.u_analytic === + nothing ? + nothing : + sol.u_analytic, + sol.errors, + sol.t[I], + sol.W, + sol.prob, + sol.alg, + sol.interp, + false, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end function sensitivity_solution(sol::AbstractRODESolution, u, t) @@ -237,18 +287,20 @@ function sensitivity_solution(sol::AbstractRODESolution, u, t) RODESolution{T, N, typeof(u), typeof(sol.u_analytic), typeof(sol.errors), typeof(t), typeof(nothing), typeof(sol.prob), typeof(sol.alg), - typeof(sol.interp), typeof(sol.destats), typeof(sol.alg_choice)}(u, - sol.u_analytic, - sol.errors, - t, - nothing, - sol.prob, - sol.alg, - sol.interp, - sol.dense, - sol.tslocation, - sol.destats, - sol.alg_choice, - sol.retcode, - sol.seed) + typeof(sol.interp), typeof(sol.destats), typeof(sol.alg_choice), + typeof(sol.sym_map)}(u, + sol.u_analytic, + sol.errors, + t, + nothing, + sol.prob, + sol.alg, + sol.interp, + sol.dense, + sol.tslocation, + sol.sym_map, + sol.destats, + sol.alg_choice, + sol.retcode, + sol.seed) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 87ba5d358..ec0b218f4 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -69,11 +69,12 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, end Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, sym) + sym = safe_unwrap(sym) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym)] end - i = sym_to_index(sym, A) + i = state_sym_to_index(A, sym) elseif all(issymbollike, sym) if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && @@ -86,7 +87,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s i = sym end - if i === nothing + if is_symbolic_expr(sym) + return convert_to_getindex(A, sym) + end + + if isnothing(i) if issymbollike(sym) if has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || Symbol(sym) == getindepsym(A) @@ -109,18 +114,23 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, sym, args...) + sym = safe_unwrap(sym) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym), args...] end - i = sym_to_index(sym, A) + i = state_sym_to_index(A, sym) elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s, args...]', sym)) else i = sym end - if i === nothing + if is_symbolic_expr(sym) + return convert_to_getindex(A, sym, args...) + end + + if isnothing(i) if issymbollike(sym) && has_sys(A.prob.f) && is_indep_sym(A.prob.f.sys, sym) || Symbol(sym) == getindepsym(A) A.t[args...] @@ -134,31 +144,79 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s end end +function _get_dep_idxs(A::AbstractSciMLSolution) + SII = SymbolicIndexingInterface + if has_sys(A.prob.f) && has_observed(A.prob.f) + if !isnothing(A.sym_map) + sys = A.prob.f.sys + return map(x -> A.sym_map[safe_unwrap(x)], + get_deps_of_observed(unknown_states(sys), SII.observed(sys))) + end + end + return [nothing] +end + +idxs_initialized(idxs) = isempty(idxs) || !isnothing(first(idxs)) + +# When you add symbolic save_idxs to a problem, give its solution type a dep_idxs +# field that is a Ref{Union{Vector{Nothing}, Vector{Int}}}. Then, when you call get_dep_idxs, it will +# check if the dep_idxs field is initialized. If it is, it will return the value. If it is not, +# it will call _get_dep_idxs to get the value, set the field, and return the value. +# You will also need to remove `dense_output = true` from the `build_explicit_observed_function` +# call in ModelingToolkit in its source system. +function get_dep_idxs(A::AbstractSciMLSolution) + if hasfield(typeof(A), :dep_idxs) + if idxs_initialized(A.dep_idxs[]) + return A.dep_idxs[] + else + idxs = _get_dep_idxs(A) + A.dep_idxs[] = idxs + return A.dep_idxs[] + end + else + return [nothing] + end +end + function observed(A::AbstractTimeseriesSolution, sym, i::Int) - getobserved(A)(sym, A[i], A.prob.p, A.t[i]) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) + return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) + end + getobserved(A)(sym, A[i][idxs], A.prob.p, A.t[i]) end -function observed(A::AbstractTimeseriesSolution, sym, i::AbstractArray{Int}) - getobserved(A).((sym,), A.u[i], (A.prob.p,), A.t[i]) +function observed(A::AbstractTimeseriesSolution, sym, is::AbstractArray{Int}) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) + return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) + end + getobserved(A).((sym,), map(j -> A.u[j][idxs], is), (A.prob.p,), A.t[is]) end function observed(A::AbstractTimeseriesSolution, sym, i::Colon) - getobserved(A).((sym,), A.u, (A.prob.p,), A.t) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) + return getobserved(A).((sym,), A.u, (A.prob.p,), A.t) + end + getobserved(A).((sym,), map(j -> A.u[j][idxs], eachindex(A.t)), (A.prob.p,), A.t) end Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) + sym = safe_unwrap(sym) if issymbollike(sym) if sym isa AbstractArray return A[collect(sym)] end - i = sym_to_index(sym, A) + i = state_sym_to_index(A, sym) + elseif all(issymbollike, sym) return reduce(vcat, map(s -> A[s]', sym)) else i = sym end - if i == nothing + if isnothing(i) paramsyms = getparamsyms(A) if issymbollike(sym) && paramsyms !== nothing && Symbol(sym) in paramsyms get_p(A)[findfirst(x -> isequal(x, Symbol(sym)), paramsyms)] @@ -173,7 +231,11 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractNoTimeSolution, sym) end function observed(A::AbstractNoTimeSolution, sym) - getobserved(A)(sym, A.u, A.prob.p) + idxs = get_dep_idxs(A) + if !idxs_initialized(idxs) + return getobserved(A)(sym, A.u[i], A.prob.p, A.t[i]) + end + getobserved(A)(sym, A.u[idxs], A.prob.p) end function observed(A::AbstractOptimizationSolution, sym) diff --git a/src/symbolic_utils.jl b/src/symbolic_utils.jl index 5a103814a..1dfd529b8 100644 --- a/src/symbolic_utils.jl +++ b/src/symbolic_utils.jl @@ -50,6 +50,14 @@ function getparamsyms(sol::AbstractOptimizationSolution) end end +function SymbolicIndexingInterface.state_sym_to_index(A::S, sym) where {S <: AbstractSciMLSolution} + if hasfield(S, :sym_map) && !isnothing(A.sym_map) + return get(A.sym_map, sym, nothing) + else + return sym_to_index(sym, A) + end +end + # Only for compatibility! function getindepsym_defaultt(sol) if has_indepsym(sol.prob.f) diff --git a/src/utils.jl b/src/utils.jl index 03109d75f..d424ee51d 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -488,3 +488,12 @@ end _unwrap_val(::Val{B}) where {B} = B _unwrap_val(B) = B + +function default_sym_map(prob) + if has_sys(prob.f) + sts = safe_unwrap.(unknown_states(prob.f.sys)) + return Dict(sts .=> eachindex(sts)) + else + return nothing + end +end