From cea05367b3c77c0e1c787921dbc4af9f77cc1b71 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 31 May 2024 16:14:09 +0530 Subject: [PATCH 01/10] feat: implement new SII discrete saving interface --- src/solutions/ode_solutions.jl | 165 ++++++++++++++++++++++++---- src/solutions/solution_interface.jl | 5 +- 2 files changed, 146 insertions(+), 24 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index 5636c4993..ba339b972 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -105,7 +105,7 @@ https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/ exited due to an error. For more details, see [the return code documentation](https://docs.sciml.ai/SciMLBase/stable/interfaces/Solutions/#retcodes). """ -struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S, +struct ODESolution{T, N, uType, uType2, DType, tType, rateType, discType, P, A, IType, S, AC <: Union{Nothing, Vector{Int}}, R, O} <: AbstractODESolution{T, N, uType} u::uType @@ -113,6 +113,7 @@ struct ODESolution{T, N, uType, uType2, DType, tType, rateType, P, A, IType, S, errors::DType t::tType k::rateType + discretes::discType prob::P alg::A interp::IType @@ -135,7 +136,7 @@ function ConstructionBase.setproperties(sol::ODESolution, patch::NamedTuple) T = eltype(eltype(u)) patch = merge(getproperties(sol), patch) return ODESolution{T, N}(patch.u, patch.u_analytic, patch.errors, patch.t, patch.k, - patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, + patch.discretes, patch.prob, patch.alg, patch.interp, patch.dense, patch.tslocation, patch.stats, patch.alg_choice, patch.retcode, patch.resid, patch.original) end @@ -150,13 +151,14 @@ Base.@propagate_inbounds function Base.getproperty(x::AbstractODESolution, s::Sy end # FIXME: Remove the defaults for resid and original on a breaking release -function ODESolution{T, N}(u, u_analytic, errors, t, k, prob, alg, interp, dense, +function ODESolution{T, N}( + u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense, tslocation, stats, alg_choice, retcode, resid = nothing, original = nothing) where {T, N} return ODESolution{T, N, typeof(u), typeof(u_analytic), typeof(errors), typeof(t), - typeof(k), typeof(prob), typeof(alg), typeof(interp), + typeof(k), typeof(discretes), typeof(prob), typeof(alg), typeof(interp), typeof(stats), typeof(alg_choice), typeof(resid), - typeof(original)}(u, u_analytic, errors, t, k, prob, alg, interp, + typeof(original)}(u, u_analytic, errors, t, k, discretes, prob, alg, interp, dense, tslocation, stats, alg_choice, retcode, resid, original) end @@ -172,6 +174,22 @@ function error_if_observed_derivative(sys, idx, ::Type) end end +function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where { + T1, T2, T3, T4, T5, T6, T7, + S <: ODESolution{T1, T2, T3, T4, T5, T6, T7, <:ParameterTimeseriesCollection}} + Timeseries() +end + +function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continuity) + is_parameter_timeseries(sol) == Timeseries() || return nothing + + discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) + interp_discs = map(discs) do partition + ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + end + return ParameterTimeseriesCollection(interp_discs, parameter_values(discs)) +end + function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} sol(t, deriv, idxs, continuity) @@ -188,7 +206,8 @@ end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs::Nothing, continuity) where {deriv} - augment(sol.interp(t, idxs, deriv, sol.prob.p, continuity), sol) + discretes = get_interpolated_discretes(sol, t, deriv, continuity) + augment(sol.interp(t, idxs, deriv, sol.prob.p, continuity), sol; discretes) end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::Integer, @@ -224,11 +243,23 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) - if is_parameter(sol, idxs) - return getp(sol, idxs)(sol) - else - return augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol)[idxs][1] + ps = parameter_values(sol) + if is_parameter(sol, idxs) && !is_timeseries_parameter(sol, idxs) + return getp(sol, idxs)(ps) + end + # NOTE: This is basically SII.parameter_values_at_time but that isn't public API + # and once we move interpolation to SII, there's no reason for it to be + if is_parameter_timeseries(sol) == Timeseries() + discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) + ps = parameter_values(discs) + for ts_idx in eachindex(discs) + partition = discs[ts_idx] + interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val) + end end + state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) + return getu(sol, idxs)(state) end function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVector, @@ -238,21 +269,30 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect error("Incorrect specification of `idxs`") end error_if_observed_derivative(sol, idxs, deriv) - interp_sol = augment(sol.interp([t], nothing, deriv, sol.prob.p, continuity), sol) - first(interp_sol[idxs]) + ps = parameter_values(sol) + # NOTE: This is basically SII.parameter_values_at_time but that isn't public API + # and once we move interpolation to SII, there's no reason for it to be + if is_parameter_timeseries(sol) == Timeseries() + discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) + ps = parameter_values(discs) + for ts_idx in eachindex(discs) + partition = discs[ts_idx] + interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val) + end + end + state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) + return getu(sol, idxs)(state) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, idxs, continuity) where {deriv} symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) - if is_parameter(sol, idxs) - return getp(sol, idxs)(sol) - else - interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) - p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - return DiffEqArray(interp_sol[idxs], t, p, sol) - end + p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + discretes = get_interpolated_discretes(sol, t, deriv, continuity) + interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes) + return DiffEqArray(getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, @@ -260,11 +300,58 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, all(!isequal(NotSymbolic()), symbolic_type.(idxs)) || error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) - interp_sol = augment(sol.interp(t, nothing, deriv, sol.prob.p, continuity), sol) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - indexed_sol = interp_sol[idxs] + discretes = get_interpolated_discretes(sol, t, deriv, continuity) + interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes) return DiffEqArray( - [indexed_sol[i] for i in 1:length(t)], t, p, sol) + getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes) +end + +# public API, used by MTK +""" + create_parameter_timeseries_collection(sys, ps) + +Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system +`sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters. +Defaults to `nothing`. +""" +function create_parameter_timeseries_collection(sys, ps, tspan) + return nothing +end + +const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange} + +# public API, used by MTK +""" + get_saveable_values(ps, timeseries_idx) +""" +function get_saveable_values end + +function save_discretes!(integ::DEIntegrator, timeseries_idx) + save_discretes!(integ.sol, current_time(integ), get_saveable_values(parameter_values(integ), timeseries_idx), timeseries_idx) +end + +save_discretes!(args...) = nothing + +# public API, used by MTK +function save_discretes!(sol::AbstractODESolution, t, vals, timeseries_idx) + RecursiveArrayTools.has_discretes(sol) || return + disc = RecursiveArrayTools.get_discretes(sol) + _save_discretes_internal!(disc[timeseries_idx], t, vals) +end + +function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals) + push!(A.t, t) + push!(A.u, vals) +end + +function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals) + # This is O(1) because A.t is a range + idx = searchsortedlast(A.t, t) + if idx == firstindex(A.t) - 1 || A.t[idx] ≉ t + error("Tried to save periodic discrete value with timeseries $(A.t) at time $t") + end + push!(A.u, vals) end function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, @@ -305,6 +392,8 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, Base.depwarn(msg, :build_solution) end + ps = parameter_values(prob) + discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan) if has_analytic(f) u_analytic = Vector{typeof(prob.u0)}() errors = Dict{Symbol, real(eltype(prob.u0))}() @@ -312,6 +401,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, u_analytic, errors, t, k, + discretes, prob, alg, interp, @@ -332,6 +422,7 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, nothing, nothing, t, k, + discretes, prob, alg, interp, @@ -413,6 +504,36 @@ function solution_slice(sol::ODESolution{T, N}, I) where {T, N} return @set sol.alg = false end +mask_discretes(::Nothing, _, _...) = nothing + +function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::Union{Int, CartesianIndex}) + masked_discretes = map(discretes) do disc + i = searchsortedlast(disc.t, new_t) + disc[i:i] + end + return ParameterTimeseriesCollection(masked_discretes, parameter_values(discretes)) +end + +function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, ::AbstractRange) + mint, maxt = extrema(new_t) + masked_discretes = map(discretes) do disc + mini = searchsortedfirst(disc.t, mint) + maxi = searchsortedlast(disc.t, maxt) + disc[mini:maxi] + end + return ParameterTimeseriesCollection(masked_discretes, parameter_values(discretes)) +end + +function mask_discretes(discretes::ParameterTimeseriesCollection, new_t, _) + masked_discretes = map(discretes) do disc + idxs = map(new_t) do t + searchsortedlast(disc.t, t) + end + disc[idxs] + end + return ParameterTimeseriesCollection(masked_discretes, parameter_values(discretes)) +end + function sensitivity_solution(sol::ODESolution, u, t) T = eltype(eltype(u)) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 22fdb62a1..105522f6d 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -22,9 +22,10 @@ function Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution) end # For augmenting system information to enable symbol based indexing of interpolated solutions -function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution) where {T, N, Q, B} +function augment(A::DiffEqArray{T, N, Q, B}, sol::AbstractODESolution; + discretes = nothing) where {T, N, Q, B} p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing - return DiffEqArray(A.u, A.t, p, sol) + return DiffEqArray(A.u, A.t, p, sol; discretes) end # SymbolicIndexingInterface.jl From 854e4fd7c69c4d9d432e2d6da9fdd16b45ea4e37 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 11 Mar 2024 15:42:45 +0530 Subject: [PATCH 02/10] fix: fix ODESolution-related adjoints --- ext/SciMLBaseZygoteExt.jl | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ext/SciMLBaseZygoteExt.jl b/ext/SciMLBaseZygoteExt.jl index 85ede1a0f..950b97dcb 100644 --- a/ext/SciMLBaseZygoteExt.jl +++ b/ext/SciMLBaseZygoteExt.jl @@ -33,7 +33,7 @@ import SciMLStructures N = length((size(dprob.u0)..., length(du))) end Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, + VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end @@ -60,7 +60,7 @@ end T = eltype(eltype(VA.u)) N = ndims(VA) Δ′ = ODESolution{T, N}(du, nothing, nothing, - VA.t, VA.k, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, + VA.t, VA.k, VA.discretes, dprob, VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode) (Δ′, nothing, nothing) end @@ -117,9 +117,11 @@ end elseif i === nothing throw(error("Zygote AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated.")) else - Δ′ = [[i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)] - for (x, j) in zip(VA.u, 1:length(VA))] - (Δ′, nothing) + VA = recursivecopy(VA) + recursivefill!(VA, zero(eltype(VA))) + v = view(VA, i, ntuple(_ -> :, ndims(VA) - 1)...) + copyto!(v, Δ) + (VA, nothing) end end VA[sym], ODESolution_getindex_pullback @@ -172,15 +174,15 @@ end VA[sym], ODESolution_getindex_pullback end -@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14 -}(u, +@adjoint function ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13 +, T14, T15}(u, args...) where {T1, T2, T3, T4, T5, T6, T7, T8, - T9, T10, T11, T12, T13, T14} + T9, T10, T11, T12, T13, T14, T15} function ODESolutionAdjoint(ȳ) (ȳ, ntuple(_ -> nothing, length(args))...) end - ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14}(u, args...), + ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15}(u, args...), ODESolutionAdjoint end From 5de5d1be42045a3598015292cbf9e8561ee66c6d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 8 Jul 2024 02:23:17 +0530 Subject: [PATCH 03/10] feat: move clock interface here --- Project.toml | 2 ++ src/SciMLBase.jl | 4 +++ src/clock.jl | 91 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+) create mode 100644 src/clock.jl diff --git a/Project.toml b/Project.toml index 0c55ac306..d2850146f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56" +Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636" FunctionWrappersWrappers = "77dc65aa-8811-40c2-897b-53d922fa7daf" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -62,6 +63,7 @@ DataFrames = "1.6" Distributed = "1.10" DocStringExtensions = "0.9" EnumX = "1" +Expronicon = "0.8" ForwardDiff = "0.10.36" FunctionWrappersWrappers = "0.1.3" IteratorInterfaceExtensions = "^1" diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index fc1623d25..18688eb35 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -23,6 +23,7 @@ import RuntimeGeneratedFunctions import EnumX import ADTypes: AbstractADType import Accessors: @set, @reset +using Expronicon.ADT: @match using Reexport using SciMLOperators @@ -717,6 +718,7 @@ include("problems/problem_traits.jl") include("problems/problem_interface.jl") include("problems/optimization_problems.jl") +include("clock.jl") include("solutions/basic_solutions.jl") include("solutions/nonlinear_solutions.jl") include("solutions/ode_solutions.jl") @@ -835,4 +837,6 @@ export step!, deleteat!, addat!, get_tmp_cache, export ContinuousCallback, DiscreteCallback, CallbackSet, VectorContinuousCallback +export Clocks, TimeDomain, is_discrete_time_domain, isclock, issolverstepclock, iscontinuous + end diff --git a/src/clock.jl b/src/clock.jl new file mode 100644 index 000000000..ba0b59278 --- /dev/null +++ b/src/clock.jl @@ -0,0 +1,91 @@ +module Clocks + +export TimeDomain + +using Expronicon.ADT: @adt, @match + +@adt TimeDomain begin + Continuous + struct PeriodicClock + dt::Union{Nothing, Float64, Rational{Int}} + end + SolverStepClock +end + +Base.Broadcast.broadcastable(d::TimeDomain) = Ref(d) + +end + +using .Clocks + +""" + Clock(dt) + Clock() + +The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will +be inferred (if possible). +""" +Clock(dt::Union{<:Rational, Float64}) = PeriodicClock(dt) +Clock(dt) = PeriodicClock(convert(Float64, dt)) +Clock() = PeriodicClock(nothing) + +@doc """ + SolverStepClock + +A clock that ticks at each solver step (sometimes referred to as "continuous sample time"). +This clock **does generally not have equidistant tick intervals**, instead, the tick +interval depends on the adaptive step-size selection of the continuous solver, as well as +any continuous event handling. If adaptivity of the solver is turned off and there are no +continuous events, the tick interval will be given by the fixed solver time step `dt`. + +Due to possibly non-equidistant tick intervals, this clock should typically not be used with +discrete-time systems that assume a fixed sample time, such as PID controllers and digital +filters. +""" SolverStepClock + +isclock(c) = @match c begin + PeriodicClock(_) => true + _ => false +end + +issolverstepclock(c) = @match c begin + &SolverStepClock => true + _ => false +end + +iscontinuous(c) = @match c begin + &Continuous => true + _ => false +end + +is_discrete_time_domain(c) = !iscontinuous(c) + +function first_clock_tick_time(c, t0) + @match c begin + PeriodicClock(dt) => ceil(t0 / dt) * dt + &SolverStepClock => t0 + &Continuous => error("Continuous is not a discrete clock") + end +end + +struct IndexedClock{I} + clock::TimeDomain + idx::I +end + +Base.getindex(c::TimeDomain, idx) = IndexedClock(c, idx) + +function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSolution) + c = ic.clock + + return @match c begin + PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + &SolverStepClock => begin + ssc_idx = findfirst(eachindex(sol.discretes)) do i + !isa(sol.discretes[i].t, AbstractRange) + end + sol.discretes[ssc_idx].t[ic.idx] + end + &Continuous => sol.t[ic.idx] + end +end From 51f554b85996b7d6ba5a8bd38f5756e153036584 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 8 Jul 2024 02:24:48 +0530 Subject: [PATCH 04/10] refactor: refactor interpolation with indexed clocks --- src/solutions/ode_solutions.jl | 95 ++++++++++++++++++++++++++-------- 1 file changed, 73 insertions(+), 22 deletions(-) diff --git a/src/solutions/ode_solutions.jl b/src/solutions/ode_solutions.jl index ba339b972..9d8f549e6 100644 --- a/src/solutions/ode_solutions.jl +++ b/src/solutions/ode_solutions.jl @@ -180,22 +180,45 @@ function SymbolicIndexingInterface.is_parameter_timeseries(::Type{S}) where { Timeseries() end +function _hold_discrete(disc_u, disc_t, t::Number) + idx = searchsortedlast(disc_t, t) + if idx == firstindex(disc_t) - 1 + error("Cannot access discrete variable at time $t before initial save $(first(disc_t))") + end + return disc_u[idx] +end + +function hold_discrete(disc_u, disc_t, t::Number) + val = _hold_discrete(disc_u, disc_t, t) + return DiffEqArray([val], [t]) +end + +function hold_discrete(disc_u, disc_t, t::AbstractVector{<:Number}) + return DiffEqArray(_hold_discrete.((disc_u,), (disc_t,), t), t) +end + function get_interpolated_discretes(sol::AbstractODESolution, t, deriv, continuity) is_parameter_timeseries(sol) == Timeseries() || return nothing discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) interp_discs = map(discs) do partition - ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) + hold_discrete(partition.u, partition.t, t) end return ParameterTimeseriesCollection(interp_discs, parameter_values(discs)) end function (sol::AbstractODESolution)(t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} + if t isa IndexedClock + t = canonicalize_indexed_clock(t, sol) + end sol(t, deriv, idxs, continuity) end function (sol::AbstractODESolution)(v, t, ::Type{deriv} = Val{0}; idxs = nothing, continuity = :left) where {deriv} + if t isa IndexedClock + t = canonicalize_indexed_clock(t, sol) + end sol.interp(v, t, idxs, deriv, sol.prob.p, continuity) end @@ -247,15 +270,13 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs, if is_parameter(sol, idxs) && !is_timeseries_parameter(sol, idxs) return getp(sol, idxs)(ps) end - # NOTE: This is basically SII.parameter_values_at_time but that isn't public API - # and once we move interpolation to SII, there's no reason for it to be if is_parameter_timeseries(sol) == Timeseries() discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) ps = parameter_values(discs) for ts_idx in eachindex(discs) partition = discs[ts_idx] interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) - ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val) + ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val) end end state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) @@ -270,15 +291,13 @@ function (sol::AbstractODESolution)(t::Number, ::Type{deriv}, idxs::AbstractVect end error_if_observed_derivative(sol, idxs, deriv) ps = parameter_values(sol) - # NOTE: This is basically SII.parameter_values_at_time but that isn't public API - # and once we move interpolation to SII, there's no reason for it to be if is_parameter_timeseries(sol) == Timeseries() discs::ParameterTimeseriesCollection = RecursiveArrayTools.get_discretes(sol) ps = parameter_values(discs) for ts_idx in eachindex(discs) partition = discs[ts_idx] interp_val = ConstantInterpolation(partition.t, partition.u)(t, nothing, deriv, nothing, continuity) - ps = with_updated_parameter_timeseries_values(ps, ts_idx => interp_val) + ps = with_updated_parameter_timeseries_values(sol, ps, ts_idx => interp_val) end end state = ProblemState(; u = sol.interp(t, nothing, deriv, ps, continuity), p = ps, t) @@ -290,9 +309,21 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, symbolic_type(idxs) == NotSymbolic() && error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + getter = getu(sol, idxs) + if is_parameter_timeseries(sol) == NotTimeseries() + interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol) + return DiffEqArray(getter(interp_sol), t, p, sol) + end discretes = get_interpolated_discretes(sol, t, deriv, continuity) - interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes) - return DiffEqArray(getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes) + interp_sol = sol.interp(t, nothing, deriv, p, continuity) + u = map(eachindex(t)) do ti + ps = parameter_values(discretes) + for i in eachindex(discretes) + ps = with_updated_parameter_timeseries_values(sol, ps, i => discretes[i, ti]) + end + return getter(ProblemState(; u = interp_sol.u[ti], p = ps, t = t[ti])) + end + return DiffEqArray(u, t, p, sol; discretes) end function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, @@ -301,34 +332,51 @@ function (sol::AbstractODESolution)(t::AbstractVector{<:Number}, ::Type{deriv}, error("Incorrect specification of `idxs`") error_if_observed_derivative(sol, idxs, deriv) p = hasproperty(sol.prob, :p) ? sol.prob.p : nothing + getter = getu(sol, idxs) + if is_parameter_timeseries(sol) == NotTimeseries() + interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol) + return DiffEqArray(getter(interp_sol), t, p, sol) + end discretes = get_interpolated_discretes(sol, t, deriv, continuity) - interp_sol = augment(sol.interp(t, nothing, deriv, p, continuity), sol; discretes) - return DiffEqArray( - getu(interp_sol, idxs)(interp_sol), t, p, sol; discretes) + interp_sol = sol.interp(t, nothing, deriv, p, continuity) + u = map(eachindex(t)) do ti + ps = parameter_values(discretes) + for i in eachindex(discretes) + ps = with_updated_parameter_timeseries_values(sol, ps, i => discretes[i, ti]) + end + return getter(ProblemState(; u = interp_sol.u[ti], p = ps, t = t[ti])) + end + return DiffEqArray(u, t, p, sol; discretes) end # public API, used by MTK """ - create_parameter_timeseries_collection(sys, ps) + create_parameter_timeseries_collection(sys, ps, tspan) Create a `SymbolicIndexingInterface.ParameterTimeseriesCollection` for the given system `sys` and parameter object `ps`. Return `nothing` if there are no timeseries parameters. -Defaults to `nothing`. +Defaults to `nothing`. Falls back on the basis of `symbolic_container`. """ function create_parameter_timeseries_collection(sys, ps, tspan) - return nothing + if hasmethod(symbolic_container, Tuple{typeof(sys)}) + return create_parameter_timeseries_collection(symbolic_container(sys), ps, tspan) + else + return nothing + end end const PeriodicDiffEqArray = DiffEqArray{T, N, A, B} where {T, N, A, B <: AbstractRange} # public API, used by MTK """ - get_saveable_values(ps, timeseries_idx) + get_saveable_values(sys, ps, timeseries_idx) """ -function get_saveable_values end +function get_saveable_values(sys, ps, timeseries_idx) + return get_saveable_values(symbolic_container(sys), ps, timeseries_idx) +end function save_discretes!(integ::DEIntegrator, timeseries_idx) - save_discretes!(integ.sol, current_time(integ), get_saveable_values(parameter_values(integ), timeseries_idx), timeseries_idx) + save_discretes!(integ.sol, current_time(integ), get_saveable_values(integ, parameter_values(integ), timeseries_idx), timeseries_idx) end save_discretes!(args...) = nothing @@ -346,9 +394,8 @@ function _save_discretes_internal!(A::AbstractDiffEqArray, t, vals) end function _save_discretes_internal!(A::PeriodicDiffEqArray, t, vals) - # This is O(1) because A.t is a range - idx = searchsortedlast(A.t, t) - if idx == firstindex(A.t) - 1 || A.t[idx] ≉ t + idx = length(A.u) + 1 + if A.t[idx] ≉ t error("Tried to save periodic discrete value with timeseries $(A.t) at time $t") end push!(A.u, vals) @@ -393,7 +440,11 @@ function build_solution(prob::Union{AbstractODEProblem, AbstractDDEProblem}, end ps = parameter_values(prob) - discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan) + if has_sys(prob.f) + discretes = create_parameter_timeseries_collection(prob.f.sys, ps, prob.tspan) + else + discretes = nothing + end if has_analytic(f) u_analytic = Vector{typeof(prob.u0)}() errors = Dict{Symbol, real(eltype(prob.u0))}() From 8af217d026ff985ebd0361362ed6a4bffe26c7c6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 25 Jun 2024 11:21:24 +0530 Subject: [PATCH 05/10] test: rewrite hybrid system tests to not use MTK --- test/downstream/Project.toml | 1 + test/downstream/comprehensive_indexing.jl | 338 +++++++++++++++++++++- 2 files changed, 338 insertions(+), 1 deletion(-) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 88f59794d..cce0055d6 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,5 +1,6 @@ [deps] BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d" +DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JumpProcesses = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" diff --git a/test/downstream/comprehensive_indexing.jl b/test/downstream/comprehensive_indexing.jl index 050ef7dde..b6fcce8a8 100644 --- a/test/downstream/comprehensive_indexing.jl +++ b/test/downstream/comprehensive_indexing.jl @@ -1,6 +1,7 @@ using ModelingToolkit, JumpProcesses, LinearAlgebra, NonlinearSolve, Optimization, OptimizationOptimJL, OrdinaryDiffEq, RecursiveArrayTools, SciMLBase, - SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, Test + SteadyStateDiffEq, StochasticDiffEq, SymbolicIndexingInterface, + DiffEqCallbacks, Test using ModelingToolkit: t_nounits as t, D_nounits as D # Sets rnd number. @@ -528,3 +529,338 @@ end @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, w]) @test_throws ErrorException sol(1.0, Val{1}, idxs = [w, y]) end + +@testset "Discrete save indexing" begin + struct NumSymbolCache{S} + sc::S + end + SymbolicIndexingInterface.symbolic_container(s::NumSymbolCache) = s.sc + function SymbolicIndexingInterface.is_observed(s::NumSymbolCache, x) + return symbolic_type(x) != NotSymbolic() && !is_variable(s, x) && + !is_parameter(s, x) && !is_independent_variable(s, x) + end + function SymbolicIndexingInterface.observed(s::NumSymbolCache, x) + res = ModelingToolkit.build_function(x, + sort(variable_symbols(s); by = Base.Fix1(variable_index, s)), + sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)), + independent_variable_symbols(s)[]; expression = Val(false)) + if res isa Tuple + return let oopfn = res[1], iipfn = res[2] + fn(out, u, p, t) = iipfn(out, u, p, t) + fn(u, p, t) = oopfn(u, p, t) + fn + end + else + return res + end + end + function SymbolicIndexingInterface.parameter_observed(s::NumSymbolCache, x) + res = ModelingToolkit.build_function(x, + sort(parameter_symbols(s), by = Base.Fix1(parameter_index, s)), + independent_variable_symbols(s)[]; expression = Val(false)) + if res isa Tuple + return let oopfn = res[1], iipfn = res[2] + fn(out, p, t) = iipfn(out, p, t) + fn(p, t) = oopfn(p, t) + fn + end + else + return res + end + end + function SymbolicIndexingInterface.get_all_timeseries_indexes(s::NumSymbolCache, x) + if symbolic_type(x) == NotSymbolic() + x = ModelingToolkit.unwrap.(x) + else + x = ModelingToolkit.unwrap(x) + end + vars = ModelingToolkit.vars(x) + return mapreduce(union, vars; init = Set()) do sym + if is_variable(s, sym) + Set([ContinuousTimeseries()]) + elseif is_parameter(s, sym) && is_timeseries_parameter(s, sym) + Set([timeseries_parameter_index(s, sym).timeseries_idx]) + else + Set() + end + end + end + function SymbolicIndexingInterface.with_updated_parameter_timeseries_values( + ::NumSymbolCache, p::Vector{Float64}, args...) + for (idx, buf) in args + if idx == 1 + p[1:2] .= buf + else + p[3:4] .= buf + end + end + + return p + end + function SciMLBase.create_parameter_timeseries_collection(s::NumSymbolCache, ps, tspan) + trem = rem(tspan[1], 0.1, RoundDown) + if trem > 0 + trem = 0.1 - trem + end + dea1 = DiffEqArray(Vector{Float64}[], (tspan[1] + trem):0.1:tspan[2]) + dea2 = DiffEqArray(Vector{Float64}[], Float64[]) + return ParameterTimeseriesCollection((dea1, dea2), deepcopy(ps)) + end + function SciMLBase.get_saveable_values(::NumSymbolCache, p::Vector{Float64}, tsidx) + if tsidx == 1 + return p[1:2] + else + return p[3:4] + end + end + + @variables x(t) ud1(t) ud2(t) xd1(t) xd2(t) + @parameters kp + sc = SymbolCache([x], + Dict(ud1 => 1, xd1 => 2, ud2 => 3, xd2 => 4, kp => 5), + t; + timeseries_parameters = Dict( + ud1 => ParameterTimeseriesIndex(1, 1), xd1 => ParameterTimeseriesIndex(1, 2), + ud2 => ParameterTimeseriesIndex(2, 1), xd2 => ParameterTimeseriesIndex(2, 2))) + sys = NumSymbolCache(sc) + + function f!(du, u, p, t) + du .= u .* t .+ p[5] * sum(u) + end + fn = ODEFunction(f!; sys = sys) + prob = ODEProblem(fn, [1.0], (0.0, 1.0), [1.0, 2.0, 3.0, 4.0, 5.0]) + cb1 = PeriodicCallback(0.1; initial_affect = true, final_affect = true, + save_positions = (false, false)) do integ + integ.p[1:2] .+= exp(-integ.t) + SciMLBase.save_discretes!(integ, 1) + end + function affect2!(integ) + integ.p[3:4] .+= only(integ.u) + SciMLBase.save_discretes!(integ, 2) + end + cb2 = DiscreteCallback((args...) -> true, affect2!, save_positions = (false, false), + initialize = (c, u, t, integ) -> affect2!(integ)) + sol = solve(deepcopy(prob), Tsit5(); callback = CallbackSet(cb1, cb2)) + + ud1val = getindex.(sol.discretes.collection[1].u, 1) + xd1val = getindex.(sol.discretes.collection[1].u, 2) + ud2val = getindex.(sol.discretes.collection[2].u, 1) + xd2val = getindex.(sol.discretes.collection[2].u, 2) + + for (sym, timeseries_index, val, buffer, isobs, check_inference) in [(ud1, + 1, + ud1val, + zeros(length(ud1val)), + false, + true) + ([ud1, xd1], + 1, + vcat.(ud1val, + xd1val), + map( + _ -> zeros(2), + ud1val), + false, + true) + ((ud2, xd2), + 2, + tuple.(ud2val, + xd2val), + map( + _ -> zeros(2), + ud2val), + false, + true) + (ud2 + xd2, + 2, + ud2val .+ + xd2val, + zeros(length(ud2val)), + true, + true) + ( + [ud2 + xd2, + ud2 * xd2], + 2, + vcat.( + ud2val .+ + xd2val, + ud2val .* + xd2val), + map( + _ -> zeros(2), + ud2val), + true, + true) + ( + (ud1 + xd1, + ud1 * xd1), + 1, + tuple.( + ud1val .+ + xd1val, + ud1val .* + xd1val), + map( + _ -> zeros(2), + ud1val), + true, + true)] + getter = getp(sys, sym) + if check_inference + @inferred getter(sol) + @inferred getter(deepcopy(buffer), sol) + if !isobs + @inferred getter(parameter_values(sol)) + if !(eltype(val) <: Number) + @inferred getter(deepcopy(buffer[1]), parameter_values(sol)) + end + end + end + + @test getter(sol) == val + if eltype(val) <: Number + target = val + else + target = collect.(val) + end + tmp = deepcopy(buffer) + getter(tmp, sol) + @test tmp == target + + if !isobs + @test getter(parameter_values(sol)) == val[end] + if !(eltype(val) <: Number) + target = collect(val[end]) + tmp = deepcopy(buffer)[end] + getter(tmp, parameter_values(sol)) + @test tmp == target + end + end + + for subidx in [ + 1, CartesianIndex(2), :, rand(Bool, length(val)), rand(eachindex(val), 4), 2:5] + if check_inference + @inferred getter(sol, subidx) + if !isa(val[subidx], Number) + @inferred getter(deepcopy(buffer[subidx]), sol, subidx) + end + end + @test getter(sol, subidx) == val[subidx] + tmp = deepcopy(buffer[subidx]) + if val[subidx] isa Number + continue + end + target = val[subidx] + if eltype(target) <: Number + target = collect(target) + else + target = collect.(target) + end + getter(tmp, sol, subidx) + @test tmp == target + end + end + + for sym in [ + [ud1, xd1, ud2], + (ud2, xd1, xd2), + ud1 + ud2, + [ud1 + ud2, ud1 * xd1], + (ud1 + ud2, ud1 * xd1)] + getter = getp(sys, sym) + @test_throws Exception getter(sol) + @test_throws Exception getter([], sol) + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws Exception getter(sol, subidx) + @test_throws Exception getter([], sol, subidx) + end + end + + kpval = sol.prob.p[5] + xval = getindex.(sol.u) + + for (sym, val_is_timeseries, val, check_inference) in [ + (kp, false, kpval, true), + ([kp, kp], false, [kpval, kpval], true), + ((kp, kp), false, (kpval, kpval), true), + (ud2, true, ud2val, true), + ([ud2, kp], true, vcat.(ud2val, kpval), false), + ((ud1, kp), true, tuple.(ud1val, kpval), false), + ([kp, x], true, vcat.(kpval, xval), false), + ((kp, x), true, tuple.(kpval, xval), false), + (2ud2, true, 2 .* ud2val, true), + ([kp, 2ud1], true, vcat.(kpval, 2 .* ud1val), false), + ((kp, 2ud1), true, tuple.(kpval, 2 .* ud1val), false) + ] + getter = getu(sys, sym) + if check_inference + @inferred getter(sol) + end + @test getter(sol) == val + reference = val_is_timeseries ? val : xval + for subidx in [ + 1, CartesianIndex(2), :, rand(Bool, length(reference)), + rand(eachindex(reference), 4), 2:6 + ] + if check_inference + @inferred getter(sol, subidx) + end + target = if val_is_timeseries + val[subidx] + else + val + end + @test getter(sol, subidx) == target + end + end + + _xval = xval[1] + _ud1val = ud1val[1] + _ud2val = ud2val[1] + _xd1val = xd1val[1] + _xd2val = xd2val[1] + integ = init(prob, Tsit5(); callback = CallbackSet(cb1, cb2)) + for (sym, val, check_inference) in [ + ([x, ud1], [_xval, _ud1val], false), + ((x, ud1), (_xval, _ud1val), true), + (x + ud2, _xval + _ud2val, true), + ([2x, 3xd1], [2_xval, 3_xd1val], true), + ((2x, 3xd2), (2_xval, 3_xd2val), true) + ] + getter = getu(sys, sym) + @test_throws Exception getter(sol) + for subidx in [1, CartesianIndex(1), :, rand(Bool, 4), rand(1:4, 3), 1:2] + @test_throws Exception getter(sol, subidx) + end + + if check_inference + @inferred getter(integ) + end + @test getter(integ) == val + end + + xinterp = sol(0.1:0.1:0.3, idxs = x) + xinterp2 = sol(sol.discretes.collection[2].t[2:4], idxs = x) + ud1interp = ud1val[2:4] + ud2interp = ud2val[2:4] + + c1 = SciMLBase.Clock(0.1) + c2 = SciMLBase.SolverStepClock + for (sym, t, val) in [ + (x, c1[2], xinterp[1]), + (x, c1[2:4], xinterp), + ([x, ud1], c1[2], [xinterp[1], ud1interp[1]]), + ([x, ud1], c1[2:4], vcat.(xinterp, ud1interp)), + (x, c2[2], xinterp2[1]), + (x, c2[2:4], xinterp2), + ([x, ud2], c2[2], [xinterp2[1], ud2interp[1]]), + ([x, ud2], c2[2:4], vcat.(xinterp2, ud2interp)) + ] + res = sol(t, idxs = sym) + if res isa DiffEqArray + res = res.u + end + @test res == val + end +end From 20efc758b4c72aeea14352bc2cf6bf238f766f2b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 2 Jul 2024 12:45:28 +0530 Subject: [PATCH 06/10] feat: add clock phase Co-authored-by: Fredrik Bagge Carlson --- src/clock.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/clock.jl b/src/clock.jl index ba0b59278..04d5dac43 100644 --- a/src/clock.jl +++ b/src/clock.jl @@ -8,6 +8,7 @@ using Expronicon.ADT: @adt, @match Continuous struct PeriodicClock dt::Union{Nothing, Float64, Rational{Int}} + phase::Float64 = 0.0 end SolverStepClock end @@ -25,9 +26,9 @@ using .Clocks The default periodic clock with tick interval `dt`. If `dt` is left unspecified, it will be inferred (if possible). """ -Clock(dt::Union{<:Rational, Float64}) = PeriodicClock(dt) -Clock(dt) = PeriodicClock(convert(Float64, dt)) -Clock() = PeriodicClock(nothing) +Clock(dt::Union{<:Rational, Float64}; phase = 0.0) = PeriodicClock(dt, phase) +Clock(dt; phase = 0.0) = PeriodicClock(convert(Float64, dt), phase) +Clock(; phase = 0.0) = PeriodicClock(nothing, phase) @doc """ SolverStepClock @@ -44,7 +45,7 @@ filters. """ SolverStepClock isclock(c) = @match c begin - PeriodicClock(_) => true + PeriodicClock(_...) => true _ => false end @@ -62,7 +63,7 @@ is_discrete_time_domain(c) = !iscontinuous(c) function first_clock_tick_time(c, t0) @match c begin - PeriodicClock(dt) => ceil(t0 / dt) * dt + PeriodicClock(dt, _...) => ceil(t0 / dt) * dt &SolverStepClock => t0 &Continuous => error("Continuous is not a discrete clock") end @@ -79,7 +80,7 @@ function canonicalize_indexed_clock(ic::IndexedClock, sol::AbstractTimeseriesSol c = ic.clock return @match c begin - PeriodicClock(dt) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt + PeriodicClock(dt, _...) => ceil(sol.prob.tspan[1] / dt) * dt .+ (ic.idx .- 1) .* dt &SolverStepClock => begin ssc_idx = findfirst(eachindex(sol.discretes)) do i !isa(sol.discretes[i].t, AbstractRange) From f7692068798f0ed9887b2b20bc32d2f021078017 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 3 Jul 2024 13:05:07 +0530 Subject: [PATCH 07/10] feat: add discrete variables support to solution plot recipe --- src/solutions/solution_interface.jl | 177 ++++++++++++++++++---------- 1 file changed, 112 insertions(+), 65 deletions(-) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index 105522f6d..dfa5ce9ab 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -176,91 +176,138 @@ DEFAULT_PLOT_FUNC(x, y, z) = (x, y, z) # For v0.5.2 bug end idxs = idxs === nothing ? (1:length(sol.u[1])) : idxs - + disc_idxs = [] + cont_idxs = [] + for idx in idxs + tsidxs = get_all_timeseries_indexes(sol, idx) + if ContinuousTimeseries() in tsidxs + push!(cont_idxs, idx) + else + push!(disc_idxs, (idx, only(tsidxs))) + end + end + idxs = identity.(cont_idxs) if !(idxs isa Union{Tuple, AbstractArray}) vars = interpret_vars([idxs], sol) else vars = interpret_vars(idxs, sol) end - - tscale = get(plotattributes, :xscale, :identity) - plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, vars, tscale, plotat) - tdir = sign(sol.t[end] - sol.t[1]) xflip --> tdir < 0 seriestype --> :path - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC - val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" + @series begin + if isempty(idxs) + label --> nothing + ([], []) + else + tscale = get(plotattributes, :xscale, :identity) + plot_vecs, labels = diffeq_to_arrays(sol, plot_analytic, denseplot, + plotdensity, tspan, vars, tscale, plotat) + + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... + if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC + val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] + if val isa Integer + if val == 0 + val = "t" + else + val = "u[$val]" + end + end + xguide --> val + val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] + if val isa Integer + if val == 0 + val = "t" + else + val = "u[$val]" + end + end + yguide --> val + if length(idxs) > 2 + val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] + if val isa Integer + if val == 0 + val = "t" + else + val = "u[$val]" + end + end + zguide --> val + end end - end - xguide --> val - val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" + + if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && + getindex.(vars, 1) == zeros(length(vars))) || + (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && + getindex.(vars, 2) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) + xguide --> "$(getindepsym_defaultt(sol))" end - end - yguide --> val - if length(idxs) > 2 - val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] - if val isa Integer - if val == 0 - val = "t" + if length(vars[1]) >= 3 && + ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && + getindex.(vars, 3) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) + yguide --> "$(getindepsym_defaultt(sol))" + end + if length(vars[1]) >= 4 && + ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && + getindex.(vars, 4) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) + zguide --> "$(getindepsym_defaultt(sol))" + end + + if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && + getindex.(vars, 2) == zeros(length(vars))) || + all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) + if tspan === nothing + if tdir > 0 + xlims --> (sol.t[1], sol.t[end]) + else + xlims --> (sol.t[end], sol.t[1]) + end else - val = "u[$val]" + xlims --> (tspan[1], tspan[end]) end end - zguide --> val - end - end - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && - getindex.(vars, 1) == zeros(length(vars))) || - (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - xguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 3 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && - getindex.(vars, 3) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) - yguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 4 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && - getindex.(vars, 4) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) - zguide --> "$(getindepsym_defaultt(sol))" + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + end end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - if tspan === nothing - if tdir > 0 - xlims --> (sol.t[1], sol.t[end]) - else - xlims --> (sol.t[end], sol.t[1]) + for (idx, tsidx) in disc_idxs + partition = sol.discretes[tsidx] + ts = current_time(partition) + if tspan !== nothing + tstart = searchsortedfirst(ts, tspan[1]) + tend = searchsortedlast(ts, tspan[2]) + if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1 + continue end else - xlims --> (tspan[1], tspan[end]) + tstart = firstindex(ts) + tend = lastindex(ts) + end + ts = ts[tstart:tend] + + vals = getp(sol, idx)(sol, tstart:tend) + # Scatterplot of points + @series begin + seriestype := :line + linestyle --> :dash + markershape --> :o + markersize --> repeat([2, 0], length(ts)-1) + markeralpha --> repeat([1, 0], length(ts)-1) + label --> string(hasname(idx) ? getname(idx) : idx) + + x = vec([ts[1:end-1]'; ts[2:end]']) + y = repeat(vals, inner=2)[1:end-1] + x, y end end - - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) end function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, From cfe63d7071620c4b29e6b9f5b8e02a4f5541750b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 8 Jul 2024 02:27:11 +0530 Subject: [PATCH 08/10] build: bump RAT and SII compat --- Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index d2850146f..07a5d9223 100644 --- a/Project.toml +++ b/Project.toml @@ -80,7 +80,7 @@ PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" RecipesBase = "1.3.4" -RecursiveArrayTools = "3.22.0" +RecursiveArrayTools = "3.26.0" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" SciMLOperators = "0.3.7" @@ -89,7 +89,7 @@ StableRNGs = "1.0" StaticArrays = "1.7" StaticArraysCore = "1.4" Statistics = "1.10" -SymbolicIndexingInterface = "0.3.20" +SymbolicIndexingInterface = "0.3.26" Tables = "1.11" Zygote = "0.6.67" julia = "1.10" From d71ec13f2d35ae9e84847deff32de0fb0c38ebb0 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 23 Jul 2024 11:55:03 +0530 Subject: [PATCH 09/10] fix: remove ambiguity in `_updated_u0_p_internal` --- src/remake.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/remake.jl b/src/remake.jl index ed227a3a2..621ca46a6 100644 --- a/src/remake.jl +++ b/src/remake.jl @@ -475,6 +475,10 @@ end anydict(d) = Dict{Any, Any}(d) anydict() = Dict{Any, Any}() +function _updated_u0_p_internal( + prob, ::Missing, ::Missing; interpret_symbolicmap = true, use_defaults = false) + return state_values(prob), parameter_values(prob) +end function _updated_u0_p_internal( prob, ::Missing, p; interpret_symbolicmap = true, use_defaults = false) u0 = state_values(prob) From fe68aaab7e738f6547ecde8dd98cb4cb19900911 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 23 Jul 2024 11:58:19 +0530 Subject: [PATCH 10/10] test: reduce ambiguities threshold --- test/aqua.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/aqua.jl b/test/aqua.jl index b403790aa..2c7f9215c 100644 --- a/test/aqua.jl +++ b/test/aqua.jl @@ -22,7 +22,7 @@ using Aqua # @show method_ambiguity # end @warn "Number of method ambiguities: $(length(ambs))" - @test length(ambs) ≤ 13 + @test length(ambs) ≤ 8 end @testset "Aqua tests (additional)" begin