diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index a5fe82714..6adf36c19 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -199,17 +199,17 @@ end end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) - return [xi[s] for xi in x] +Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon) + return VectorOfArray([xi[s] for xi in x]) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end -Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return [xi[args...] for xi in x] -end +#Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, args::Int..., ::Colon) +# return VectorOfArray([xi[args...] for xi in x]) +#end function (sol::AbstractEnsembleSolution)(args...; kwargs...) - [s(args...; kwargs...) for s in sol] + VectorOfArray([s(args...; kwargs...) for s in sol]) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index e9be869af..3e4bf5774 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -78,9 +78,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && all(in(getparamsyms(A)), Symbol.(sym)) - return getindex.((A,), sym) + return VectorOfArray(getindex.((A,), sym)) else - return [getindex.((A,), sym, i) for i in eachindex(A)] + return VectorOfArray([getindex.((A,), sym, i) for i in eachindex(A)]) end else i = sym @@ -110,7 +110,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s observed(A, sym, :) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} - A[i, :] + VectorOfArray(A[i, :]) else error("Invalid indexing of solution") end diff --git a/test/downstream/ensemble_multi_prob.jl b/test/downstream/ensemble_multi_prob.jl index e8e7640ed..96facd548 100644 --- a/test/downstream/ensemble_multi_prob.jl +++ b/test/downstream/ensemble_multi_prob.jl @@ -1,19 +1,27 @@ using ModelingToolkit, OrdinaryDiffEq, Test -@variables t, x(t) +@variables t, x(t), y(t) D = Differential(t) -@named sys1 = ODESystem([D(x) ~ 1.1*x]) -@named sys2 = ODESystem([D(x) ~ 1.2*x]) +@named sys1 = ODESystem([D(x) ~ x, + D(y) ~ -y]) +@named sys2 = ODESystem([D(x) ~ 2x, + D(y) ~ -2y]) +@named sys3 = ODESystem([D(x) ~ 3x, + D(y) ~ -3y]) -prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0)) -prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0)) +prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0)) +prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0)) +prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0)) # test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately -ensemble_prob = EnsembleProblem([prob1, prob2]) +ensemble_prob = EnsembleProblem([prob1, prob2, prob3]) sol = solve(ensemble_prob, Tsit5(), EnsembleThreads()) -@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4) +for i in 1:3 + @test sol[x, :][i] == sol[i][x] + @test sol[y, :][i] == sol[i][y] +end # Ensemble is a recursive array -@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1) +@test Matrix(sol(0.0, idxs=[x])) == sol[1:1, 1, :] == Matrix(first(eachrow(sol[x, :]))') # TODO: fix the interpolation -@test sol(1.0, idxs=[x]) ≈ last.(sol[:, x], 1) +@test vec(sol(1.0, idxs=[x])) ≈ last.(sol[x, :].u)