From 605824aae9573728e328111a7062ebc74c59cf4a Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 13 Jul 2023 16:22:31 -0400 Subject: [PATCH 1/2] make symbolic indexing return VectorOfArray --- src/ensemble/ensemble_solutions.jl | 6 +++--- src/solutions/solution_interface.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index a5fe82714..7f52a32d9 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -200,16 +200,16 @@ end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) - return [xi[s] for xi in x] + return VectorOfArray([xi[s] for xi in x]) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return [xi[args...] for xi in x] + return VectorOfArray([xi[args...] for xi in x]) end function (sol::AbstractEnsembleSolution)(args...; kwargs...) - [s(args...; kwargs...) for s in sol] + VectorOfArray([s(args...; kwargs...) for s in sol]) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index e9be869af..3e4bf5774 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -78,9 +78,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && all(in(getparamsyms(A)), Symbol.(sym)) - return getindex.((A,), sym) + return VectorOfArray(getindex.((A,), sym)) else - return [getindex.((A,), sym, i) for i in eachindex(A)] + return VectorOfArray([getindex.((A,), sym, i) for i in eachindex(A)]) end else i = sym @@ -110,7 +110,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s observed(A, sym, :) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} - A[i, :] + VectorOfArray(A[i, :]) else error("Invalid indexing of solution") end From 7ff39aee279f1ebac44161ac4fa634547f6f2420 Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Mon, 17 Jul 2023 14:25:34 -0400 Subject: [PATCH 2/2] fix indexing --- src/ensemble/ensemble_solutions.jl | 8 ++++---- test/downstream/ensemble_multi_prob.jl | 26 +++++++++++++++++--------- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index 7f52a32d9..6adf36c19 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -199,16 +199,16 @@ end end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon) return VectorOfArray([xi[s] for xi in x]) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return VectorOfArray([xi[args...] for xi in x]) -end +#Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, args::Int..., ::Colon) +# return VectorOfArray([xi[args...] for xi in x]) +#end function (sol::AbstractEnsembleSolution)(args...; kwargs...) VectorOfArray([s(args...; kwargs...) for s in sol]) diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index e8e7640ed..96facd548 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -1,19 +1,27 @@ using ModelingToolkit, OrdinaryDiffEq, Test -@variables t, x(t) +@variables t, x(t), y(t) D = Differential(t) -@named sys1 = ODESystem([D(x) ~ 1.1*x]) -@named sys2 = ODESystem([D(x) ~ 1.2*x]) +@named sys1 = ODESystem([D(x) ~ x, + D(y) ~ -y]) +@named sys2 = ODESystem([D(x) ~ 2x, + D(y) ~ -2y]) +@named sys3 = ODESystem([D(x) ~ 3x, + D(y) ~ -3y]) -prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0)) -prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0)) +prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0)) +prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0)) +prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately -ensemble_prob = EnsembleProblem([prob1, prob2]) +ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) -@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4) +for i in 1:3 + @test sol[x, :][i] == sol[i][x] + @test sol[y, :][i] == sol[i][y] +end # Ensemble is a recursive array -@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1) +@test Matrix(sol(0.0, idxs=[x])) == sol[1:1, 1, :] == Matrix(first(eachrow(sol[x, :]))') # TODO: fix the interpolation -@test sol(1.0, idxs=[x]) ≈ last.(sol[:, x], 1) +@test vec(sol(1.0, idxs=[x])) ≈ last.(sol[x, :].u)