Skip to content

Commit

Permalink
Merge pull request #472 from oscardssmith/ensemble-indexing-fixes
Browse files Browse the repository at this point in the history
fix ensemble indexing
  • Loading branch information
ChrisRackauckas authored Jul 21, 2023
2 parents 091c91a + 30c3f74 commit 1aee21d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
5 changes: 1 addition & 4 deletions src/ensemble/ensemble_solutions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,13 @@ end
end


Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s)
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, s, ::Colon)
return [xi[s] for xi in x]
end

Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...)
return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...)
end
Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...)
return [xi[args...] for xi in x]
end

function (sol::AbstractEnsembleSolution)(args...; kwargs...)
[s(args...; kwargs...) for s in sol]
Expand Down
26 changes: 17 additions & 9 deletions test/downstream/ensemble_multi_prob.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
using ModelingToolkit, OrdinaryDiffEq, Test

@variables t, x(t)
@variables t, x(t), y(t)
D = Differential(t)

@named sys1 = ODESystem([D(x) ~ 1.1*x])
@named sys2 = ODESystem([D(x) ~ 1.2*x])
@named sys1 = ODESystem([D(x) ~ x,
D(y) ~ -y])
@named sys2 = ODESystem([D(x) ~ 2x,
D(y) ~ -2y])
@named sys3 = ODESystem([D(x) ~ 3x,
D(y) ~ -3y])

prob1 = ODEProblem(sys1, [2.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [1.0], (0.0, 1.0))
prob1 = ODEProblem(sys1, [1.0, 1.0], (0.0, 1.0))
prob2 = ODEProblem(sys2, [2.0, 2.0], (0.0, 1.0))
prob3 = ODEProblem(sys3, [3.0, 3.0], (0.0, 1.0))

# test that when passing a vector of problems, trajectories and the prob_func are chosen appropriately
ensemble_prob = EnsembleProblem([prob1, prob2])
ensemble_prob = EnsembleProblem([prob1, prob2, prob3])
sol = solve(ensemble_prob, Tsit5(), EnsembleThreads())
@test isapprox(sol[:, x], [2,1] .* map(Base.Fix1(map, exp), [1.1, 1.2] .* sol[:, t]), rtol=1e-4)
for i in 1:3
@test sol[x, :][i] == sol[i][x]
@test sol[y, :][i] == sol[i][y]
end
# Ensemble is a recursive array
@test sol(0.0, idxs=[x]) == sol[:, 1] == first.(sol[:, x], 1)
@test only.(sol(0.0, idxs=[x])) == sol[1, 1, :] == first.(sol[x, :])
# TODO: fix the interpolation
@test sol(1.0, idxs=[x]) last.(sol[:, x], 1)
@test only.(sol(1.0, idxs=[x])) last.(sol[x, :])

0 comments on commit 1aee21d

Please sign in to comment.