diff --git a/Project.toml b/Project.toml index a5361fc3a..b69472707 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" +[weakdeps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + +[extensions] +SciMLBaseDataFramesExt = "DataFrames" + [compat] ADTypes = "0.1.3" ArrayInterface = "6, 7" diff --git a/ext/SciMLBaseDataFramesExt.jl b/ext/SciMLBaseDataFramesExt.jl new file mode 100644 index 000000000..7059992bd --- /dev/null +++ b/ext/SciMLBaseDataFramesExt.jl @@ -0,0 +1,28 @@ +module SciMLBaseDataFramesExt +using SciMLBase, DataFrames + +function DataFrames.DataFrame(sol::EnsembleSolution, idxs::AbstractVector) + @assert allequal(getproperty.(sol.u, :t)) "solutions must have shared timesteps" + data = sol[:, idxs] + i_max, j_max, _ = size(data) + v = ["t"=>sol[1].t] + for (i, s) in enumerate(sol.u) + for idx in idxs + push!(v, string("sol ", i, ": ", idx)=>s[idx]) + end + end + DataFrame(v) +end + +function DataFrames.DataFrame(sol::ODESolution, idxs::AbstractVector) + @assert allequal(getproperty.(sol.u, :t)) "solutions must have shared timesteps" + data = sol[:, idxs] + i_max, j_max, _ = size(data) + v = ["t"=>sol[1].t] + for idx in idxs + push!(v, string("sol ", i, ": ", idx)=>s[idx]) + end + DataFrame(v) +end + +end diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index a5fe82714..7f52a32d9 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -200,16 +200,16 @@ end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) - return [xi[s] for xi in x] + return VectorOfArray([xi[s] for xi in x]) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return [xi[args...] for xi in x] + return VectorOfArray([xi[args...] for xi in x]) end function (sol::AbstractEnsembleSolution)(args...; kwargs...) - [s(args...; kwargs...) for s in sol] + VectorOfArray([s(args...; kwargs...) for s in sol]) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index e9be869af..3e4bf5774 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -78,9 +78,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && all(in(getparamsyms(A)), Symbol.(sym)) - return getindex.((A,), sym) + return VectorOfArray(getindex.((A,), sym)) else - return [getindex.((A,), sym, i) for i in eachindex(A)] + return VectorOfArray([getindex.((A,), sym, i) for i in eachindex(A)]) end else i = sym @@ -110,7 +110,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s observed(A, sym, :) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} - A[i, :] + VectorOfArray(A[i, :]) else error("Invalid indexing of solution") end