Skip to content

Commit

Permalink
Support conditions on linear combinations of variables for `DiscreteC…
Browse files Browse the repository at this point in the history
…ontrol` (#1371)

Fixes #1346.
  • Loading branch information
SouthEndMusic authored Apr 16, 2024
1 parent 624b33a commit f6108a0
Show file tree
Hide file tree
Showing 21 changed files with 492 additions and 134 deletions.
108 changes: 78 additions & 30 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,19 @@ function set_initial_discrete_controlled_parameters!(
(; p) = integrator
(; discrete_control) = p

n_conditions = length(discrete_control.condition_value)
n_conditions = sum(length(vec) for vec in discrete_control.condition_value; init = 0)
condition_diffs = zeros(Float64, n_conditions)
discrete_control_condition(condition_diffs, storage0, integrator.t, integrator)
discrete_control.condition_value .= (condition_diffs .> 0.0)

# Set the discrete control value (bool) per compound variable
idx_start = 1
for (compound_variable_idx, vec) in enumerate(discrete_control.condition_value)
l = length(vec)
idx_end = idx_start + l - 1
discrete_control.condition_value[compound_variable_idx] .=
(condition_diffs[idx_start:idx_end] .> 0)
idx_start += l
end

# For every discrete_control node find a condition_idx it listens to
for discrete_control_node_id in unique(discrete_control.node_id)
Expand Down Expand Up @@ -78,7 +87,7 @@ function create_callbacks(

saved = SavedResults(saved_flow, saved_vertical_flux, saved_subgrid_level)

n_conditions = length(discrete_control.node_id)
n_conditions = sum(length(vec) for vec in discrete_control.greater_than; init = 0)
if n_conditions > 0
discrete_control_cb = VectorContinuousCallback(
discrete_control_condition,
Expand Down Expand Up @@ -183,18 +192,27 @@ Listens for changes in condition truths.
function discrete_control_condition(out, u, t, integrator)
(; p) = integrator
(; discrete_control) = p

for (i, (listen_node_id, variable, greater_than, look_ahead)) in enumerate(
zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.greater_than,
discrete_control.look_ahead,
),
condition_idx = 0

# Loop over compound variables
for (listen_node_ids, variables, weights, greater_thans, look_aheads) in zip(
discrete_control.listen_node_id,
discrete_control.variable,
discrete_control.weight,
discrete_control.greater_than,
discrete_control.look_ahead,
)
value = get_value(p, listen_node_id, variable, look_ahead, u, t)
diff = value - greater_than
out[i] = diff
value = 0.0
for (listen_node_id, variable, weight, look_ahead) in
zip(listen_node_ids, variables, weights, look_aheads)
value += weight * get_value(p, listen_node_id, variable, look_ahead, u, t)
end
# Loop over greater_than values for this compound_variable
for greater_than in greater_thans
condition_idx += 1
diff = value - greater_than
out[condition_idx] = diff
end
end
end

Expand Down Expand Up @@ -252,7 +270,9 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

condition_value[condition_idx] = true
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = true

control_state_change = discrete_control_affect!(integrator, condition_idx, true)

Expand All @@ -262,19 +282,24 @@ function discrete_control_affect_upcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
is_basin = id_index(basin.node_id, discrete_control.listen_node_id[condition_idx])[1]
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

# NOTE: The above no longer works when listen feature ids can be something other than node ids
# I think the more durable option is to give all possible condition types a different variable string,
# e.g. basin.level and level_boundary.level
if variable[condition_idx] == "level" && control_state_change && is_basin
if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
_, condition_basin_idx = id_index(basin.node_id, listen_node_id[condition_idx])
_, condition_basin_idx =
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if du[condition_basin_idx] < 0.0
condition_value[condition_idx] = false
condition_value[compound_variable_idx][greater_than_idx] = false
discrete_control_affect!(integrator, condition_idx, false)
end
end
Expand All @@ -288,7 +313,9 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
(; discrete_control, basin) = p
(; variable, condition_value, listen_node_id) = discrete_control

condition_value[condition_idx] = false
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
condition_value[compound_variable_idx][greater_than_idx] = false

control_state_change = discrete_control_affect!(integrator, condition_idx, false)

Expand All @@ -298,16 +325,23 @@ function discrete_control_affect_downcrossing!(integrator, condition_idx)
# only possibly the du. Parameter changes can change the flow on an edge discontinuously,
# giving the possibility of logical paradoxes where certain parameter changes immediately
# undo the truth state that caused that parameter change.
if variable[condition_idx] == "level" && control_state_change
compound_variable_idx, greater_than_idx =
get_discrete_control_indices(discrete_control, condition_idx)
listen_node_ids = discrete_control.listen_node_id[compound_variable_idx]
is_basin =
length(listen_node_ids) == 1 ? id_index(basin.node_id, only(listen_node_ids))[1] :
false

if variable[compound_variable_idx][1] == "level" && control_state_change && is_basin
# Calling water_balance is expensive, but it is a sure way of getting
# du for the basin of this level condition
du = zero(u)
water_balance!(du, u, p, t)
has_index, condition_basin_idx =
id_index(basin.node_id, listen_node_id[condition_idx])
id_index(basin.node_id, listen_node_id[compound_variable_idx][1])

if has_index && du[condition_basin_idx] > 0.0
condition_value[condition_idx] = true
condition_value[compound_variable_idx][greater_than_idx] = true
discrete_control_affect!(integrator, condition_idx, true)
end
end
Expand All @@ -325,20 +359,34 @@ function discrete_control_affect!(
(; discrete_control, graph) = p

# Get the discrete_control node that listens to this condition
discrete_control_node_id = discrete_control.node_id[condition_idx]

compound_variable_idx, _ = get_discrete_control_indices(discrete_control, condition_idx)
discrete_control_node_id = discrete_control.node_id[compound_variable_idx]

# Get the indices of all conditions that this control node listens to
condition_ids = discrete_control.node_id .== discrete_control_node_id
where_node_id = searchsorted(discrete_control.node_id, discrete_control_node_id)

# Get the truth state for this discrete_control node
truth_values = [ifelse(b, "T", "F") for b in discrete_control.condition_value]
truth_state = join(truth_values[condition_ids], "")
truth_values = cat(
[
[ifelse(b, "T", "F") for b in discrete_control.condition_value[i]] for
i in where_node_id
]...;
dims = 1,
)
truth_state = join(truth_values, "")

# Get the truth specific about the latest crossing
if !ismissing(upcrossing)
truth_values[condition_idx] = upcrossing ? "U" : "D"
truth_value_idx =
condition_idx - sum(
length(vec) for
vec in discrete_control.condition_value[1:(where_node_id.start - 1)];
init = 0,
)
truth_values[truth_value_idx] = upcrossing ? "U" : "D"
end
truth_state_crossing_specific = join(truth_values[condition_ids], "")
truth_state_crossing_specific = join(truth_values, "")

# What the local control state should be
control_state_new =
Expand All @@ -359,7 +407,7 @@ function discrete_control_affect!(
discrete_control.logic_mapping[(discrete_control_node_id, truth_state)]
else
error(
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for DiscreteControl node $discrete_control_node_id.",
"Control state specified for neither $truth_state_crossing_specific nor $truth_state for $discrete_control_node_id.",
)
end

Expand Down
27 changes: 16 additions & 11 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -443,23 +443,28 @@ struct Terminal <: AbstractParameterNode
end

"""
node_id: node ID of the DiscreteControl node; these are not unique but repeated
by the amount of conditions of this DiscreteControl node
listen_node_id: the ID of the node being condition on
variable: the name of the variable in the condition
greater_than: The threshold value in the condition
condition_value: The current value of each condition
node_id: node ID of the DiscreteControl node per compound variable (can contain repeats)
listen_node_id: the IDs of the nodes being condition on per compound variable
variable: the names of the variables in the condition per compound variable
weight: the weight of the variables in the condition per compound variable
look_ahead: the look ahead of variables in the condition in seconds per compound_variable
greater_than: The threshold values per compound variable
condition_value: The current truth value of each condition per compound_variable per greater_than
control_state: Dictionary: node ID => (control state, control state start)
logic_mapping: Dictionary: (control node ID, truth state) => control state
record: Namedtuple with discrete control information for results
"""
struct DiscreteControl <: AbstractParameterNode
node_id::Vector{NodeID}
listen_node_id::Vector{NodeID}
variable::Vector{String}
look_ahead::Vector{Float64}
greater_than::Vector{Float64}
condition_value::Vector{Bool}
# Definition of compound variables
listen_node_id::Vector{Vector{NodeID}}
variable::Vector{Vector{String}}
weight::Vector{Vector{Float64}}
look_ahead::Vector{Vector{Float64}}
# Definition of conditions (one or more greater_than per compound variable)
greater_than::Vector{Vector{Float64}}
condition_value::Vector{BitVector}
# Definition of logic
control_state::Dict{NodeID, Tuple{String, Float64}}
logic_mapping::Dict{Tuple{NodeID, String}, String}
record::@NamedTuple{
Expand Down
84 changes: 77 additions & 7 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,81 @@ function Basin(db::DB, config::Config, chunk_sizes::Vector{Int})::Basin
)
end

function parse_variables_and_conditions(compound_variable, condition)
node_id = NodeID[]
listen_node_id = Vector{NodeID}[]
variable = Vector{String}[]
weight = Vector{Float64}[]
look_ahead = Vector{Float64}[]
greater_than = Vector{Float64}[]
condition_value = BitVector[]
errors = false

# Loop over unique discrete_control node IDs (on which at least one condition is defined)
for id in unique(condition.node_id)
condition_group_id = filter(row -> row.node_id == id, condition)
variable_group_id = filter(row -> row.node_id == id, compound_variable)
# Loop over compound variables for this node ID
for compound_variable_id in unique(condition_group_id.compound_variable_id)
condition_group_variable = filter(
row -> row.compound_variable_id == compound_variable_id,
condition_group_id,
)
variable_group_variable = filter(
row -> row.compound_variable_id == compound_variable_id,
variable_group_id,
)
discrete_control_id = NodeID(NodeType.DiscreteControl, id)
if isempty(variable_group_variable)
errors = true
@error "compound_variable_id $compound_variable_id for $discrete_control_id in condition table but not in variable table"
else
push!(node_id, discrete_control_id)
push!(
listen_node_id,
NodeID.(
variable_group_variable.listen_node_type,
variable_group_variable.listen_node_id,
),
)
push!(variable, variable_group_variable.variable)
push!(weight, coalesce.(variable_group_variable.weight, 1.0))
push!(look_ahead, coalesce.(variable_group_variable.look_ahead, 0.0))
push!(greater_than, condition_group_variable.greater_than)
push!(
condition_value,
BitVector(zeros(length(condition_group_variable.greater_than))),
)
end
end
end
return node_id,
listen_node_id,
variable,
weight,
look_ahead,
greater_than,
condition_value,
!errors
end

function DiscreteControl(db::DB, config::Config)::DiscreteControl
condition = load_structvector(db, config, DiscreteControlConditionV1)
compound_variable = load_structvector(db, config, DiscreteControlVariableV1)

node_id,
listen_node_id,
variable,
weight,
look_ahead,
greater_than,
condition_value,
valid = parse_variables_and_conditions(compound_variable, condition)

if !valid
error("Problems encountered when parsing DiscreteControl variables and conditions.")
end

condition_value = fill(false, length(condition.node_id))
control_state::Dict{NodeID, Tuple{String, Float64}} = Dict()

rows = execute(db, "SELECT from_node_id, edge_type FROM Edge ORDER BY fid")
Expand All @@ -557,7 +628,6 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl
end

logic = load_structvector(db, config, DiscreteControlLogicV1)

logic_mapping = Dict{Tuple{NodeID, String}, String}()

for (node_id, truth_state, control_state_) in
Expand All @@ -567,7 +637,6 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl
end

logic_mapping = expand_logic_mapping(logic_mapping)
look_ahead = coalesce.(condition.look_ahead, 0.0)

record = (
time = Float64[],
Expand All @@ -577,11 +646,12 @@ function DiscreteControl(db::DB, config::Config)::DiscreteControl
)

return DiscreteControl(
NodeID.(NodeType.DiscreteControl, condition.node_id), # Not unique
NodeID.(condition.listen_node_type, condition.listen_node_id),
condition.variable,
node_id, # Not unique
listen_node_id,
variable,
weight,
look_ahead,
condition.greater_than,
greater_than,
condition_value,
control_state,
logic_mapping,
Expand Down
12 changes: 10 additions & 2 deletions core/src/schema.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# These schemas define the name of database tables and the configuration file structure
# The identifier is parsed as ribasim.nodetype.kind, no capitals or underscores are allowed.
@schema "ribasim.discretecontrol.variable" DiscreteControlVariable
@schema "ribasim.discretecontrol.condition" DiscreteControlCondition
@schema "ribasim.discretecontrol.logic" DiscreteControlLogic
@schema "ribasim.basin.static" BasinStatic
Expand Down Expand Up @@ -183,15 +184,22 @@ end
node_id::Int32
end

@version DiscreteControlConditionV1 begin
@version DiscreteControlVariableV1 begin
node_id::Int32
compound_variable_id::Int32
listen_node_type::String
listen_node_id::Int32
variable::String
greater_than::Float64
weight::Union{Missing, Float64}
look_ahead::Union{Missing, Float64}
end

@version DiscreteControlConditionV1 begin
node_id::Int32
compound_variable_id::Int32
greater_than::Float64
end

@version DiscreteControlLogicV1 begin
node_id::Int32
truth_state::String
Expand Down
Loading

0 comments on commit f6108a0

Please sign in to comment.