From 605824aae9573728e328111a7062ebc74c59cf4a Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Thu, 13 Jul 2023 16:22:31 -0400 Subject: [PATCH 1/2] make symbolic indexing return VectorOfArray --- src/ensemble/ensemble_solutions.jl | 6 +++--- src/solutions/solution_interface.jl | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index a5fe82714..7f52a32d9 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -200,16 +200,16 @@ end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, s) - return [xi[s] for xi in x] + return VectorOfArray([xi[s] for xi in x]) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Colon...) return invoke(getindex, Tuple{RecursiveArrayTools.AbstractVectorOfArray, Colon, typeof.(args)...}, x, :, args...) end Base.@propagate_inbounds function Base.getindex(x::AbstractEnsembleSolution, ::Colon, args::Int...) - return [xi[args...] for xi in x] + return VectorOfArray([xi[args...] for xi in x]) end function (sol::AbstractEnsembleSolution)(args...; kwargs...) - [s(args...; kwargs...) for s in sol] + VectorOfArray([s(args...; kwargs...) for s in sol]) end diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index e9be869af..3e4bf5774 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -78,9 +78,9 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s if has_sys(A.prob.f) && all(Base.Fix1(is_param_sym, A.prob.f.sys), sym) || !has_sys(A.prob.f) && has_paramsyms(A.prob.f) && all(in(getparamsyms(A)), Symbol.(sym)) - return getindex.((A,), sym) + return VectorOfArray(getindex.((A,), sym)) else - return [getindex.((A,), sym, i) for i in eachindex(A)] + return VectorOfArray([getindex.((A,), sym, i) for i in eachindex(A)]) end else i = sym @@ -110,7 +110,7 @@ Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution, s observed(A, sym, :) end elseif i isa Base.Integer || i isa AbstractRange || i isa AbstractVector{<:Base.Integer} - A[i, :] + VectorOfArray(A[i, :]) else error("Invalid indexing of solution") end From 40abe8d57d3367b1ab17a2d4241b8cd358f523ec Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 14 Jul 2023 11:13:01 -0400 Subject: [PATCH 2/2] add ODESolution (and SciMLEnsembleSolution) to DataFrame support (conditional on 1.9) --- Project.toml | 6 ++++++ ext/SciMLBaseDataFramesExt.jl | 28 ++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 ext/SciMLBaseDataFramesExt.jl diff --git a/Project.toml b/Project.toml index a5361fc3a..b69472707 100644 --- a/Project.toml +++ b/Project.toml @@ -29,6 +29,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" +[weakdeps] +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" + +[extensions] +SciMLBaseDataFramesExt = "DataFrames" + [compat] ADTypes = "0.1.3" ArrayInterface = "6, 7" diff --git a/ext/SciMLBaseDataFramesExt.jl b/ext/SciMLBaseDataFramesExt.jl new file mode 100644 index 000000000..7059992bd --- /dev/null +++ b/ext/SciMLBaseDataFramesExt.jl @@ -0,0 +1,28 @@ +module SciMLBaseDataFramesExt +using SciMLBase, DataFrames + +function DataFrames.DataFrame(sol::EnsembleSolution, idxs::AbstractVector) + @assert allequal(getproperty.(sol.u, :t)) "solutions must have shared timesteps" + data = sol[:, idxs] + i_max, j_max, _ = size(data) + v = ["t"=>sol[1].t] + for (i, s) in enumerate(sol.u) + for idx in idxs + push!(v, string("sol ", i, ": ", idx)=>s[idx]) + end + end + DataFrame(v) +end + +function DataFrames.DataFrame(sol::ODESolution, idxs::AbstractVector) + @assert allequal(getproperty.(sol.u, :t)) "solutions must have shared timesteps" + data = sol[:, idxs] + i_max, j_max, _ = size(data) + v = ["t"=>sol[1].t] + for idx in idxs + push!(v, string("sol ", i, ": ", idx)=>s[idx]) + end + DataFrame(v) +end + +end