Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

symbolic save idxs, new observed #392

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/SciMLBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ export step!, deleteat!, addat!, get_tmp_cache,
set_reltol!, get_du, get_du!, get_dt, get_proposed_dt, set_proposed_dt!,
u_modified!, savevalues!, reinit!, auto_dt_reset!, set_t!,
set_u!, check_error, change_t_via_interpolation!, addsteps!,
isdiscrete, reeval_internals_due_to_modification!
isdiscrete, reeval_internals_due_to_modification!, is_dense_output

export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback

Expand Down
171 changes: 98 additions & 73 deletions src/solutions/dae_solutions.jl

Large diffs are not rendered by default.

43 changes: 31 additions & 12 deletions src/solutions/nonlinear_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ or the steady state solution to a differential equation defined by a SteadyState
- `original`: if the solver is wrapped from an alternative solver ecosystem, such as
NLsolve.jl, then this is the original return from said solver library.
- `retcode`: the return code from the solver. Used to determine whether the solver solved
successfully or whether it exited due to an error. For more details, see
successfully or whether it exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
- `left`: if the solver is bracketing method, this is the final left bracket value.
- `right`: if the solver is bracketing method, this is the final right bracket value.
- `sym_map`: Map of symbols to their index in the solution
"""
struct NonlinearSolution{T, N, uType, R, P, A, O, uType2} <: AbstractNonlinearSolution{T, N}
struct NonlinearSolution{T, N, uType, R, P, A, O, uType2, MType, DI} <:
AbstractNonlinearSolution{T, N}
u::uType
resid::R
prob::P
Expand All @@ -27,6 +29,8 @@ struct NonlinearSolution{T, N, uType, R, P, A, O, uType2} <: AbstractNonlinearSo
original::O
left::uType2
right::uType2
sym_map::MType
dep_idxs::DI
end

function Base.show(io::IO,
Expand All @@ -49,17 +53,25 @@ function build_solution(prob::AbstractNonlinearProblem,
original = nothing,
left = nothing,
right = nothing,
sym_map = nothing,
dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing]),
kwargs...)
T = eltype(eltype(u))
N = ndims(u)

if isnothing(sym_map)
sym_map = default_sym_map(prob)
end

NonlinearSolution{T, N, typeof(u), typeof(resid),
typeof(prob), typeof(alg), typeof(original), typeof(left)}(u, resid,
prob, alg,
retcode,
original,
left,
right)
typeof(prob), typeof(alg), typeof(original), typeof(left),
typeof(sym_map), typeof(dep_idxs)}(u, resid,
prob, alg,
retcode,
original, left,
right,
sym_map,
dep_idxs)
end

function sensitivity_solution(sol::AbstractNonlinearSolution, u)
Expand All @@ -68,8 +80,15 @@ function sensitivity_solution(sol::AbstractNonlinearSolution, u)

NonlinearSolution{T, N, typeof(u), typeof(sol.resid),
typeof(sol.prob), typeof(sol.alg),
typeof(sol.original), typeof(sol.left)}(u, sol.resid, sol.prob,
sol.alg, sol.retcode,
sol.original, sol.left,
sol.right)
typeof(sol.original), typeof(sol.left), typeof(sol.sym_map),
typeof(dep_idxs)}(u,
sol.resid,
sol.prob,
sol.alg,
sol.retcode,
sol.original,
sol.left,
sol.right,
sol.sym_map,
sol.dep_idxs)
end
56 changes: 42 additions & 14 deletions src/solutions/ode_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/
- `alg`: the algorithm type used by the solver.
- `destats`: statistics of the solver, such as the number of function evaluations required,
number of Jacobians computed, and more.
- `sym_map`: Map of symbols to their index in the solution
- `retcode`: the return code from the solver. Used to determine whether the solver solved
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
successfully, whether it terminated early due to a user-defined callback, or whether it
exited due to an error. For more details, see
[the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes).
"""
struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE,
AC <: Union{Nothing, Vector{Int}}} <:
AC <: Union{Nothing, Vector{Int}}, MType, DI} <:
AbstractODESolution{T, N, uType}
u::uType
u_analytic::uType2
Expand All @@ -39,17 +40,31 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, DE,
interp::IType
dense::Bool
tslocation::Int
sym_map::MType
destats::DE
alg_choice::AC
dep_idxs::DI
retcode::ReturnCode.T
end
function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense,
tslocation, destats, alg_choice, retcode) where {T, N}
tslocation, destats, alg_choice, retcode,
sym_map = nothing,
dep_idxs = nothing) where {T, N}
if isnothing(dep_idxs)
dep_idxs = Ref{Vector{Union{Int, Nothing}}}(Union{Int, Nothing}[nothing])
end
if isnothing(sym_map)
sym_map = default_sym_map(prob)
end
return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t),
typeof(k), typeof(prob), typeof(alg), typeof(interp),
typeof(destats),
typeof(alg_choice)}(u, u_analytic, errors, t, k, prob, alg, interp,
dense, tslocation, destats, alg_choice, retcode)
typeof(alg_choice),
typeof(sym_map), typeof(dep_idxs)}(u, u_analytic, errors, t, k, prob,
alg, interp,
dense, tslocation, sym_map,
destats, alg_choice, dep_idxs,
retcode)
end

function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing,
Expand Down Expand Up @@ -161,9 +176,12 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
k = nothing,
alg_choice = nothing,
interp = LinearInterpolation(t, u),
sym_map = nothing, dep_idxs = nothing,
retcode = ReturnCode.Default, destats = nothing, kwargs...)
T = eltype(eltype(u))

if isnothing(sym_map)
sym_map = default_sym_map(prob)
end
if prob.u0 === nothing
N = 2
else
Expand All @@ -190,7 +208,9 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
0,
destats,
alg_choice,
retcode)
retcode,
sym_map,
dep_idxs)
if calculate_error
calculate_solution_errors!(sol; timeseries_errors = timeseries_errors,
dense_errors = dense_errors)
Expand All @@ -208,7 +228,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem},
0,
destats,
alg_choice,
retcode)
retcode, sym_map, dep_idxs)
end
end

Expand Down Expand Up @@ -264,7 +284,9 @@ function build_solution(sol::ODESolution{T, N}, u_analytic, errors) where {T, N}
sol.tslocation,
sol.destats,
sol.alg_choice,
sol.retcode)
sol.retcode,
sol.sym_map,
sol.dep_idxs)
end

function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N}
Expand All @@ -280,7 +302,9 @@ function solution_new_retcode(sol::ODESolution{T, N}, retcode) where {T, N}
sol.tslocation,
sol.destats,
sol.alg_choice,
retcode)
retcode,
sol.sym_map,
sol.dep_idxs)
end

function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N}
Expand All @@ -296,7 +320,9 @@ function solution_new_tslocation(sol::ODESolution{T, N}, tslocation) where {T, N
tslocation,
sol.destats,
sol.alg_choice,
sol.retcode)
sol.retcode,
sol.sym_map,
sol.dep_idxs)
end

function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
Expand All @@ -312,7 +338,9 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N}
sol.tslocation,
sol.destats,
sol.alg_choice,
sol.retcode)
sol.retcode,
sol.sym_map,
sol.dep_idxs)
end

function sensitivity_solution(sol::ODESolution, u, t)
Expand All @@ -330,5 +358,5 @@ function sensitivity_solution(sol::ODESolution, u, t)
nothing, sol.prob,
sol.alg, interp,
sol.dense, sol.tslocation,
sol.destats, sol.alg_choice, sol.retcode)
sol.destats, sol.alg_choice, sol.retcode, sol.sym_map, sol.dep_idxs)
end
29 changes: 19 additions & 10 deletions src/solutions/optimization_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ Representation of the solution to a non-linear optimization defined by an Optimi
- `original`: if the solver is wrapped from an alternative solver ecosystem, such as
Optim.jl, then this is the original return from said solver library.
- `solve_time`: Solve time from the solver in Seconds
- `sym_map`: Map of symbols to their index in the solution
"""
struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, O, S} <:
struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV, O, S, MType
} <:
AbstractOptimizationSolution{T, N}
u::uType # minimizer
cache::C # optimization cache
Expand All @@ -25,13 +27,15 @@ struct OptimizationSolution{T, N, uType, C <: AbstractOptimizationCache, A, OV,
retcode::ReturnCode.T
original::O # original output of the optimizer
solve_time::S # [s] solve time from the solver
sym_map::MType
end

function build_solution(cache::AbstractOptimizationCache,
alg, u, objective;
retcode = ReturnCode.Default,
original = nothing,
solve_time = nothing,
sym_map = nothing,
kwargs...)
T = eltype(eltype(u))
N = ndims(u)
Expand All @@ -40,12 +44,13 @@ function build_solution(cache::AbstractOptimizationCache,
retcode = symbol_to_ReturnCode(retcode)

OptimizationSolution{T, N, typeof(u), typeof(cache), typeof(alg),
typeof(objective), typeof(original), typeof(solve_time)}(u, cache,
alg,
objective,
retcode,
original,
solve_time)
typeof(objective), typeof(original), typeof(solve_time),
typeof(sym_map)}(u, cache,
alg,
objective,
retcode,
original,
solve_time, sym_map)
end

function Base.show(io::IO,
Expand Down Expand Up @@ -75,6 +80,7 @@ function build_solution(prob::AbstractOptimizationProblem,
alg, u, objective;
retcode = ReturnCode.Default,
original = nothing,
sym_map = default_sym_map(prob),
kwargs...)
T = eltype(eltype(u))
N = ndims(u)
Expand All @@ -89,9 +95,12 @@ function build_solution(prob::AbstractOptimizationProblem,
retcode = symbol_to_ReturnCode(retcode)

OptimizationSolution{T, N, typeof(u), typeof(cache), typeof(alg),
typeof(objective), typeof(original)}(u, cache, alg, objective,
retcode,
original)
typeof(objective), typeof(original), typeof(sym_map)}(u, cache,
alg,
objective,
retcode,
original,
sym_map)
end

get_p(sol::OptimizationSolution) = sol.cache.p
Expand Down
Loading