Skip to content

Commit

Permalink
Run allocation first after BMI stop (#1390)
Browse files Browse the repository at this point in the history
This allows allocation over period `(t, t + dt)` to use variables set
over BMI at time `t`.

We no longer run allocation as a callback, but call it ourselves, such
that we can control that BMI runs allocation first before running the
physical layer.
  • Loading branch information
visr authored Apr 16, 2024
1 parent 5e46eae commit 184827a
Show file tree
Hide file tree
Showing 7 changed files with 67 additions and 27 deletions.
1 change: 1 addition & 0 deletions core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ using SciMLBase:
solve!,
step!,
SciMLBase,
ReturnCode,
successful_retcode,
CallbackSet,
ODEFunction,
Expand Down
7 changes: 3 additions & 4 deletions core/src/bmi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,15 @@ function BMI.update(model::Model)::Model
return model
end

function BMI.update_until(model::Model, time)::Model
integrator = model.integrator
t = integrator.t
function BMI.update_until(model::Model, time::Float64)::Model
(; t) = model.integrator
dt = time - t
if dt < 0
error("The model has already passed the given timestamp.")
elseif dt == 0
return model
else
step!(integrator, dt, true)
step!(model, dt)
end
return model
end
Expand Down
10 changes: 0 additions & 10 deletions core/src/callback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ function create_callbacks(
)
push!(callbacks, tabulated_rating_curve_cb)

if config.allocation.use_allocation
allocation_cb = PeriodicCallback(
update_allocation!,
config.allocation.timestep;
initial_affect = false,
save_positions = (false, false),
)
push!(callbacks, allocation_cb)
end

# If saveat is a vector which contains 0.0 this callback will still be called
# at t = 0.0 despite save_start = false
saveat = saveat isa Vector ? filter(x -> x != 0.0, saveat) : saveat
Expand Down
7 changes: 3 additions & 4 deletions core/src/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,21 @@ function main(ARGS::Vector{String})::Cint
config = Config(arg)
mkpath(results_path(config, "."))
open(results_path(config, "ribasim.log"), "w") do io
logger =
Ribasim.setup_logger(; verbosity = config.logging.verbosity, stream = io)
logger = setup_logger(; verbosity = config.logging.verbosity, stream = io)
with_logger(logger) do
cli = (; ribasim_version = string(pkgversion(Ribasim)))
(; starttime, endtime) = config
if config.ribasim_version != cli.ribasim_version
@warn "The Ribasim version in the TOML config file does not match the used Ribasim CLI version." config.ribasim_version cli.ribasim_version
end
@info "Starting a Ribasim simulation." cli.ribasim_version starttime endtime
model = Ribasim.run(config)
model = run(config)
if successful_retcode(model)
@info "The model finished successfully"
return 0
end

t = Ribasim.datetime_since(model.integrator.t, starttime)
t = datetime_since(model.integrator.t, starttime)
retcode = model.integrator.sol.retcode
@error "The model exited at model time $t with return code $retcode.\nSee https://docs.sciml.ai/DiffEqDocs/stable/basics/solution/#retcodes"
return 1
Expand Down
56 changes: 48 additions & 8 deletions core/src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ end

function Model(config_path::AbstractString)::Model
config = Config(config_path)
if !valid_config(config)
error("Invalid configuration in TOML.")
end
return Model(config)
end

Expand All @@ -44,11 +47,6 @@ function Model(config::Config)::Model
TimerOutputs.enable_debug_timings(Ribasim) # causes recompilation (!)
end

t_end = seconds_since(config.endtime, config.starttime)
if t_end <= 0
error("Model starttime is not before endtime.")
end

# All data from the database that we need during runtime is copied into memory,
# so we can directly close it again.
db = SQLite.DB(db_path)
Expand Down Expand Up @@ -114,6 +112,7 @@ function Model(config::Config)::Model
integral = zeros(length(parameters.pid_control.node_id))
u0 = ComponentVector{Float64}(; storage, integral)
# for Float32 this method allows max ~1000 year simulations without accuracy issues
t_end = seconds_since(config.endtime, config.starttime)
@assert eps(t_end) < 3600 "Simulation time too long"
t0 = zero(t_end)
timespan = (t0, t_end)
Expand Down Expand Up @@ -188,10 +187,51 @@ function SciMLBase.successful_retcode(model::Model)::Bool
end

"""
solve!(model::Model)::ODESolution
step!(model::Model, dt::Float64)::Model
Take Model timesteps until `t + dt` is reached exactly.
"""
function SciMLBase.step!(model::Model, dt::Float64)::Model
(; config, integrator) = model
(; t) = integrator
# If we are at an allocation time, run allocation before the next physical
# layer timestep. This allows allocation over period (t, t + dt) to use variables
# set over BMI at time t before calling this function.
# Also, don't run allocation at t = 0 since there are no flows yet (#1389).
ntimes = t / config.allocation.timestep
if t > 0 && round(ntimes) ntimes
update_allocation!(integrator)
end
step!(integrator, dt, true)
return model
end

"""
solve!(model::Model)::Model
Solve a Model until the configured `endtime`.
"""
function SciMLBase.solve!(model::Model)::ODESolution
return solve!(model.integrator)
function SciMLBase.solve!(model::Model)::Model
(; config, integrator) = model
if config.allocation.use_allocation
(; tspan) = integrator.sol.prob
(; timestep) = config.allocation
allocation_times = timestep:timestep:(tspan[end] - timestep)
n_allocation_times = length(allocation_times)
# Don't run allocation at t = 0 since there are no flows yet (#1389).
step!(integrator, timestep, true)
for _ in 1:n_allocation_times
update_allocation!(integrator)
step!(integrator, timestep, true)
end

if integrator.sol.retcode != ReturnCode.Default
return model
end
# TODO replace with `check_error!` https://github.com/SciML/SciMLBase.jl/issues/669
integrator.sol = SciMLBase.solution_new_retcode(integrator.sol, ReturnCode.Success)
else
solve!(integrator)
end
return model
end
2 changes: 1 addition & 1 deletion core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ function get_Δt(integrator)::Float64
elseif isinf(saveat)
t
else
t_end = integrator.sol.prob.tspan[2]
t_end = integrator.sol.prob.tspan[end]
if t_end - t > saveat
saveat
else
Expand Down
11 changes: 11 additions & 0 deletions core/src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,17 @@ function sorted_table!(
return table
end

function valid_config(config::Config)::Bool
errors = false

if config.starttime >= config.endtime
errors = true
@error "The model starttime must be before the endtime."
end

return !errors
end

"""
Test for each node given its node type whether the nodes that
# are downstream ('down-edge') of this node are of an allowed type
Expand Down

0 comments on commit 184827a

Please sign in to comment.