Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

get state dependencies, observed -> SymbolicIndexingInterface #2071

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@
Manifest.toml
.vscode
.vscode/*
scratch
scratch/*
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
Expand Down
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pages = [
"basics/Events.md",
"basics/Linearization.md",
"basics/Validation.md",
"basics/saving_syms.md"
"basics/DependencyGraphs.md",
"basics/FAQ.md"],
"System Types" => Any["systems/ODESystem.md",
Expand Down
47 changes: 47 additions & 0 deletions docs/src/basics/saving_syms.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# [Saving Only Certain Symbols](@id sym_save_idxs)

It is possible to specify symbolically which states of an ODESystem to save, with the `save_idxs` keyword argument
to the `solve` call. This may be important to you to save memory for large problems.

```julia
@parameters t
@variables a(t) b(t) c(t) d(t) e(t)

D = Differential(t)

eqs = [D(a) ~ a,
D(b) ~ b,
D(c) ~ c,
D(d) ~ d,
e ~ d]

@named sys = ODESystem(eqs, t, [a, b, c, d, e], [];
defaults = Dict([a => 1.0,
b => 1.0,
c => 1.0,
d => 1.0,
e => 1.0]))
sys = structural_simplify(sys)
prob = ODEProblem(sys, [], (0, 1.0))
prob_sym = ODEProblem(sys, [], (0, 1.0))

sol = solve(prob, Tsit5())
sol_sym = solve(prob_sym, Tsit5(), save_idxs = [a, c, e])

@test sol_sym[a] ≈ sol[a]
@test sol_sym[c] ≈ sol[c]
@test sol_sym[d] ≈ sol[d] # `d` is automatically saved too, as `e` depends on it.
@test sol_sym[e] ≈ sol[e]

sola = sol[a]
sole = sol[e]
solc = sol[c]

expr1 = @. sola^2 + sole^2 - sin(solc) + 1

@test sol_sym[a^2 + e^2 - sin(c) + 1] ≈ expr1 # indexing with a symbolic expr, these can be quite complex.

@test sol.u != sol_sym.u

@test_throws Exception sol_sym[b] # `b` is not saved.
```
11 changes: 9 additions & 2 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,15 @@ RuntimeGeneratedFunctions.init(@__MODULE__)
using RecursiveArrayTools

import SymbolicIndexingInterface
import SymbolicIndexingInterface: independent_variables, states, parameters
export independent_variables, states, parameters

const SII = SymbolicIndexingInterface

import SymbolicIndexingInterface: independent_variables, states, state_sym_to_index,
parameters, observed,
observed_sym_to_index, get_state_dependencies,
get_deps_of_observed, is_observed_sym, unknown_states,
convert_to_getindex, is_symbolic_expr, safe_unwrap
export independent_variables, states, parameters, observed
import SymbolicUtils
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
Symbolic, isadd, ismul, ispow, issym, FnType,
Expand Down
2 changes: 2 additions & 0 deletions src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ using SparseArrays

using SimpleNonlinearSolve

import SymbolicIndexingInterface: get_deps_of_observed

export tearing, partial_state_selection, dae_index_lowering, check_consistency
export dummy_derivative
export build_torn_function, build_observed_function, ODAEProblem
Expand Down
14 changes: 10 additions & 4 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ function build_torn_function(sys;
jacobian_sparsity = true,
checkbounds = false,
max_inlining_size = nothing,
dense_output = true,
kw...)
max_inlining_size = something(max_inlining_size, MAX_INLINE_NLSOLVE_SIZE)
rhss = []
Expand Down Expand Up @@ -333,6 +334,7 @@ function build_torn_function(sys;
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
is_solver_state_idxs, assignments, deps,
sol_states, var2assignment,
dense_output = dense_output,
checkbounds = checkbounds)
end
if args === ()
Expand Down Expand Up @@ -388,7 +390,8 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
var2assignment;
expression = false,
output_type = Array,
checkbounds = true)
checkbounds = true,
dense_output = true)
is_not_prepended_assignment = trues(length(assignments))
if (isscalar = !(ts isa AbstractVector))
ts = [ts]
Expand Down Expand Up @@ -480,6 +483,9 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
solves = []
end

unknown_state_deps = dense_output ? unknown_states :
get_deps_of_observed(unknown_states, obs)

subs = []
for sym in vars
eqidx = get(observed_idx, sym, nothing)
Expand All @@ -490,7 +496,7 @@ function build_observed_function(state, ts, var_eq_matching, var_sccs,
cpre = get_preprocess_constants([obs[1:maxidx];
isscalar ? ts[1] : MakeArray(ts, output_type)])
pre2 = x -> pre(cpre(x))
ex = Code.toexpr(Func([DestructuredArgs(unknown_states, inbounds = !checkbounds)
ex = Code.toexpr(Func([DestructuredArgs(unknown_state_deps, inbounds = !checkbounds)
DestructuredArgs(parameters(sys), inbounds = !checkbounds)
independent_variables(sys)],
[],
Expand Down Expand Up @@ -544,11 +550,11 @@ function ODAEProblem{iip}(sys,

has_difference = any(isdifferenceeq, eqs)
cbs = process_events(sys; callback, has_difference, kwargs...)

kwargs = filter_kwargs(kwargs)
if cbs === nothing
ODEProblem{iip}(fun, u0, tspan, p; kwargs...)
else
ODEProblem{iip}(fun, u0, tspan, p; callback = cbs, kwargs...)
ODEProblem{iip}(fun, u0, tspan, p; callback = cbs,
kwargs...)
end
end
166 changes: 138 additions & 28 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ function independent_variable(sys::AbstractSystem)
end

#Treat the result as a vector of symbols always
function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
function SII.independent_variables(sys::AbstractSystem)
systype = typeof(sys)
@warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
if isdefined(sys, :iv)
Expand All @@ -173,11 +173,11 @@ function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
end
end

function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem)
function SII.independent_variables(sys::AbstractTimeDependentSystem)
[getfield(sys, :iv)]
end
SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = []
function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem)
SII.independent_variables(sys::AbstractTimeIndependentSystem) = []
function SII.independent_variables(sys::AbstractMultivariateSystem)
getfield(sys, :ivs)
end

Expand Down Expand Up @@ -512,15 +512,22 @@ function namespace_expr(O, sys, n = nameof(sys))
end
end

function states(sys::AbstractSystem)
function SII.states(sys::AbstractSystem)
sts = get_states(sys)
systems = get_systems(sys)
unique(isempty(systems) ?
sts :
[sts; reduce(vcat, namespace_variables.(systems))])
end

function SymbolicIndexingInterface.parameters(sys::AbstractSystem)
"""
$(SIGNATURES)

Return a list of actual states needed to be solved by solvers.
"""
SII.unknown_states(sys::AbstractSystem) = SII.states(sys)

function SII.parameters(sys::AbstractSystem)
ps = get_ps(sys)
systems = get_systems(sys)
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
Expand All @@ -532,7 +539,7 @@ function controls(sys::AbstractSystem)
isempty(systems) ? ctrls : [ctrls; reduce(vcat, namespace_controls.(systems))]
end

function observed(sys::AbstractSystem)
function SII.observed(sys::AbstractSystem)
obs = get_observed(sys)
systems = get_systems(sys)
[obs;
Expand Down Expand Up @@ -615,7 +622,7 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
# than pass in a value in place of x(t).
#
# This is done by just making `x` the argument of the function.
# This is*-+ done by just making `x` the argument of the function.
if istree(x) &&
issym(operation(x)) &&
!(length(arguments(x)) == 1 && isequal(arguments(x)[1], get_iv(sys)))
Expand All @@ -624,35 +631,138 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
return x
end

SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))
###
### SymbolicIndexingInterface
###

"""
$(SIGNATURES)
SII.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))

Return a list of actual states needed to be solved by solvers.
"""
function unknown_states(sys::AbstractSystem)
sts = states(sys)
if has_unknown_states(sys)
sts = something(get_unknown_states(sys), sts)
function SII.state_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(SII.safe_unwrap(sym)), SII.unknown_states(sys)) |> safe_unwrap
end
function SII.is_state_sym(sys::AbstractSystem, sym)
!isnothing(SII.state_sym_to_index(sys, sym))
end

function SII.param_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal((SII.safe_unwrap(sym))), parameters(sys)) |> safe_unwrap
end
function SII.is_param_sym(sys::AbstractSystem, sym)
!isnothing(SII.param_sym_to_index(sys, sym))
end

function SII.observed_sym_to_index(sys::AbstractSystem, sym)
findfirst(o -> isequal(SII.safe_unwrap(sym), o.lhs), observed(sys)) |>
safe_unwrap
end
function SII.observed_sym_to_index(sys::AbstractSystem, sym::Equation)
findfirst(isequal(SII.safe_unwrap(sym)), observed(sys)) |> safe_unwrap
end
function SII.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym)
findfirst(o -> isequal(SII.safe_unwrap(sym), o.lhs), obs) |>
safe_unwrap
end
function SII.observed_sym_to_index(obs::AbstractArray{<:Equation}, sym::Equation)
findfirst(isequal(SII.safe_unwrap(sym)), obs) |> safe_unwrap
end
function SII.is_observed_sym(sys::AbstractSystem, sym)
!isnothing(SII.observed_sym_to_index(sys, sym))
end

function SII.get_deps_of_observed(sts, obs::AbstractArray{<:Equation})
deps = mapreduce(vcat, obs, init = []) do eq
get_state_dependencies(sts, obs, eq.lhs)
end |> unique

return deps
end

# ! this is broken vvvvvv
function SII.get_state_dependencies(sts, obs, sym)
sym = SII.safe_unwrap(sym)
i = observed_sym_to_index(obs, sym)
if isnothing(i)
return []
end
return sts

eq = obs[i]
varss = vars(eq.rhs)
out = mapreduce(vcat, varss, init = []) do u
if !isnothing(observed_sym_to_index(obs, u))
get_state_dependencies(sts, obs, u)
else
[u]
end
end |> unique

return filter(x -> any(isequal(x), sts), out)
end

function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), unknown_states(sys))
function SII.get_state_dependencies(sys::AbstractSystem, sym)
obs = SII.observed(sys)
sts = SII.unknown_states(sys)
return get_state_dependencies(sts, obs, sym)
end

function SII.get_observed_dependencies(sys::AbstractSystem, sym)
obs = SII.observed(sys)

i = observed_sym_to_index(sys, sym)
if isnothing(i)
return []
end

eq = obs[i]
varss = vars(eq.rhs)
out = mapreduce(vcat, varss, init = []) do u
if is_observed_sym(sys, u)
[u]
else
[]
end
end |> unique

return out
end
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))

function operations(ex)
if istree(ex)
op = operation(ex)
return vcat(operations.(arguments(ex))..., op)
end
return []
end

function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys))
"""
Must be defined here because it needs Symbolics
"""
function SII.convert_to_getindex(A::SciMLBase.AbstractSciMLSolution, expr, is...)
expr = scalarize(expr)
ex_vars = vars(expr)

var_rules = [@rule v => A[v, is...] for v in ex_vars]
ex_ops = operations(expr)
ignore = vcat(operation.(ex_vars), getindex)
ex_ops = filter(x -> !any(isequal(x), ignore), ex_ops)
op_rules = [@rule $(op)(~~a) => broadcast(op, ~a...) for op in ex_ops]

ch = Postwalk(Chain([var_rules; op_rules]))
ex = ch(expr)
return ex
end
function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym))

function SII.is_symbolic_expr(ex::SymbolicUtils.Symbolic)
ex_vars = vars(ex)
if istree(ex)
return !any(isequal(ex), ex_vars)
end
return false
end

###
### System to Expr
###

struct AbstractSysToExpr
sys::AbstractSystem
states::Vector
Expand Down Expand Up @@ -1249,7 +1359,7 @@ function linearization_function(sys::AbstractSystem, inputs,
input_idxs = input_idxs,
sts = states(sys),
fun = ODEFunction{true, SciMLBase.FullSpecialize}(sys),
h = build_explicit_observed_function(sys, outputs),
h = build_explicit_observed_function(sys, outputs; dense_output = true),
chunk = ForwardDiff.Chunk(input_idxs)

function (u, p, t)
Expand All @@ -1259,7 +1369,7 @@ function linearization_function(sys::AbstractSystem, inputs,
uf = SciMLBase.UJacobianWrapper(fun, t, p)
fg_xz = ForwardDiff.jacobian(uf, u)
h_xz = ForwardDiff.jacobian(let p = p, t = t
xz -> h(xz, p, t)
(xz) -> h(xz, p, t)
end, u)
pf = SciMLBase.ParamJacobianWrapper(fun, t, u)
fg_u = jacobian_wrt_vars(pf, p, input_idxs, chunk)
Expand Down
Loading