Skip to content

Commit

Permalink
Speed up initialization (#1977)
Browse files Browse the repository at this point in the history
For lhm_2024_12_3 this speeds up initialization
(`Ribasim.Model(toml_path)`), from 234 seconds to 4 seconds. Large
models were especially slow to initialize since in `create_graph` we
were constructing NodeID structs for each node with a separate call to
the database.

This changes `get_node_ids` to return a `Vector{NodeID}` rather than a
`Vector{Int32}`. This means that the type and index of each ID is now
directly included, leading to simpler code.
  • Loading branch information
visr authored Dec 20, 2024
1 parent fc02cfb commit ac0920d
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 95 deletions.
2 changes: 1 addition & 1 deletion core/src/Ribasim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ using StructArrays: StructVector

# OrderedSet is used to store the order of the substances in the network.
# OrderedDict is used to store the order of the sources in a subnetwork.
using DataStructures: OrderedSet, OrderedDict
using DataStructures: OrderedSet, OrderedDict, counter, inc!

export libribasim

Expand Down
7 changes: 4 additions & 3 deletions core/src/graph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ and data of edges (EdgeMetadata):
[`EdgeMetadata`](@ref)
"""
function create_graph(db::DB, config::Config)::MetaGraph
node_table = get_node_ids(db)
node_rows = execute(
db,
"SELECT node_id, node_type, subnetwork_id FROM Node ORDER BY node_type, node_id",
Expand Down Expand Up @@ -40,7 +41,7 @@ function create_graph(db::DB, config::Config)::MetaGraph
graph_data = nothing,
)
for row in node_rows
node_id = NodeID(row.node_type, row.node_id, db)
node_id = NodeID(row.node_type, row.node_id, node_table)
# Process allocation network ID
if ismissing(row.subnetwork_id)
subnetwork_id = 0
Expand All @@ -63,8 +64,8 @@ function create_graph(db::DB, config::Config)::MetaGraph
catch
error("Invalid edge type $edge_type.")
end
id_src = NodeID(from_node_type, from_node_id, db)
id_dst = NodeID(to_node_type, to_node_id, db)
id_src = NodeID(from_node_type, from_node_id, node_table)
id_dst = NodeID(to_node_type, to_node_id, node_table)
edge_metadata =
EdgeMetadata(; id = edge_id, type = edge_type, edge = (id_src, id_dst))
if edge_type == EdgeType.flow
Expand Down
57 changes: 24 additions & 33 deletions core/src/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,16 @@ function NodeType.T(s::Symbol)::NodeType.T
end

NodeType.T(str::AbstractString) = NodeType.T(Symbol(str))
NodeType.T(x::NodeType.T) = x
Base.convert(::Type{NodeType.T}, x::String) = NodeType.T(x)
Base.convert(::Type{NodeType.T}, x::Symbol) = NodeType.T(x)

SQLite.esc_id(x::NodeType.T) = esc_id(string(x))

"""
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, idx::Int)
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, db::DB)
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, p::Parameters)
NodeID(type::Union{NodeType.T, Symbol, AbstractString}, value::Integer, node_ids::Vector{NodeID})
NodeID is a unique identifier for a node in the model, as well as an index into the internal node type struct.
Expand All @@ -52,42 +57,28 @@ This index can be passed directly, or calculated from the database or parameters
idx::Int
end

NodeID(type::Symbol, value::Integer, idx::Int) = NodeID(NodeType.T(type), value, idx)
NodeID(type::AbstractString, value::Integer, idx::Int) =
NodeID(NodeType.T(type), value, idx)

function NodeID(type::Union{Symbol, AbstractString}, value::Integer, db::DB)::NodeID
return NodeID(NodeType.T(type), value, db)
end

function NodeID(type::NodeType.T, value::Integer, db::DB)::NodeID
node_type_string = string(type)
# The index is equal to the number of nodes of the same type with a lower or equal ID
idx = only(
only(
execute(
columntable,
db,
"SELECT COUNT(*) FROM Node WHERE node_type == $(esc_id(node_type_string)) AND node_id <= $value",
),
),
)
if idx <= 0
error("Node ID #$value of type $type is not in the Node table.")
function NodeID(node_type, value::Integer, node_ids::Vector{NodeID})::NodeID
node_type = NodeType.T(node_type)
index = searchsortedfirst(node_ids, value; by = Int32)
if index == lastindex(node_ids) + 1
@error "Node ID $node_type #$value is not in the Node table."
error("Node ID not found")
end
node_id = node_ids[index]
if node_id.type !== node_type
@error "Requested node ID #$value is of type $(node_id.type), not $node_type"
error("Node ID is of the wrong type")
end
return NodeID(type, value, idx)
return node_id
end

function NodeID(value::Integer, db::DB)::NodeID
(idx, type) = execute(
columntable,
db,
"SELECT COUNT(*), node_type FROM Node WHERE node_type == (SELECT node_type FROM Node WHERE node_id == $value) AND node_id <= $value",
)
if only(idx) <= 0
error("Node ID #$value is not in the Node table.")
function NodeID(value::Integer, node_ids::Vector{NodeID})::NodeID
index = searchsortedfirst(node_ids, value; by = Int32)
if index == lastindex(node_ids) + 1
@error "Node ID #$value is not in the Node table."
error("Node ID not found")
end
return NodeID(only(type), value, only(idx))
return node_ids[index]
end

Base.Int32(id::NodeID) = id.value
Expand Down
99 changes: 71 additions & 28 deletions core/src/read.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ function parse_static_and_time(
# of the current type
vals_out = []

node_type_string = split(string(node_type), '.')[end]
ids = get_ids(db, node_type_string)
node_ids = NodeID.(node_type_string, ids, eachindex(ids))
node_type_string = String(split(string(node_type), '.')[end])
node_ids = get_node_ids(db, node_type_string)
ids = Int32.(node_ids)
n_nodes = length(node_ids)

# Initialize the vectors for the output
Expand Down Expand Up @@ -191,14 +191,14 @@ function static_and_time_node_ids(
db::DB,
static::StructVector,
time::StructVector,
node_type::String,
node_type::NodeType.T,
)::Tuple{Set{NodeID}, Set{NodeID}, Vector{NodeID}, Bool}
ids = get_ids(db, node_type)
node_ids = get_node_ids(db, node_type)
ids = Int32.(node_ids)
idx = searchsortedfirst.(Ref(ids), static.node_id)
static_node_ids = Set(NodeID.(Ref(node_type), static.node_id, idx))
idx = searchsortedfirst.(Ref(ids), time.node_id)
time_node_ids = Set(NodeID.(Ref(node_type), time.node_id, idx))
node_ids = NodeID.(Ref(node_type), ids, eachindex(ids))
doubles = intersect(static_node_ids, time_node_ids)
errors = false
if !isempty(doubles)
Expand Down Expand Up @@ -287,7 +287,7 @@ function TabulatedRatingCurve(
time = load_structvector(db, config, TabulatedRatingCurveTimeV1)

static_node_ids, time_node_ids, node_ids, valid =
static_and_time_node_ids(db, static, time, "TabulatedRatingCurve")
static_and_time_node_ids(db, static, time, NodeType.TabulatedRatingCurve)

if !valid
error(
Expand Down Expand Up @@ -418,7 +418,8 @@ function LevelBoundary(db::DB, config::Config)::LevelBoundary
time = load_structvector(db, config, LevelBoundaryTimeV1)
concentration_time = load_structvector(db, config, LevelBoundaryConcentrationV1)

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "LevelBoundary")
_, _, node_ids, valid =
static_and_time_node_ids(db, static, time, NodeType.LevelBoundary)

if !valid
error("Problems encountered when parsing LevelBoundary static and time node IDs.")
Expand Down Expand Up @@ -452,7 +453,8 @@ function FlowBoundary(db::DB, config::Config, graph::MetaGraph)::FlowBoundary
time = load_structvector(db, config, FlowBoundaryTimeV1)
concentration_time = load_structvector(db, config, FlowBoundaryConcentrationV1)

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "FlowBoundary")
_, _, node_ids, valid =
static_and_time_node_ids(db, static, time, NodeType.FlowBoundary)

if !valid
error("Problems encountered when parsing FlowBoundary static and time node IDs.")
Expand Down Expand Up @@ -567,8 +569,8 @@ function Outlet(db::DB, config::Config, graph::MetaGraph)::Outlet
end

function Terminal(db::DB, config::Config)::Terminal
node_id = get_ids(db, "Terminal")
return Terminal(NodeID.(NodeType.Terminal, node_id, eachindex(node_id)))
node_id = get_node_ids(db, NodeType.Terminal)
return Terminal(node_id)
end

function ConcentrationData(
Expand Down Expand Up @@ -662,7 +664,7 @@ function ConcentrationData(
end

function Basin(db::DB, config::Config, graph::MetaGraph)::Basin
node_id = get_ids(db, "Basin")
node_id = get_node_ids(db, NodeType.Basin)
n = length(node_id)

# both static and time are optional, but we need fallback defaults
Expand All @@ -683,9 +685,6 @@ function Basin(db::DB, config::Config, graph::MetaGraph)::Basin

vertical_flux = ComponentVector(; table...)

# Node IDs
node_id = NodeID.(NodeType.Basin, node_id, eachindex(node_id))

# Profiles
area, level = create_storage_tables(db, config)

Expand Down Expand Up @@ -742,9 +741,10 @@ function CompoundVariable(
weight::Float64,
look_ahead::Float64,
}[]
node_ids = get_node_ids(db)
# Each row defines a subvariable
for row in compound_variable_data
listen_node_id = NodeID(row.listen_node_id, db)
listen_node_id = NodeID(row.listen_node_id, node_ids)
# Placeholder until actual ref is known
variable_ref = PreallocationRef(placeholder_vector, 0)
variable = row.variable
Expand All @@ -757,7 +757,7 @@ function CompoundVariable(
end

# The ID of the node listening to this CompoundVariable
node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), db)
node_id = NodeID(node_type, only(unique(compound_variable_data.node_id)), node_ids)
return CompoundVariable(node_id, subvariables, greater_than)
end

Expand Down Expand Up @@ -811,8 +811,8 @@ function DiscreteControl(db::DB, config::Config, graph::MetaGraph)::DiscreteCont
condition = load_structvector(db, config, DiscreteControlConditionV1)
compound_variable = load_structvector(db, config, DiscreteControlVariableV1)

ids = get_ids(db, "DiscreteControl")
node_id = NodeID.(:DiscreteControl, ids, eachindex(ids))
node_id = get_node_ids(db, NodeType.DiscreteControl)
ids = Int32.(node_id)
compound_variables, valid =
parse_variables_and_conditions(compound_variable, condition, ids, db, graph)

Expand Down Expand Up @@ -913,8 +913,8 @@ end
function ContinuousControl(db::DB, config::Config, graph::MetaGraph)::ContinuousControl
compound_variable = load_structvector(db, config, ContinuousControlVariableV1)

ids = get_ids(db, "ContinuousControl")
node_id = NodeID.(:ContinuousControl, ids, eachindex(ids))
node_id = get_node_ids(db, NodeType.ContinuousControl)
ids = Int32.(node_id)

# Avoid using `function` as a variable name as that is recognized as a keyword
func, controlled_variable, errors = continuous_control_functions(db, config, ids)
Expand All @@ -940,7 +940,7 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl
static = load_structvector(db, config, PidControlStaticV1)
time = load_structvector(db, config, PidControlTimeV1)

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "PidControl")
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.PidControl)

if !valid
error("Problems encountered when parsing PidControl static and time node IDs.")
Expand Down Expand Up @@ -968,7 +968,8 @@ function PidControl(db::DB, config::Config, graph::MetaGraph)::PidControl
end
controlled_basins = collect(controlled_basins)

listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(db))
all_node_ids = get_node_ids(db)
listen_node_id = NodeID.(parsed_parameters.listen_node_id, Ref(all_node_ids))

return PidControl(;
node_id = node_ids,
Expand Down Expand Up @@ -1087,9 +1088,9 @@ function UserDemand(db::DB, config::Config, graph::MetaGraph)::UserDemand
static = load_structvector(db, config, UserDemandStaticV1)
time = load_structvector(db, config, UserDemandTimeV1)
concentration_time = load_structvector(db, config, UserDemandConcentrationV1)
ids = get_ids(db, "UserDemand")

_, _, node_ids, valid = static_and_time_node_ids(db, static, time, "UserDemand")
_, _, node_ids, valid = static_and_time_node_ids(db, static, time, NodeType.UserDemand)
ids = Int32.(node_ids)

if !valid
error("Problems encountered when parsing UserDemand static and time node IDs.")
Expand Down Expand Up @@ -1229,14 +1230,15 @@ end
function Subgrid(db::DB, config::Config, basin::Basin)::Subgrid
node_to_basin = Dict(node_id => index for (index, node_id) in enumerate(basin.node_id))
tables = load_structvector(db, config, BasinSubgridV1)
node_table = get_node_ids(db, NodeType.Basin)

subgrid_ids = Int32[]
basin_index = Int32[]
interpolations = ScalarInterpolation[]
has_error = false
for group in IterTools.groupby(row -> row.subgrid_id, tables)
subgrid_id = first(getproperty.(group, :subgrid_id))
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), db)
node_id = NodeID(NodeType.Basin, first(getproperty.(group, :node_id)), node_table)
basin_level = getproperty.(group, :basin_level)
subgrid_level = getproperty.(group, :subgrid_level)

Expand Down Expand Up @@ -1395,11 +1397,52 @@ function Parameters(db::DB, config::Config)::Parameters
return p
end

function get_ids(db::DB, nodetype)::Vector{Int32}
sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(nodetype)) ORDER BY node_id"
function get_node_ids_int32(db::DB, node_type)::Vector{Int32}
sql = "SELECT node_id FROM Node WHERE node_type = $(esc_id(node_type)) ORDER BY node_id"
return only(execute(columntable, db, sql))
end

function get_node_ids_types(
db::DB,
)::@NamedTuple{node_id::Vector{Int32}, node_type::Vector{NodeType.T}}
sql = "SELECT node_id, node_type FROM Node ORDER BY node_id"
table = execute(columntable, db, sql)
# convert from String to NodeType
node_type = NodeType.T.(table.node_type)
return (; table.node_id, node_type)
end

function get_node_ids(db::DB)::Vector{NodeID}
nt = get_node_ids_types(db)
node_ids = Vector{Ribasim.NodeID}(undef, length(nt.node_id))
count = counter(Ribasim.NodeType.T)
for (i, (; node_id, node_type)) in enumerate(Tables.rows(nt))
index = inc!(count, node_type)
node_ids[i] = NodeID(node_type, node_id, index)
end
return node_ids
end

# Convenience method for tests
function get_node_ids(toml_path::String)::Vector{NodeID}
cfg = Config(toml_path)
db_path = database_path(cfg)
db = SQLite.DB(db_path)
node_ids = get_node_ids(db)
close(db)
return node_ids
end

function get_node_ids(db::DB, node_type)::Vector{NodeID}
node_type = NodeType.T(node_type)
node_ints = get_node_ids_int32(db, node_type)
node_ids = Vector{Ribasim.NodeID}(undef, length(node_ints))
for (index, node_int) in enumerate(node_ints)
node_ids[index] = NodeID(node_type, node_int, index)
end
return node_ids
end

function exists(db::DB, tablename::String)
query = execute(
db,
Expand Down
8 changes: 4 additions & 4 deletions core/src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,11 +180,11 @@ Data is matched based on the node_id, which is sorted.
"""
function set_static_value!(
table::NamedTuple,
node_id::Vector{Int32},
node_id::Vector{NodeID},
static::StructVector,
)::NamedTuple
for (i, id) in enumerate(node_id)
idx = findsorted(static.node_id, id)
idx = findsorted(static.node_id, Int32(id))
idx === nothing && continue
row = static[idx]
set_table_row!(table, row, i)
Expand All @@ -199,7 +199,7 @@ The most recent applicable data is non-NaN data for a given ID that is on or bef
"""
function set_current_value!(
table::NamedTuple,
node_id::Vector{Int32},
node_id::Vector{NodeID},
time::StructVector,
t::DateTime,
)::NamedTuple
Expand All @@ -209,7 +209,7 @@ function set_current_value!(
for (i, id) in enumerate(node_id)
for (symbol, vector) in pairs(table)
idx = findlast(
row -> row.node_id == id && !ismissing(getproperty(row, symbol)),
row -> row.node_id == Int32(id) && !ismissing(getproperty(row, symbol)),
pre_table,
)
if idx !== nothing
Expand Down
Loading

0 comments on commit ac0920d

Please sign in to comment.