From e31bfdeec01f3e9d0d4f76eb0ae1c7873ccd49bf Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 8 Feb 2023 15:40:07 +0000 Subject: [PATCH 01/18] get state dependencies --- src/systems/abstractsystem.jl | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index a4e24fbdea..f8d6208f11 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -517,7 +517,7 @@ function controls(sys::AbstractSystem) isempty(systems) ? ctrls : [ctrls; reduce(vcat, namespace_controls.(systems))] end -function observed(sys::AbstractSystem) +function SymbolicIndexingInterface.observed(sys::AbstractSystem) obs = get_observed(sys) systems = get_systems(sys) [obs; @@ -638,6 +638,37 @@ function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym)) end +function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym) + findfirst(isequal(sym), SymbolicIndexingInterface.observed(sys)) +end +function SymbolicIndexingInterface.is_observed_sym(sys::AbstractSystem, sym) + !isnothing(SymbolicIndexingInterface.obvserved_sym_to_index(sys, sym)) +end + +function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, sym) + obs = observed(sys) + lhss = map(obs) do eq + eq.lhs + end + sts = states(sys) + i = observed_sym_to_index(sys, sym) + if isnothing(i) + return [] + end + + eq = obs[i] + varss = vars(eq) + out = mapreduce(vcat, varss, init = []) do u + if any(isequal(u), lhss) + get_state_dependencies(sys, u) + else + [u] + end + end |> unique + + return filter(x -> any(isequal(x), sts), out) +end + struct AbstractSysToExpr sys::AbstractSystem states::Vector From cdf8834f0e9d1fa040bacf335e3f5d62a20f03eb Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 8 Feb 2023 15:44:01 +0000 Subject: [PATCH 02/18] question --- src/systems/diffeqs/odesystem.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index ea98e8c17f..881cd3dc20 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -328,6 +328,7 @@ function build_explicit_observed_function(sys, ts; # FIXME: This is a rather rough estimate of dependencies. We assume # the expression depends on everything before the `maxidx`. + # ? subs = Dict() maxidx = 0 for s in dep_vars From 6dcccb6aed3d7b7d6adb81d2b94f73eb8a687d46 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 15 Feb 2023 18:58:58 +0000 Subject: [PATCH 03/18] state_deps --- .gitignore | 2 ++ Project.toml | 1 + src/ModelingToolkit.jl | 7 ++++-- src/systems/abstractsystem.jl | 28 ++++++++++++++++++++-- src/systems/diffeqs/odesystem.jl | 4 +++- test/odesystem.jl | 40 ++++++++++++++++++++++++++++++++ 6 files changed, 77 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 3401a5a4b3..b56768d112 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ Manifest.toml .vscode .vscode/* +scratch +scratch/* \ No newline at end of file diff --git a/Project.toml b/Project.toml index af67ae8564..6c16a48f34 100644 --- a/Project.toml +++ b/Project.toml @@ -30,6 +30,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" +OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6427337a7e..6543a68bfc 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -35,8 +35,11 @@ RuntimeGeneratedFunctions.init(@__MODULE__) using RecursiveArrayTools import SymbolicIndexingInterface -import SymbolicIndexingInterface: independent_variables, states, parameters -export independent_variables, states, parameters + +import SymbolicIndexingInterface: independent_variables, states, parameters, observed, + observed_sym_to_index, get_state_dependencies, + get_dependencies_of_observed +export independent_variables, states, parameters, observed import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, Symbolic, Term, Add, Mul, Pow, Sym, FnType, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index f8d6208f11..ab31ca4e09 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -522,7 +522,7 @@ function SymbolicIndexingInterface.observed(sys::AbstractSystem) systems = get_systems(sys) [obs; reduce(vcat, - (map(o -> namespace_equation(o, s), observed(s)) for s in systems), + (map(o -> namespace_equation(o, s), SII.observed(s)) for s in systems), init = Equation[])] end @@ -600,7 +600,7 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem) # but is `x(t-1)` or something like that, pass in `x` as a callable function rather # than pass in a value in place of x(t). # - # This is done by just making `x` the argument of the function. + # This is*-+ done by just making `x` the argument of the function. if istree(x) && operation(x) isa Sym && !(length(arguments(x)) == 1 && isequal(arguments(x)[1], get_iv(sys))) @@ -669,6 +669,30 @@ function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, s return filter(x -> any(isequal(x), sts), out) end +function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem, sym) + obs = observed(sys) + lhss = map(obs) do eq + eq.lhs + end + + i = observed_sym_to_index(sys, sym) + if isnothing(i) + return [] + end + + eq = obs[i] + varss = vars(eq) + out = mapreduce(vcat, varss, init = []) do u + if any(isequal(u), lhss) + [u] + else + [] + end + end |> unique + + return out +end + struct AbstractSysToExpr sys::AbstractSystem states::Vector diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 881cd3dc20..cab18f9cc6 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -372,7 +372,9 @@ function build_explicit_observed_function(sys, ts; push!(obsexprs, lhs ← rhs) end - dvs = DestructuredArgs(states(sys), inbounds = !checkbounds) + statedeps = mapreduce(x -> get_state_dependencies(sys, x.lhs), vcat, obs) |> unique + + dvs = DestructuredArgs(statedeps, inbounds = !checkbounds) ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds) args = [dvs, ps, ivs...] pre = get_postprocess_fbody(sys) diff --git a/test/odesystem.jl b/test/odesystem.jl index e304f32de2..ddfeac51fd 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -603,6 +603,11 @@ RHS2 = RHS @unpack RHS = fol @test isequal(RHS, RHS2) +using SymbolicIndexingInterface +using SymbolicIndexingInterface: get_state_dependencies + +@test isequal(get_state_dependencies(fol, observed(fol)[1]), x) + #1413 and 1389 @parameters t α β @variables x(t) y(t) z(t) @@ -1010,3 +1015,38 @@ let prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0)) @test !isnothing(prob.f.sys) end + +@testset "Symbolic save_idxs" begin + @parameters t + @variables a(t) b(t) c(t) d(t) e(t) + + D = Differential(t) + + eqs = [D(a) ~ a, + D(b) ~ b, + D(c) ~ c, + D(d) ~ d, + e ~ d] + + @named sys = ODESystem(eqs, t, [a, b, c, d, e], []; + defaults = Dict([a => 1.0, + b => 1.0, + c => 1.0, + d => 1.0, + e => 1.0])) + sys = structural_simplify(sys) + prob = ODEProblem(sys, [], (0, 1.0)) + prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, c]) + + sol = solve(prob, Tsit5()) + sol_sym = solve(prob_sym, Tsit5()) + + @test sol_sym[a] ≈ sol[a] + @test sol_sym[c] ≈ sol[c] + @test sol_sym[d] ≈ sol[d] + @test sol_sym[e] ≈ sol[e] + + @test sol.u != sol_sym.u + + @test_throws Exception sol_sym[b] +end From 5c25c88685176315e24a60c40dc436894f050c1c Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 16 Feb 2023 15:29:55 +0000 Subject: [PATCH 04/18] fixes --- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 6 +++--- src/systems/diffeqs/odesystem.jl | 2 +- src/utils.jl | 2 ++ 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 6543a68bfc..04b012b23b 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -38,7 +38,7 @@ import SymbolicIndexingInterface import SymbolicIndexingInterface: independent_variables, states, parameters, observed, observed_sym_to_index, get_state_dependencies, - get_dependencies_of_observed + get_deps_of_observed export independent_variables, states, parameters, observed import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ab31ca4e09..fa3ad7c38d 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -625,21 +625,21 @@ function unknown_states(sys::AbstractSystem) end function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), unknown_states(sys)) + findfirst(isequal(sym), unknown_states(sys)) |> safe_unwrap end function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym)) end function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys)) + findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys)) |> safe_unwrap end function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym)) end function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), SymbolicIndexingInterface.observed(sys)) + findfirst(isequal(sym), SymbolicIndexingInterface.observed(sys)) |> safe_unwrap end function SymbolicIndexingInterface.is_observed_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.obvserved_sym_to_index(sys, sym)) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index cab18f9cc6..bb0414cf59 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -372,7 +372,7 @@ function build_explicit_observed_function(sys, ts; push!(obsexprs, lhs ← rhs) end - statedeps = mapreduce(x -> get_state_dependencies(sys, x.lhs), vcat, obs) |> unique + statedeps = get_deps_of_observed(sys) dvs = DestructuredArgs(statedeps, inbounds = !checkbounds) ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds) diff --git a/src/utils.jl b/src/utils.jl index d677cf30a0..3572457b5a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -879,3 +879,5 @@ function fast_substitute(expr, pair::Pair) end normalize_to_differential(s) = s + +safe_unwrap(x) = x isa Num ? unwrap(x) : x From 9c267aad811ca2e275f4dfe5c53626a1dddb185a Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 17 Feb 2023 12:18:19 +0000 Subject: [PATCH 05/18] passing --- src/ModelingToolkit.jl | 2 +- src/systems/abstractsystem.jl | 23 ++++++++++----------- src/utils.jl | 2 +- test/odesystem.jl | 39 +++++++++++++++++++++++++++++++++-- 4 files changed, 50 insertions(+), 16 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index b7e1f883bd..2890d02a7e 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -38,7 +38,7 @@ import SymbolicIndexingInterface import SymbolicIndexingInterface: independent_variables, states, parameters, observed, observed_sym_to_index, get_state_dependencies, - get_deps_of_observed + get_deps_of_observed, is_observed_sym export independent_variables, states, parameters, observed import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index d3e8465260..2a0b528f4d 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -654,17 +654,19 @@ function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym) end function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), SymbolicIndexingInterface.observed(sys)) |> safe_unwrap + findfirst(isequal(safe_unwrap(sym)), getfield.(observed(sys), (:lhs,))) |> + safe_unwrap +end +function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym::Equation) + findfirst(isequal(sym), observed(sys)) |> + safe_unwrap end function SymbolicIndexingInterface.is_observed_sym(sys::AbstractSystem, sym) - !isnothing(SymbolicIndexingInterface.obvserved_sym_to_index(sys, sym)) + !isnothing(SymbolicIndexingInterface.observed_sym_to_index(sys, sym)) end function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, sym) obs = observed(sys) - lhss = map(obs) do eq - eq.lhs - end sts = states(sys) i = observed_sym_to_index(sys, sym) if isnothing(i) @@ -672,9 +674,9 @@ function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, s end eq = obs[i] - varss = vars(eq) + varss = vars(eq.rhs) out = mapreduce(vcat, varss, init = []) do u - if any(isequal(u), lhss) + if is_observed_sym(sys, u) get_state_dependencies(sys, u) else [u] @@ -686,9 +688,6 @@ end function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem, sym) obs = observed(sys) - lhss = map(obs) do eq - eq.lhs - end i = observed_sym_to_index(sys, sym) if isnothing(i) @@ -696,9 +695,9 @@ function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem end eq = obs[i] - varss = vars(eq) + varss = vars(eq.rhs) out = mapreduce(vcat, varss, init = []) do u - if any(isequal(u), lhss) + if is_observed_sym(sys, u) [u] else [] diff --git a/src/utils.jl b/src/utils.jl index 93246a8c91..f70a70812b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -878,4 +878,4 @@ end normalize_to_differential(s) = s -safe_unwrap(x) = x isa Num ? unwrap(x) : x +safe_unwrap(x) = x isa Num ? x.val : x diff --git a/test/odesystem.jl b/test/odesystem.jl index 02a4fec1d4..f3db536a7b 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1017,7 +1017,7 @@ let @test !isnothing(prob.f.sys) end -@testset "Symbolic save_idxs" begin +@testset "Symbolic save_idxs - Save Observed" begin @parameters t @variables a(t) b(t) c(t) d(t) e(t) @@ -1037,7 +1037,7 @@ end e => 1.0])) sys = structural_simplify(sys) prob = ODEProblem(sys, [], (0, 1.0)) - prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, c]) + prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, c, e]) sol = solve(prob, Tsit5()) sol_sym = solve(prob_sym, Tsit5()) @@ -1051,3 +1051,38 @@ end @test_throws Exception sol_sym[b] end + +@testset "Symbolic save_idxs - No observed" begin + @parameters t + @variables a(t) b(t) c(t) d(t) e(t) + + D = Differential(t) + + eqs = [D(a) ~ a, + D(b) ~ b, + D(c) ~ c, + D(d) ~ d, + e ~ d] + + @named sys = ODESystem(eqs, t, [a, b, c, d, e], []; + defaults = Dict([a => 1.0, + b => 1.0, + c => 1.0, + d => 1.0, + e => 1.0])) + sys = structural_simplify(sys) + prob = ODEProblem(sys, [], (0, 1.0)) + prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, c, b]) + + sol = solve(prob, Tsit5()) + sol_sym = solve(prob_sym, Tsit5()) + + @test sol_sym[a] ≈ sol[a] + @test sol_sym[b] ≈ sol[b] + @test sol_sym[c] ≈ sol[c] + + @test sol.u != sol_sym.u + + @test_throws Exception sol_sym[d] + @test_throws Exception sol_sym[e] +end From d1399304282e977372f99b8a297b293af4ade7d9 Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 17 Feb 2023 14:54:28 +0000 Subject: [PATCH 06/18] fixes and comment --- src/systems/abstractsystem.jl | 4 ++-- src/systems/diffeqs/odesystem.jl | 2 +- src/utils.jl | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 2a0b528f4d..32ec6f4944 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -537,7 +537,7 @@ function SymbolicIndexingInterface.observed(sys::AbstractSystem) systems = get_systems(sys) [obs; reduce(vcat, - (map(o -> namespace_equation(o, s), SII.observed(s)) for s in systems), + (map(o -> namespace_equation(o, s), observed(s)) for s in systems), init = Equation[])] end @@ -1281,7 +1281,7 @@ function linearization_function(sys::AbstractSystem, inputs, uf = SciMLBase.UJacobianWrapper(fun, t, p) fg_xz = ForwardDiff.jacobian(uf, u) h_xz = ForwardDiff.jacobian(let p = p, t = t - xz -> h(xz, p, t) + xz -> h(xz, p, t) #! signature must change end, u) pf = SciMLBase.ParamJacobianWrapper(fun, t, u) fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 819a557e59..6a017e7a6d 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -379,7 +379,7 @@ function build_explicit_observed_function(sys, ts; push!(obsexprs, lhs ← rhs) end - statedeps = get_deps_of_observed(sys) + statedeps = safe_unwrap.(get_deps_of_observed(sys)) dvs = DestructuredArgs(statedeps, inbounds = !checkbounds) ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds) diff --git a/src/utils.jl b/src/utils.jl index f70a70812b..93246a8c91 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -878,4 +878,4 @@ end normalize_to_differential(s) = s -safe_unwrap(x) = x isa Num ? x.val : x +safe_unwrap(x) = x isa Num ? unwrap(x) : x From 4edcd79acefccb30aa1044f3fd58991d9061d458 Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 21 Feb 2023 15:52:57 +0000 Subject: [PATCH 07/18] linearization and odae --- src/ModelingToolkit.jl | 3 +- src/structural_transformation/codegen.jl | 4 ++- src/systems/abstractsystem.jl | 39 +++++++++++++++++------- test/runtests.jl | 2 +- 4 files changed, 34 insertions(+), 14 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 2890d02a7e..9f53c20fe1 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -36,7 +36,8 @@ using RecursiveArrayTools import SymbolicIndexingInterface -import SymbolicIndexingInterface: independent_variables, states, parameters, observed, +import SymbolicIndexingInterface: independent_variables, states, state_sym_to_index, + parameters, observed, observed_sym_to_index, get_state_dependencies, get_deps_of_observed, is_observed_sym export independent_variables, states, parameters, observed diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index 1e9bea8088..aa76734849 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -480,6 +480,8 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, solves = [] end + unknown_state_deps = get_deps_of_observed(unknown_states, obs) + subs = [] for sym in vars eqidx = get(observed_idx, sym, nothing) @@ -490,7 +492,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, cpre = get_preprocess_constants([obs[1:maxidx]; isscalar ? ts[1] : MakeArray(ts, output_type)]) pre2 = x -> pre(cpre(x)) - ex = Code.toexpr(Func([DestructuredArgs(unknown_states, inbounds = !checkbounds) + ex = Code.toexpr(Func([DestructuredArgs(unknown_state_deps, inbounds = !checkbounds) DestructuredArgs(parameters(sys), inbounds = !checkbounds) independent_variables(sys)], [], diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 32ec6f4944..962fa250ba 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -658,17 +658,26 @@ function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sy safe_unwrap end function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym::Equation) - findfirst(isequal(sym), observed(sys)) |> - safe_unwrap + findfirst(isequal(sym), observed(sys)) |> safe_unwrap end function SymbolicIndexingInterface.is_observed_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.observed_sym_to_index(sys, sym)) end -function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, sym) - obs = observed(sys) - sts = states(sys) - i = observed_sym_to_index(sys, sym) +function SymbolicIndexingInterface.get_deps_of_observed(sts, obs::AbstractArray{<:Equation}) + deps = mapreduce(vcat, obs, init = []) do eq + get_state_dependencies(sts, obs, eq.lhs) + end |> unique + + return deps +end + +function SymbolicIndexingInterface.get_state_dependencies(sts, obs, sym) + s2i = (sym) -> findfirst(isequal(safe_unwrap(sym)), getfield.(obs, (:lhs,))) |> + safe_unwrap + + i = s2i(sym) + if isnothing(i) return [] end @@ -676,8 +685,8 @@ function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, s eq = obs[i] varss = vars(eq.rhs) out = mapreduce(vcat, varss, init = []) do u - if is_observed_sym(sys, u) - get_state_dependencies(sys, u) + if !isnothing(s2i(u)) + get_state_dependencies(sts, obs, u) else [u] end @@ -686,6 +695,12 @@ function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, s return filter(x -> any(isequal(x), sts), out) end +function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, sym) + obs = observed(sys) + sts = states(sys) + return get_state_dependencies(sts, obs, sym) +end + function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem, sym) obs = observed(sys) @@ -1272,6 +1287,8 @@ function linearization_function(sys::AbstractSystem, inputs, sts = states(sys), fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys), h = build_explicit_observed_function(sys, outputs), + deps = get_deps_of_observed(sys) + dep_idxs = state_sym_to_index.((sys,), deps) chunk = ForwardDiff.Chunk(input_idxs) function (u, p, t) @@ -1281,7 +1298,7 @@ function linearization_function(sys::AbstractSystem, inputs, uf = SciMLBase.UJacobianWrapper(fun, t, p) fg_xz = ForwardDiff.jacobian(uf, u) h_xz = ForwardDiff.jacobian(let p = p, t = t - xz -> h(xz, p, t) #! signature must change + (xz) -> h(xz[dep_idxs], p, t) #! signature must change end, u) pf = SciMLBase.ParamJacobianWrapper(fun, t, u) fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk) @@ -1291,8 +1308,8 @@ function linearization_function(sys::AbstractSystem, inputs, fg_xz = zeros(0, 0) h_xz = fg_u = zeros(0, length(inputs)) end - hp = let u = u, t = t - p -> h(u, p, t) + hp = let u = u[dep_idxs], t = t + p -> h(u, p, t) #! signature must change end h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk) (f_x = fg_xz[diff_idxs, diff_idxs], diff --git a/test/runtests.jl b/test/runtests.jl index c222a9c340..3893d5aadb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SafeTestsets, Test +@safetestset "Linearization Tests" begin include("linearize.jl") end @safetestset "AliasGraph Test" begin include("alias.jl") end @safetestset "Linear Algebra Test" begin include("linalg.jl") end @safetestset "AbstractSystem Test" begin include("abstractsystem.jl") end @@ -9,7 +10,6 @@ using SafeTestsets, Test @safetestset "Simplify Test" begin include("simplify.jl") end @safetestset "Direct Usage Test" begin include("direct.jl") end @safetestset "System Linearity Test" begin include("linearity.jl") end -@safetestset "Linearization Tests" begin include("linearize.jl") end @safetestset "Input Output Test" begin include("input_output_handling.jl") end @safetestset "Clock Test" begin include("clock.jl") end @safetestset "DiscreteSystem Test" begin include("discretesystem.jl") end From 789350b41ca652a075a4f9dc70885a4cddcf45a8 Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 21 Feb 2023 17:24:28 +0000 Subject: [PATCH 08/18] fix errors --- src/structural_transformation/codegen.jl | 5 +++-- src/systems/abstractsystem.jl | 23 ++++++++++++----------- src/systems/diffeqs/odesystem.jl | 5 +++-- test/odaeproblem.jl | 12 +++++++++++- test/odesystem.jl | 2 +- test/runtests.jl | 6 +++--- 6 files changed, 33 insertions(+), 20 deletions(-) diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index aa76734849..13b8fc7d31 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -388,7 +388,8 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, var2assignment; expression = false, output_type = Array, - checkbounds = true) + checkbounds = true, + dense_states = false) is_not_prepended_assignment = trues(length(assignments)) if (isscalar = !(ts isa AbstractVector)) ts = [ts] @@ -480,7 +481,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, solves = [] end - unknown_state_deps = get_deps_of_observed(unknown_states, obs) + unknown_state_deps = dense_states ? unknown_states : get_deps_of_observed(unknown_states, obs) subs = [] for sym in vars diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 962fa250ba..7fe0c65f8e 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -660,6 +660,13 @@ end function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym::Equation) findfirst(isequal(sym), observed(sys)) |> safe_unwrap end +function SymbolicIndexingInterface.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym) + findfirst(isequal(safe_unwrap(sym)), getfield.(obs, (:lhs,))) |> + safe_unwrap +end +function SymbolicIndexingInterface.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym::Equation) + findfirst(isequal(sym), obs) |> safe_unwrap +end function SymbolicIndexingInterface.is_observed_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.observed_sym_to_index(sys, sym)) end @@ -673,11 +680,7 @@ function SymbolicIndexingInterface.get_deps_of_observed(sts, obs::AbstractArray{ end function SymbolicIndexingInterface.get_state_dependencies(sts, obs, sym) - s2i = (sym) -> findfirst(isequal(safe_unwrap(sym)), getfield.(obs, (:lhs,))) |> - safe_unwrap - - i = s2i(sym) - + i = observed_sym_to_index(obs, sym) if isnothing(i) return [] end @@ -685,7 +688,7 @@ function SymbolicIndexingInterface.get_state_dependencies(sts, obs, sym) eq = obs[i] varss = vars(eq.rhs) out = mapreduce(vcat, varss, init = []) do u - if !isnothing(s2i(u)) + if !isnothing(observed_sym_to_index(obs, u)) get_state_dependencies(sts, obs, u) else [u] @@ -1286,9 +1289,7 @@ function linearization_function(sys::AbstractSystem, inputs, input_idxs = input_idxs, sts = states(sys), fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys), - h = build_explicit_observed_function(sys, outputs), - deps = get_deps_of_observed(sys) - dep_idxs = state_sym_to_index.((sys,), deps) + h = build_explicit_observed_function(sys, outputs; dense_states = true), chunk = ForwardDiff.Chunk(input_idxs) function (u, p, t) @@ -1298,7 +1299,7 @@ function linearization_function(sys::AbstractSystem, inputs, uf = SciMLBase.UJacobianWrapper(fun, t, p) fg_xz = ForwardDiff.jacobian(uf, u) h_xz = ForwardDiff.jacobian(let p = p, t = t - (xz) -> h(xz[dep_idxs], p, t) #! signature must change + (xz) -> h(xz, p, t) #! signature must change end, u) pf = SciMLBase.ParamJacobianWrapper(fun, t, u) fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk) @@ -1308,7 +1309,7 @@ function linearization_function(sys::AbstractSystem, inputs, fg_xz = zeros(0, 0) h_xz = fg_u = zeros(0, length(inputs)) end - hp = let u = u[dep_idxs], t = t + hp = let u = u, t = t p -> h(u, p, t) #! signature must change end h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 6a017e7a6d..742288e72c 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -307,7 +307,8 @@ function build_explicit_observed_function(sys, ts; expression = false, output_type = Array, checkbounds = true, - throw = true) + throw = true, + dense_states = false) if (isscalar = !(ts isa AbstractVector)) ts = [ts] end @@ -379,7 +380,7 @@ function build_explicit_observed_function(sys, ts; push!(obsexprs, lhs ← rhs) end - statedeps = safe_unwrap.(get_deps_of_observed(sys)) + statedeps = dense_states ? states(sys) : safe_unwrap.(get_deps_of_observed(sys)) dvs = DestructuredArgs(statedeps, inbounds = !checkbounds) ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds) diff --git a/test/odaeproblem.jl b/test/odaeproblem.jl index 33db56dd36..bbaac204d9 100644 --- a/test/odaeproblem.jl +++ b/test/odaeproblem.jl @@ -55,5 +55,15 @@ rc_eqs = [connect(c.output, source.V) @named rc_model = ODESystem(rc_eqs, t, systems = [strip, c, source, ground]) sys = structural_simplify(rc_model) +@show ModelingToolkit.observed(sys) + +ref_prob = ODEProblem(sys, [], (0, 10)) prob = ODAEProblem(sys, [], (0, 10)) -@test_nowarn solve(prob, Tsit5()) + +ref_sol = solve(ref_prob, Tsit5()) +@test_nowarn sol = solve(prob, Tsit5()) + + + +# test that the observed variables are correct +@test sol[strip₊St_1₊r₊n₊v] ≈ sol[strip₊St_1₊r₊n₊v] diff --git a/test/odesystem.jl b/test/odesystem.jl index f3db536a7b..68381a305d 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -607,7 +607,7 @@ RHS2 = RHS using SymbolicIndexingInterface using SymbolicIndexingInterface: get_state_dependencies -@test isequal(get_state_dependencies(fol, observed(fol)[1]), x) +@test isequal(get_state_dependencies(fol, observed(fol)[1]), [x]) #1413 and 1389 @parameters t α β diff --git a/test/runtests.jl b/test/runtests.jl index 3893d5aadb..21a5737314 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using SafeTestsets, Test -@safetestset "Linearization Tests" begin include("linearize.jl") end +@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end +@safetestset "ODESystem Test" begin include("odesystem.jl") end @safetestset "AliasGraph Test" begin include("alias.jl") end @safetestset "Linear Algebra Test" begin include("linalg.jl") end @safetestset "AbstractSystem Test" begin include("abstractsystem.jl") end @@ -10,10 +11,10 @@ using SafeTestsets, Test @safetestset "Simplify Test" begin include("simplify.jl") end @safetestset "Direct Usage Test" begin include("direct.jl") end @safetestset "System Linearity Test" begin include("linearity.jl") end +@safetestset "Linearization Tests" begin include("linearize.jl") end @safetestset "Input Output Test" begin include("input_output_handling.jl") end @safetestset "Clock Test" begin include("clock.jl") end @safetestset "DiscreteSystem Test" begin include("discretesystem.jl") end -@safetestset "ODESystem Test" begin include("odesystem.jl") end @safetestset "Unitful Quantities Test" begin include("units.jl") end @safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end @safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end @@ -24,7 +25,6 @@ using SafeTestsets, Test @safetestset "JumpSystem Test" begin include("jumpsystem.jl") end @safetestset "Constraints Test" begin include("constraints.jl") end @safetestset "Reduction Test" begin include("reduction.jl") end -@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end @safetestset "Components Test" begin include("components.jl") end @safetestset "print_tree" begin include("print_tree.jl") end @safetestset "Error Handling" begin include("error_handling.jl") end From 0663652f0900552b3ca6701434c8563ba1947ada Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 21 Feb 2023 20:31:49 +0000 Subject: [PATCH 09/18] passing --- .../StructuralTransformations.jl | 2 ++ src/structural_transformation/codegen.jl | 3 ++- test/odaeproblem.jl | 13 +++++-------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/structural_transformation/StructuralTransformations.jl b/src/structural_transformation/StructuralTransformations.jl index 96800d4042..2f991b42cb 100644 --- a/src/structural_transformation/StructuralTransformations.jl +++ b/src/structural_transformation/StructuralTransformations.jl @@ -42,6 +42,8 @@ using SparseArrays using SimpleNonlinearSolve +import SymbolicIndexingInterface: get_deps_of_observed + export tearing, partial_state_selection, dae_index_lowering, check_consistency export dummy_derivative export build_torn_function, build_observed_function, ODAEProblem diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index 13b8fc7d31..ce6ac64653 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -481,7 +481,8 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, solves = [] end - unknown_state_deps = dense_states ? unknown_states : get_deps_of_observed(unknown_states, obs) + unknown_state_deps = dense_states ? unknown_states : + get_deps_of_observed(unknown_states, obs) subs = [] for sym in vars diff --git a/test/odaeproblem.jl b/test/odaeproblem.jl index bbaac204d9..595961814f 100644 --- a/test/odaeproblem.jl +++ b/test/odaeproblem.jl @@ -55,15 +55,12 @@ rc_eqs = [connect(c.output, source.V) @named rc_model = ODESystem(rc_eqs, t, systems = [strip, c, source, ground]) sys = structural_simplify(rc_model) -@show ModelingToolkit.observed(sys) - ref_prob = ODEProblem(sys, [], (0, 10)) prob = ODAEProblem(sys, [], (0, 10)) -ref_sol = solve(ref_prob, Tsit5()) -@test_nowarn sol = solve(prob, Tsit5()) - - +ref_sol = solve(ref_prob, FBDF(), saveat = 1.0) +@test_nowarn solve(prob, Tsit5()) +sol = solve(prob, Tsit5(), saveat = 1.0) -# test that the observed variables are correct -@test sol[strip₊St_1₊r₊n₊v] ≈ sol[strip₊St_1₊r₊n₊v] +# test observed +@test sol[observed(sys)[1].lhs] ≈ ref_sol[observed(sys)[1].lhs] atol = 0.001 From f72d2f31d2c5476f29a5b951438d74ea056c1320 Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 27 Feb 2023 15:19:26 +0000 Subject: [PATCH 10/18] passing --- docs/pages.jl | 1 + docs/src/basics/saving_syms.md | 43 ++++++++++++++++++++++++ src/ModelingToolkit.jl | 2 +- src/structural_transformation/codegen.jl | 12 ++++--- src/systems/abstractsystem.jl | 7 ++-- src/systems/diffeqs/abstractodesystem.jl | 14 +++++--- src/systems/diffeqs/odesystem.jl | 5 ++- test/odesystem.jl | 8 ++--- 8 files changed, 73 insertions(+), 19 deletions(-) create mode 100644 docs/src/basics/saving_syms.md diff --git a/docs/pages.jl b/docs/pages.jl index d79d9ee79b..ee48f06a48 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -21,6 +21,7 @@ pages = [ "basics/Events.md", "basics/Linearization.md", "basics/Validation.md", + "basics/saving_syms.md" "basics/DependencyGraphs.md", "basics/FAQ.md"], "System Types" => Any["systems/ODESystem.md", diff --git a/docs/src/basics/saving_syms.md b/docs/src/basics/saving_syms.md new file mode 100644 index 0000000000..0e70f9c749 --- /dev/null +++ b/docs/src/basics/saving_syms.md @@ -0,0 +1,43 @@ +# [Saving Only Certain Symbols](@id sym_save_idxs) + +It is possible to specify symbolically which states of an ODESystem to save, with the `save_idxs` keyword argument +to the `solve` call. This may be important to you to save memory for large problems. + +Take care to disable the `dense_output` flag when constructing the problem to ensure that observed +variables can be properly reconstructed. Failing to do this may cause incorrect construction of observed variables. + +```julia +using ModelingToolkit, OrdinaryDiffEq, Test +@parameters t +@variables a(t) b(t) c(t) d(t) e(t) + +D = Differential(t) + +eqs = [D(a) ~ a, + D(b) ~ b, + D(c) ~ c, + D(d) ~ d, + e ~ d] + +@named sys = ODESystem(eqs, t, [a, b, c, d, e], []; + defaults = Dict([a => 1.0, + b => 1.0, + c => 1.0, + d => 1.0, + e => 1.0])) +sys = structural_simplify(sys) +prob = ODEProblem(sys, [], (0, 1.0)) +prob_sym = ODEProblem(sys, [], (0, 1.0), dense_output = false) + +sol = solve(prob, Tsit5()) +sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, e]) + +@test sol_sym[a] ≈ sol[a] +@test sol_sym[c] ≈ sol[c] +@test sol_sym[d] ≈ sol[d] # Dependency `d` of the observed variable `e` is automatically saved too. +@test sol_sym[e] ≈ sol[e] + +@test sol.u != sol_sym.u + +@test_throws Exception sol_sym[b] +``` \ No newline at end of file diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index a53a88e40d..07a3346f0d 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -39,7 +39,7 @@ import SymbolicIndexingInterface import SymbolicIndexingInterface: independent_variables, states, state_sym_to_index, parameters, observed, observed_sym_to_index, get_state_dependencies, - get_deps_of_observed, is_observed_sym + get_deps_of_observed, is_observed_sym, unknown_states export independent_variables, states, parameters, observed import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index ce6ac64653..00b34b212d 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -230,6 +230,7 @@ function build_torn_function(sys; jacobian_sparsity = true, checkbounds = false, max_inlining_size = nothing, + dense_output = true, kw...) max_inlining_size = something(max_inlining_size, MAX_INLINE_NLSOLVE_SIZE) rhss = [] @@ -333,6 +334,7 @@ function build_torn_function(sys; build_observed_function(state, obsvar, var_eq_matching, var_sccs, is_solver_state_idxs, assignments, deps, sol_states, var2assignment, + dense_output = dense_output, checkbounds = checkbounds) end if args === () @@ -389,7 +391,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, expression = false, output_type = Array, checkbounds = true, - dense_states = false) + dense_output = true) is_not_prepended_assignment = trues(length(assignments)) if (isscalar = !(ts isa AbstractVector)) ts = [ts] @@ -481,7 +483,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs, solves = [] end - unknown_state_deps = dense_states ? unknown_states : + unknown_state_deps = dense_output ? unknown_states : get_deps_of_observed(unknown_states, obs) subs = [] @@ -532,6 +534,7 @@ function ODAEProblem{iip}(sys, parammap = DiffEqBase.NullParameters(); callback = nothing, use_union = false, + dense_output = true, check = true, kwargs...) where {iip} eqs = equations(sys) @@ -551,8 +554,9 @@ function ODAEProblem{iip}(sys, kwargs = filter_kwargs(kwargs) if cbs === nothing - ODEProblem{iip}(fun, u0, tspan, p; kwargs...) + ODEProblem{iip}(fun, u0, tspan, p; dense_output = dense_output, kwargs...) else - ODEProblem{iip}(fun, u0, tspan, p; callback = cbs, kwargs...) + ODEProblem{iip}(fun, u0, tspan, p; dense_output = dense_output, callback = cbs, + kwargs...) end end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 1c2c584e1b..b8ac2c05bc 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -631,7 +631,7 @@ $(SIGNATURES) Return a list of actual states needed to be solved by solvers. """ -function unknown_states(sys::AbstractSystem) +function SymbolicIndexingInterface.unknown_states(sys::AbstractSystem) sts = states(sys) if has_unknown_states(sys) sts = something(get_unknown_states(sys), sts) @@ -640,7 +640,7 @@ function unknown_states(sys::AbstractSystem) end function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), unknown_states(sys)) |> safe_unwrap + findfirst(isequal(safe_unwrap(sym)), unknown_states(sys)) |> safe_unwrap end function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym)) @@ -680,6 +680,7 @@ function SymbolicIndexingInterface.get_deps_of_observed(sts, obs::AbstractArray{ end function SymbolicIndexingInterface.get_state_dependencies(sts, obs, sym) + sym = safe_unwrap(sym) i = observed_sym_to_index(obs, sym) if isnothing(i) return [] @@ -700,7 +701,7 @@ end function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, sym) obs = observed(sys) - sts = states(sys) + sts = unknown_states(sys) return get_state_dependencies(sts, obs, sym) end diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index aef90b6c48..633be8f1ea 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -275,6 +275,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s steady_state = false, checkbounds = false, sparsity = false, + dense_output = true, analytic = nothing, kwargs...) where {iip, specialize} f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, @@ -337,7 +338,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s let sys = sys, dict = Dict() function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) + build_explicit_observed_function(sys, obsvar; + dense_output = dense_output) end if args === () let obs = obs @@ -352,7 +354,8 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s let sys = sys, dict = Dict() function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) + build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds, + dense_output = dense_output) end if args === () let obs = obs @@ -685,6 +688,7 @@ DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem, u0map, tspan, checkbounds = false, sparse = false, simplify = false, linenumbers = true, parallel = SerialForm(), + dense_output = true, kwargs...) where {iip} ``` @@ -714,12 +718,13 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = parammap = DiffEqBase.NullParameters(); callback = nothing, check_length = true, + dense_output = true, kwargs...) where {iip, specialize} has_difference = any(isdifferenceeq, equations(sys)) f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap; t = tspan !== nothing ? tspan[1] : tspan, has_difference = has_difference, - check_length, kwargs...) + check_length, dense_output = dense_output, kwargs...) cbs = process_events(sys; callback, has_difference, kwargs...) if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...) @@ -752,7 +757,8 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = if svs !== nothing kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) end - ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) + ODEProblem{iip}(f, u0, tspan, p, pt; dense_output = dense_output, kwargs1..., + kwargs...) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 742288e72c..eeddabfd94 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -308,7 +308,7 @@ function build_explicit_observed_function(sys, ts; output_type = Array, checkbounds = true, throw = true, - dense_states = false) + dense_output = true) if (isscalar = !(ts isa AbstractVector)) ts = [ts] end @@ -379,8 +379,7 @@ function build_explicit_observed_function(sys, ts; rhs = eq.rhs push!(obsexprs, lhs ← rhs) end - - statedeps = dense_states ? states(sys) : safe_unwrap.(get_deps_of_observed(sys)) + statedeps = dense_output ? states(sys) : safe_unwrap.(get_deps_of_observed(sys)) dvs = DestructuredArgs(statedeps, inbounds = !checkbounds) ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds) diff --git a/test/odesystem.jl b/test/odesystem.jl index 68381a305d..36d8bba630 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1037,10 +1037,10 @@ end e => 1.0])) sys = structural_simplify(sys) prob = ODEProblem(sys, [], (0, 1.0)) - prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, c, e]) + prob_sym = ODEProblem(sys, [], (0, 1.0), dense_output = false) sol = solve(prob, Tsit5()) - sol_sym = solve(prob_sym, Tsit5()) + sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, e]) @test sol_sym[a] ≈ sol[a] @test sol_sym[c] ≈ sol[c] @@ -1072,10 +1072,10 @@ end e => 1.0])) sys = structural_simplify(sys) prob = ODEProblem(sys, [], (0, 1.0)) - prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, c, b]) + prob_sym = ODEProblem(sys, [], (0, 1.0), dense_output = false) sol = solve(prob, Tsit5()) - sol_sym = solve(prob_sym, Tsit5()) + sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, b]) @test sol_sym[a] ≈ sol[a] @test sol_sym[b] ≈ sol[b] From 39ab8f9f6b7c54e5350cff31cdcc47346afcf91d Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 27 Feb 2023 15:22:17 +0000 Subject: [PATCH 11/18] fix kwarg --- src/systems/abstractsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index b8ac2c05bc..9b80913ea8 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1290,7 +1290,7 @@ function linearization_function(sys::AbstractSystem, inputs, input_idxs = input_idxs, sts = states(sys), fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys), - h = build_explicit_observed_function(sys, outputs; dense_states = true), + h = build_explicit_observed_function(sys, outputs; dense_output = true), chunk = ForwardDiff.Chunk(input_idxs) function (u, p, t) From 6ba93f48dd19a610ba8b764ba5d07873572a231b Mon Sep 17 00:00:00 2001 From: xtalax Date: Tue, 7 Mar 2023 17:03:53 +0000 Subject: [PATCH 12/18] add `convert_to_getindex` --- src/systems/abstractsystem.jl | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 95d1665abc..00c5900785 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -647,7 +647,7 @@ function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym) end function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys)) |> safe_unwrap + findfirst(isequal((safe_unwrap(sym)), SymbolicIndexingInterface.parameters(sys)) |> safe_unwrap end function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym) !isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym)) @@ -726,6 +726,28 @@ function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem return out end +function operations(ex; ignore = []) + if istree(ex) + op = operation(ex) + if !any(isequal(op), ignore) + return vcat(operations.(arguments(ex), ignore = ignore)..., op) + end + end + return [] +end + +function SymbolicIndexingInterface.convert_to_getindex(A::SciMLBase.AbstractSciMLSolution, expr, is...) + ex_vars = vars(expr) + var_rules = [@rule v => A[v, is...] for v in ex_vars] + + ex_ops = operations(expr, ignore = vcat(operation.(ex_vars), getindex)) + op_rules = [@rule op(~~a) => broadcast(op, ~a...)for op in ex_ops] + + ch = Chain(vcat(var_rules, op_rules)) + ex = ch(ex) + return ex +end + struct AbstractSysToExpr sys::AbstractSystem states::Vector From b4b67c0c82cd4dd3f2f7a228a77a309c20cbd185 Mon Sep 17 00:00:00 2001 From: xtalax Date: Fri, 10 Mar 2023 21:15:17 +0000 Subject: [PATCH 13/18] ode passing --- src/ModelingToolkit.jl | 5 +- src/structural_transformation/codegen.jl | 6 +- src/systems/abstractsystem.jl | 137 +++++++++++++---------- src/systems/diffeqs/abstractodesystem.jl | 11 +- src/systems/diffeqs/odesystem.jl | 3 +- src/utils.jl | 2 +- test/odaeproblem.jl | 3 +- test/odesystem.jl | 20 +++- test/runtests.jl | 2 +- 9 files changed, 112 insertions(+), 77 deletions(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 167cc403e8..96830a2e37 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -36,10 +36,13 @@ using RecursiveArrayTools import SymbolicIndexingInterface +const SII = SymbolicIndexingInterface + import SymbolicIndexingInterface: independent_variables, states, state_sym_to_index, parameters, observed, observed_sym_to_index, get_state_dependencies, - get_deps_of_observed, is_observed_sym, unknown_states + get_deps_of_observed, is_observed_sym, unknown_states, + convert_to_getindex, is_symbolic_expr, safe_unwrap export independent_variables, states, parameters, observed import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, diff --git a/src/structural_transformation/codegen.jl b/src/structural_transformation/codegen.jl index 00b34b212d..3d8033a5a4 100644 --- a/src/structural_transformation/codegen.jl +++ b/src/structural_transformation/codegen.jl @@ -534,7 +534,6 @@ function ODAEProblem{iip}(sys, parammap = DiffEqBase.NullParameters(); callback = nothing, use_union = false, - dense_output = true, check = true, kwargs...) where {iip} eqs = equations(sys) @@ -551,12 +550,11 @@ function ODAEProblem{iip}(sys, has_difference = any(isdifferenceeq, eqs) cbs = process_events(sys; callback, has_difference, kwargs...) - kwargs = filter_kwargs(kwargs) if cbs === nothing - ODEProblem{iip}(fun, u0, tspan, p; dense_output = dense_output, kwargs...) + ODEProblem{iip}(fun, u0, tspan, p; kwargs...) else - ODEProblem{iip}(fun, u0, tspan, p; dense_output = dense_output, callback = cbs, + ODEProblem{iip}(fun, u0, tspan, p; callback = cbs, kwargs...) end end diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 00c5900785..0a7651c1db 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -161,7 +161,7 @@ function independent_variable(sys::AbstractSystem) end #Treat the result as a vector of symbols always -function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem) +function SII.independent_variables(sys::AbstractSystem) systype = typeof(sys) @warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`." if isdefined(sys, :iv) @@ -173,11 +173,11 @@ function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem) end end -function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem) +function SII.independent_variables(sys::AbstractTimeDependentSystem) [getfield(sys, :iv)] end -SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = [] -function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem) +SII.independent_variables(sys::AbstractTimeIndependentSystem) = [] +function SII.independent_variables(sys::AbstractMultivariateSystem) getfield(sys, :ivs) end @@ -512,7 +512,7 @@ function namespace_expr(O, sys, n = nameof(sys)) end end -function states(sys::AbstractSystem) +function SII.states(sys::AbstractSystem) sts = get_states(sys) systems = get_systems(sys) unique(isempty(systems) ? @@ -520,7 +520,20 @@ function states(sys::AbstractSystem) [sts; reduce(vcat, namespace_variables.(systems))]) end -function SymbolicIndexingInterface.parameters(sys::AbstractSystem) +""" +$(SIGNATURES) + +Return a list of actual states needed to be solved by solvers. +""" +function SII.unknown_states(sys::AbstractSystem) + sts = something(get_unknown_states(sys), get_states(sys)) + systems = get_systems(sys) + return unique(isempty(systems) ? + sts : + [sts; reduce(vcat, SII.unknown_states.(systems))]) +end + +function SII.parameters(sys::AbstractSystem) ps = get_ps(sys) systems = get_systems(sys) unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))]) @@ -532,7 +545,7 @@ function controls(sys::AbstractSystem) isempty(systems) ? ctrls : [ctrls; reduce(vcat, namespace_controls.(systems))] end -function SymbolicIndexingInterface.observed(sys::AbstractSystem) +function SII.observed(sys::AbstractSystem) obs = get_observed(sys) systems = get_systems(sys) [obs; @@ -624,54 +637,45 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem) return x end -SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys)) - -""" -$(SIGNATURES) +### +### SymbolicIndexingInterface +### -Return a list of actual states needed to be solved by solvers. -""" -function SymbolicIndexingInterface.unknown_states(sys::AbstractSystem) - sts = states(sys) - if has_unknown_states(sys) - sts = something(get_unknown_states(sys), sts) - end - return sts -end +SII.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys)) -function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(safe_unwrap(sym)), unknown_states(sys)) |> safe_unwrap +function SII.state_sym_to_index(sys::AbstractSystem, sym) + findfirst(isequal(SII.safe_unwrap(sym)), SII.unknown_states(sys)) |> safe_unwrap end -function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym) - !isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym)) +function SII.is_state_sym(sys::AbstractSystem, sym) + !isnothing(SII.state_sym_to_index(sys, sym)) end -function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal((safe_unwrap(sym)), SymbolicIndexingInterface.parameters(sys)) |> safe_unwrap +function SII.param_sym_to_index(sys::AbstractSystem, sym) + findfirst(isequal((SII.safe_unwrap(sym))), parameters(sys)) |> safe_unwrap end -function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym) - !isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym)) +function SII.is_param_sym(sys::AbstractSystem, sym) + !isnothing(SII.param_sym_to_index(sys, sym)) end -function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(safe_unwrap(sym)), getfield.(observed(sys), (:lhs,))) |> +function SII.observed_sym_to_index(sys::AbstractSystem, sym) + findfirst(o -> isequal(SII.safe_unwrap(sym), o.lhs), observed(sys)) |> safe_unwrap end -function SymbolicIndexingInterface.observed_sym_to_index(sys::AbstractSystem, sym::Equation) - findfirst(isequal(sym), observed(sys)) |> safe_unwrap +function SII.observed_sym_to_index(sys::AbstractSystem, sym::Equation) + findfirst(isequal(SII.safe_unwrap(sym)), observed(sys)) |> safe_unwrap end -function SymbolicIndexingInterface.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym) - findfirst(isequal(safe_unwrap(sym)), getfield.(obs, (:lhs,))) |> +function SII.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym) + findfirst(o -> isequal(SII.safe_unwrap(sym), o.lhs), obs) |> safe_unwrap end -function SymbolicIndexingInterface.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym::Equation) - findfirst(isequal(sym), obs) |> safe_unwrap +function SII.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym::Equation) + findfirst(isequal(SII.safe_unwrap(sym)), obs) |> safe_unwrap end -function SymbolicIndexingInterface.is_observed_sym(sys::AbstractSystem, sym) - !isnothing(SymbolicIndexingInterface.observed_sym_to_index(sys, sym)) +function SII.is_observed_sym(sys::AbstractSystem, sym) + !isnothing(SII.observed_sym_to_index(sys, sym)) end -function SymbolicIndexingInterface.get_deps_of_observed(sts, obs::AbstractArray{<:Equation}) +function SII.get_deps_of_observed(sts, obs::AbstractArray{<:Equation}) deps = mapreduce(vcat, obs, init = []) do eq get_state_dependencies(sts, obs, eq.lhs) end |> unique @@ -679,8 +683,9 @@ function SymbolicIndexingInterface.get_deps_of_observed(sts, obs::AbstractArray{ return deps end -function SymbolicIndexingInterface.get_state_dependencies(sts, obs, sym) - sym = safe_unwrap(sym) +# ! this is broken vvvvvv +function SII.get_state_dependencies(sts, obs, sym) + sym = SII.safe_unwrap(sym) i = observed_sym_to_index(obs, sym) if isnothing(i) return [] @@ -699,14 +704,14 @@ function SymbolicIndexingInterface.get_state_dependencies(sts, obs, sym) return filter(x -> any(isequal(x), sts), out) end -function SymbolicIndexingInterface.get_state_dependencies(sys::AbstractSystem, sym) - obs = observed(sys) - sts = unknown_states(sys) +function SII.get_state_dependencies(sys::AbstractSystem, sym) + obs = SII.observed(sys) + sts = SII.unknown_states(sys) return get_state_dependencies(sts, obs, sym) end -function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem, sym) - obs = observed(sys) +function SII.get_observed_dependencies(sys::AbstractSystem, sym) + obs = SII.observed(sys) i = observed_sym_to_index(sys, sym) if isnothing(i) @@ -726,28 +731,44 @@ function SymbolicIndexingInterface.get_observed_dependencies(sys::AbstractSystem return out end -function operations(ex; ignore = []) +function operations(ex) if istree(ex) op = operation(ex) - if !any(isequal(op), ignore) - return vcat(operations.(arguments(ex), ignore = ignore)..., op) - end + return vcat(operations.(arguments(ex))..., op) end return [] end -function SymbolicIndexingInterface.convert_to_getindex(A::SciMLBase.AbstractSciMLSolution, expr, is...) +""" +Must be defined here because it needs Symbolics +""" +function SII.convert_to_getindex(A::SciMLBase.AbstractSciMLSolution, expr, is...) + expr = scalarize(expr) ex_vars = vars(expr) - var_rules = [@rule v => A[v, is...] for v in ex_vars] - ex_ops = operations(expr, ignore = vcat(operation.(ex_vars), getindex)) - op_rules = [@rule op(~~a) => broadcast(op, ~a...)for op in ex_ops] + var_rules = [@rule v => A[v, is...] for v in ex_vars] + ex_ops = operations(expr) + ignore = vcat(operation.(ex_vars), getindex) + ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops) + op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops] - ch = Chain(vcat(var_rules, op_rules)) - ex = ch(ex) + ch = Postwalk(Chain([var_rules; op_rules])) + ex = ch(expr) return ex end +function SII.is_symbolic_expr(ex::SymbolicUtils.Symbolic) + ex_vars = vars(ex) + if istree(ex) + return !any(isequal(ex), ex_vars) + end + return false +end + +### +### System to Expr +### + struct AbstractSysToExpr sys::AbstractSystem states::Vector @@ -1354,7 +1375,7 @@ function linearization_function(sys::AbstractSystem, inputs, uf = SciMLBase.UJacobianWrapper(fun, t, p) fg_xz = ForwardDiff.jacobian(uf, u) h_xz = ForwardDiff.jacobian(let p = p, t = t - (xz) -> h(xz, p, t) #! signature must change + (xz) -> h(xz, p, t) end, u) pf = SciMLBase.ParamJacobianWrapper(fun, t, u) fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk) @@ -1365,7 +1386,7 @@ function linearization_function(sys::AbstractSystem, inputs, h_xz = fg_u = zeros(0, length(inputs)) end hp = let u = u, t = t - p -> h(u, p, t) #! signature must change + p -> h(u, p, t) end h_u = jacobian_wrt_vars(hp, p, input_idxs, chunk) (f_x = fg_xz[diff_idxs, diff_idxs], diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 633be8f1ea..6bb7f3234f 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -275,7 +275,6 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s steady_state = false, checkbounds = false, sparsity = false, - dense_output = true, analytic = nothing, kwargs...) where {iip, specialize} f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression}, @@ -338,8 +337,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s let sys = sys, dict = Dict() function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; - dense_output = dense_output) + build_explicit_observed_function(sys, obsvar) end if args === () let obs = obs @@ -354,8 +352,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s let sys = sys, dict = Dict() function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds, - dense_output = dense_output) + build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) end if args === () let obs = obs @@ -576,7 +573,7 @@ end """ u0, p, defs = get_u0_p(sys, u0map, parammap; use_union=false, tofloat=!use_union) -Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point. +Take dictionaries with initial conditions and parameters and convert them to numeric arrays `u0` and `p`. Also return the merged dictionary `defs` containing the entire operating point. """ function get_u0_p(sys, u0map, parammap; use_union = false, tofloat = !use_union) eqs = equations(sys) @@ -757,7 +754,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = if svs !== nothing kwargs1 = merge(kwargs1, (disc_saved_values = svs,)) end - ODEProblem{iip}(f, u0, tspan, p, pt; dense_output = dense_output, kwargs1..., + ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...) end get_callback(prob::ODEProblem) = prob.kwargs[:callback] diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index eeddabfd94..49d7e79d3f 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -308,7 +308,7 @@ function build_explicit_observed_function(sys, ts; output_type = Array, checkbounds = true, throw = true, - dense_output = true) + dense_output = false) if (isscalar = !(ts isa AbstractVector)) ts = [ts] end @@ -336,7 +336,6 @@ function build_explicit_observed_function(sys, ts; # FIXME: This is a rather rough estimate of dependencies. We assume # the expression depends on everything before the `maxidx`. - # ? subs = Dict() maxidx = 0 for s in dep_vars diff --git a/src/utils.jl b/src/utils.jl index 93246a8c91..6527abd860 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -878,4 +878,4 @@ end normalize_to_differential(s) = s -safe_unwrap(x) = x isa Num ? unwrap(x) : x +SymbolicIndexingInterface.safe_unwrap(x::Num) = unwrap(x) diff --git a/test/odaeproblem.jl b/test/odaeproblem.jl index 595961814f..d1d601e0eb 100644 --- a/test/odaeproblem.jl +++ b/test/odaeproblem.jl @@ -63,4 +63,5 @@ ref_sol = solve(ref_prob, FBDF(), saveat = 1.0) sol = solve(prob, Tsit5(), saveat = 1.0) # test observed -@test sol[observed(sys)[1].lhs] ≈ ref_sol[observed(sys)[1].lhs] atol = 0.001 + +@test sol[observed(sys)[1].lhs]≈ref_sol[observed(sys)[1].lhs] atol=0.001 diff --git a/test/odesystem.jl b/test/odesystem.jl index 36d8bba630..a98f710d6c 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1037,7 +1037,7 @@ end e => 1.0])) sys = structural_simplify(sys) prob = ODEProblem(sys, [], (0, 1.0)) - prob_sym = ODEProblem(sys, [], (0, 1.0), dense_output = false) + prob_sym = ODEProblem(sys, [], (0, 1.0)) sol = solve(prob, Tsit5()) sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, e]) @@ -1047,6 +1047,14 @@ end @test sol_sym[d] ≈ sol[d] @test sol_sym[e] ≈ sol[e] + sola = sol[a] + sole = sol[e] + solc = sol[c] + + expr1 = @. sola^2 + sole^2 - sin(solc) + 1 + + @test sol_sym[a^2 + e^2 - sin(c) + 1] ≈ expr1 + @test sol.u != sol_sym.u @test_throws Exception sol_sym[b] @@ -1072,7 +1080,7 @@ end e => 1.0])) sys = structural_simplify(sys) prob = ODEProblem(sys, [], (0, 1.0)) - prob_sym = ODEProblem(sys, [], (0, 1.0), dense_output = false) + prob_sym = ODEProblem(sys, [], (0, 1.0)) sol = solve(prob, Tsit5()) sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, b]) @@ -1081,6 +1089,14 @@ end @test sol_sym[b] ≈ sol[b] @test sol_sym[c] ≈ sol[c] + sola = sol[a] + solb = sol[b] + solc = sol[c] + + expr1 = sola[2]^2 + solb[2]^2 - sin(solc[2]) + + @test sol_sym[a^2 + b^2 - sin(c), 2] ≈ expr1 + @test sol.u != sol_sym.u @test_throws Exception sol_sym[d] diff --git a/test/runtests.jl b/test/runtests.jl index 5b0a824872..7972f34de0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,7 @@ using SafeTestsets, Test -@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end @safetestset "ODESystem Test" begin include("odesystem.jl") end +@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end @safetestset "AliasGraph Test" begin include("alias.jl") end @safetestset "Pantelides Test" begin include("pantelides.jl") end @safetestset "Linear Algebra Test" begin include("linalg.jl") end From 5c859bd3ae6ff778c71754ac9575c3cbfc51df4f Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 13 Mar 2023 14:05:53 +0000 Subject: [PATCH 14/18] sde failing? --- src/systems/abstractsystem.jl | 8 +---- src/systems/diffeqs/abstractodesystem.jl | 8 +++++ test/runtests.jl | 40 ++++++++++++------------ 3 files changed, 29 insertions(+), 27 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 0a7651c1db..c111e9bfea 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -525,13 +525,7 @@ $(SIGNATURES) Return a list of actual states needed to be solved by solvers. """ -function SII.unknown_states(sys::AbstractSystem) - sts = something(get_unknown_states(sys), get_states(sys)) - systems = get_systems(sys) - return unique(isempty(systems) ? - sts : - [sts; reduce(vcat, SII.unknown_states.(systems))]) -end +SII.unknown_states(sys::AbstractSystem) = SII.states(sys) function SII.parameters(sys::AbstractSystem) ps = get_ps(sys) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 6bb7f3234f..1a447ca3be 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -13,6 +13,14 @@ function gen_quoted_kwargs(kwargs) kwargparam end +function SII.unknown_states(sys::ODESystem) + sts = something(get_unknown_states(sys), get_states(sys)) + systems = get_systems(sys) + return unique(isempty(systems) ? + sts : + [sts; reduce(vcat, SII.unknown_states.(systems))]) +end + function calculate_tgrad(sys::AbstractODESystem; simplify = false) isempty(get_tgrad(sys)[]) || return get_tgrad(sys)[] # use cached tgrad, if possible diff --git a/test/runtests.jl b/test/runtests.jl index 7972f34de0..17bd67b0b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,25 +1,5 @@ using SafeTestsets, Test -@safetestset "ODESystem Test" begin include("odesystem.jl") end -@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end -@safetestset "AliasGraph Test" begin include("alias.jl") end -@safetestset "Pantelides Test" begin include("pantelides.jl") end -@safetestset "Linear Algebra Test" begin include("linalg.jl") end -@safetestset "AbstractSystem Test" begin include("abstractsystem.jl") end -@safetestset "Variable Scope Tests" begin include("variable_scope.jl") end -@safetestset "Symbolic Parameters Test" begin include("symbolic_parameters.jl") end -@safetestset "Parsing Test" begin include("variable_parsing.jl") end -@safetestset "Simplify Test" begin include("simplify.jl") end -@safetestset "Direct Usage Test" begin include("direct.jl") end -@safetestset "System Linearity Test" begin include("linearity.jl") end -@safetestset "Linearization Tests" begin include("linearize.jl") end -@safetestset "Input Output Test" begin include("input_output_handling.jl") end -@safetestset "Clock Test" begin include("clock.jl") end -@safetestset "DiscreteSystem Test" begin include("discretesystem.jl") end -@safetestset "Unitful Quantities Test" begin include("units.jl") end -@safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end -@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end -@safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end @safetestset "SDESystem Test" begin include("sdesystem.jl") end @safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end @safetestset "PDE Construction Test" begin include("pde.jl") end @@ -50,5 +30,25 @@ println("Last test requires gcc available in the path!") @safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end @safetestset "FuncAffect Test" begin include("funcaffect.jl") end @safetestset "Constants Test" begin include("constants.jl") end +@safetestset "ODESystem Test" begin include("odesystem.jl") end +@safetestset "ODAEProblem Test" begin include("odaeproblem.jl") end +@safetestset "AliasGraph Test" begin include("alias.jl") end +@safetestset "Pantelides Test" begin include("pantelides.jl") end +@safetestset "Linear Algebra Test" begin include("linalg.jl") end +@safetestset "AbstractSystem Test" begin include("abstractsystem.jl") end +@safetestset "Variable Scope Tests" begin include("variable_scope.jl") end +@safetestset "Symbolic Parameters Test" begin include("symbolic_parameters.jl") end +@safetestset "Parsing Test" begin include("variable_parsing.jl") end +@safetestset "Simplify Test" begin include("simplify.jl") end +@safetestset "Direct Usage Test" begin include("direct.jl") end +@safetestset "System Linearity Test" begin include("linearity.jl") end +@safetestset "Linearization Tests" begin include("linearize.jl") end +@safetestset "Input Output Test" begin include("input_output_handling.jl") end +@safetestset "Clock Test" begin include("clock.jl") end +@safetestset "DiscreteSystem Test" begin include("discretesystem.jl") end +@safetestset "Unitful Quantities Test" begin include("units.jl") end +@safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end +@safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end +@safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end # Reference tests go Last @safetestset "Latexify recipes Test" begin include("latexify.jl") end From 839e5a5a9aa8c19db6b364ed8b997d2e45380d7a Mon Sep 17 00:00:00 2001 From: xtalax Date: Mon, 13 Mar 2023 14:43:05 +0000 Subject: [PATCH 15/18] move tests --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 17bd67b0b3..adcb1243c2 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,5 @@ using SafeTestsets, Test -@safetestset "SDESystem Test" begin include("sdesystem.jl") end @safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end @safetestset "PDE Construction Test" begin include("pde.jl") end @safetestset "JumpSystem Test" begin include("jumpsystem.jl") end @@ -50,5 +49,6 @@ println("Last test requires gcc available in the path!") @safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end @safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end @safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end +@safetestset "SDESystem Test" begin include("sdesystem.jl") end # Reference tests go Last @safetestset "Latexify recipes Test" begin include("latexify.jl") end From 5739506a3b5c87d52a502eb8efe603803d753586 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 15 Mar 2023 16:45:54 +0000 Subject: [PATCH 16/18] fix sde --- docs/src/basics/saving_syms.md | 76 +++++++++++++++++--------------- src/systems/diffeqs/sdesystem.jl | 2 +- test/runtests.jl | 2 +- 3 files changed, 42 insertions(+), 38 deletions(-) diff --git a/docs/src/basics/saving_syms.md b/docs/src/basics/saving_syms.md index 0e70f9c749..e3d78f3088 100644 --- a/docs/src/basics/saving_syms.md +++ b/docs/src/basics/saving_syms.md @@ -3,41 +3,45 @@ It is possible to specify symbolically which states of an ODESystem to save, with the `save_idxs` keyword argument to the `solve` call. This may be important to you to save memory for large problems. -Take care to disable the `dense_output` flag when constructing the problem to ensure that observed -variables can be properly reconstructed. Failing to do this may cause incorrect construction of observed variables. - ```julia -using ModelingToolkit, OrdinaryDiffEq, Test -@parameters t -@variables a(t) b(t) c(t) d(t) e(t) - -D = Differential(t) - -eqs = [D(a) ~ a, - D(b) ~ b, - D(c) ~ c, - D(d) ~ d, - e ~ d] - -@named sys = ODESystem(eqs, t, [a, b, c, d, e], []; - defaults = Dict([a => 1.0, - b => 1.0, - c => 1.0, - d => 1.0, - e => 1.0])) -sys = structural_simplify(sys) -prob = ODEProblem(sys, [], (0, 1.0)) -prob_sym = ODEProblem(sys, [], (0, 1.0), dense_output = false) - -sol = solve(prob, Tsit5()) -sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, e]) - -@test sol_sym[a] ≈ sol[a] -@test sol_sym[c] ≈ sol[c] -@test sol_sym[d] ≈ sol[d] # Dependency `d` of the observed variable `e` is automatically saved too. -@test sol_sym[e] ≈ sol[e] - -@test sol.u != sol_sym.u - -@test_throws Exception sol_sym[b] + @parameters t + @variables a(t) b(t) c(t) d(t) e(t) + + D = Differential(t) + + eqs = [D(a) ~ a, + D(b) ~ b, + D(c) ~ c, + D(d) ~ d, + e ~ d] + + @named sys = ODESystem(eqs, t, [a, b, c, d, e], []; + defaults = Dict([a => 1.0, + b => 1.0, + c => 1.0, + d => 1.0, + e => 1.0])) + sys = structural_simplify(sys) + prob = ODEProblem(sys, [], (0, 1.0)) + prob_sym = ODEProblem(sys, [], (0, 1.0)) + + sol = solve(prob, Tsit5()) + sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, e]) + + @test sol_sym[a] ≈ sol[a] + @test sol_sym[c] ≈ sol[c] + @test sol_sym[d] ≈ sol[d] # `d` is automatically saved too, as `e` depends on it. + @test sol_sym[e] ≈ sol[e] + + sola = sol[a] + sole = sol[e] + solc = sol[c] + + expr1 = @. sola^2 + sole^2 - sin(solc) + 1 + + @test sol_sym[a^2 + e^2 - sin(c) + 1] ≈ expr1 # indexing with a symbolic expr, these can be quite complex. + + @test sol.u != sol_sym.u + + @test_throws Exception sol_sym[b] # `b` is not saved. ``` \ No newline at end of file diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 680c24deb0..2dd6f512c1 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -442,7 +442,7 @@ function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = states(sys), observedfun = let sys = sys, dict = Dict() function generated_observed(obsvar, u, p, t) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) + build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds, dense_output = true) end obs(u, p, t) end diff --git a/test/runtests.jl b/test/runtests.jl index adcb1243c2..17bd67b0b3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SafeTestsets, Test +@safetestset "SDESystem Test" begin include("sdesystem.jl") end @safetestset "NonlinearSystem Test" begin include("nonlinearsystem.jl") end @safetestset "PDE Construction Test" begin include("pde.jl") end @safetestset "JumpSystem Test" begin include("jumpsystem.jl") end @@ -49,6 +50,5 @@ println("Last test requires gcc available in the path!") @safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end @safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end @safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end -@safetestset "SDESystem Test" begin include("sdesystem.jl") end # Reference tests go Last @safetestset "Latexify recipes Test" begin include("latexify.jl") end From 3ca3658cff9f3bf7276d7fab228da440a0e15539 Mon Sep 17 00:00:00 2001 From: xtalax Date: Wed, 15 Mar 2023 17:39:40 +0000 Subject: [PATCH 17/18] observed fix --- src/systems/clock_inference.jl | 6 ++++-- src/systems/discrete_system/discrete_system.jl | 2 +- src/systems/jumps/jumpsystem.jl | 3 ++- src/systems/nonlinear/nonlinearsystem.jl | 2 +- src/systems/optimization/optimizationsystem.jl | 2 +- test/runtests.jl | 4 ++-- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/systems/clock_inference.jl b/src/systems/clock_inference.jl index 8b7aaf19ab..d9d457aa86 100644 --- a/src/systems/clock_inference.jl +++ b/src/systems/clock_inference.jl @@ -187,12 +187,14 @@ function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock; needed_cont_to_disc_obs, throw = false, expression = true, - output_type = SVector) + output_type = SVector, + dense_output = true) @set! sys.ps = appended_parameters disc_to_cont_obs = build_explicit_observed_function(sys, needed_disc_to_cont_obs, throw = false, expression = true, - output_type = SVector) + output_type = SVector, + dense_output = true) ni = length(input) ns = length(states(sys)) disc = Func([ diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 83c8204ddc..1046fa098a 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -353,7 +353,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(sys::DiscreteSystem, observedfun = let sys = sys, dict = Dict() function generate_observed(obsvar, u, p, t) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) + build_explicit_observed_function(sys, obsvar, dense_output = true) end obs(u, p, t) end diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index be772c7003..ddab0dcf88 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -313,7 +313,8 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, observedfun = let sys = sys, dict = Dict() function generated_observed(obsvar, u, p, t) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) + build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds, + dense_output = true) end obs(u, p, t) end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index a0ab0e58f5..9c79fa2fae 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -247,7 +247,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = states(sys observedfun = let sys = sys, dict = Dict() function generated_observed(obsvar, u, p) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) + build_explicit_observed_function(sys, obsvar, dense_output = true) end obs(u, p) end diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 24fc68e8cd..3c8ac60c52 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -308,7 +308,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map, observedfun = let sys = sys, dict = Dict() function generated_observed(obsvar, args...) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar) + build_explicit_observed_function(sys, obsvar, dense_output = true) end if args === () let obs = obs diff --git a/test/runtests.jl b/test/runtests.jl index 17bd67b0b3..df135271f3 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,8 +9,6 @@ using SafeTestsets, Test @safetestset "Components Test" begin include("components.jl") end @safetestset "print_tree" begin include("print_tree.jl") end @safetestset "Error Handling" begin include("error_handling.jl") end -@safetestset "StructuralTransformations" begin include("structural_transformation/runtests.jl") end -@safetestset "State Selection Test" begin include("state_selection.jl") end @safetestset "Symbolic Event Test" begin include("symbolic_events.jl") end @safetestset "Stream Connnect Test" begin include("stream_connectors.jl") end @safetestset "Lowering Integration Test" begin include("lowering_solving.jl") end @@ -50,5 +48,7 @@ println("Last test requires gcc available in the path!") @safetestset "LabelledArrays Test" begin include("labelledarrays.jl") end @safetestset "Mass Matrix Test" begin include("mass_matrix.jl") end @safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end +@safetestset "State Selection Test" begin include("state_selection.jl") end +@safetestset "StructuralTransformations" begin include("structural_transformation/runtests.jl") end # Reference tests go Last @safetestset "Latexify recipes Test" begin include("latexify.jl") end From e3e2d966c62d7c9657a56ba0b3a3d570bc12af9d Mon Sep 17 00:00:00 2001 From: xtalax Date: Thu, 16 Mar 2023 13:58:27 +0000 Subject: [PATCH 18/18] fix jump --- src/systems/jumps/jumpsystem.jl | 3 +-- test/runtests.jl | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index ddab0dcf88..be772c7003 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -313,8 +313,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, observedfun = let sys = sys, dict = Dict() function generated_observed(obsvar, u, p, t) obs = get!(dict, value(obsvar)) do - build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds, - dense_output = true) + build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds) end obs(u, p, t) end diff --git a/test/runtests.jl b/test/runtests.jl index df135271f3..2f55e7be16 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -25,7 +25,6 @@ println("Last test requires gcc available in the path!") @safetestset "C Compilation Test" begin include("ccompile.jl") end @testset "Serialization" begin include("serialization.jl") end @safetestset "Modelingtoolkitize Test" begin include("modelingtoolkitize.jl") end -@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end @safetestset "FuncAffect Test" begin include("funcaffect.jl") end @safetestset "Constants Test" begin include("constants.jl") end @safetestset "ODESystem Test" begin include("odesystem.jl") end @@ -50,5 +49,6 @@ println("Last test requires gcc available in the path!") @safetestset "SteadyStateSystem Test" begin include("steadystatesystems.jl") end @safetestset "State Selection Test" begin include("state_selection.jl") end @safetestset "StructuralTransformations" begin include("structural_transformation/runtests.jl") end +@safetestset "OptimizationSystem Test" begin include("optimizationsystem.jl") end # Reference tests go Last @safetestset "Latexify recipes Test" begin include("latexify.jl") end