diff --git a/src/ribasim_nl/ribasim_nl/model.py b/src/ribasim_nl/ribasim_nl/model.py index bce94e4..5510fea 100644 --- a/src/ribasim_nl/ribasim_nl/model.py +++ b/src/ribasim_nl/ribasim_nl/model.py @@ -72,6 +72,7 @@ def df(self) -> pd.DataFrame: class Model(Model): _basin_results: BasinResults | None = None + _graph: nx.Graph | None = None @property def basin_results(self): @@ -82,12 +83,48 @@ def basin_results(self): @property def graph(self): - return nx.from_pandas_edgelist(self.edge.df[["from_node_id", "to_node_id"]], "from_node_id", "to_node_id") + # create a DiGraph from edge-table + if self._graph is None: + self._graph = nx.from_pandas_edgelist( + df=self.edge.df[["from_node_id", "to_node_id"]], + source="from_node_id", + target="to_node_id", + create_using=nx.DiGraph, + ) + return self._graph @property def next_node_id(self): return self.node_table().df.index.max() + 1 + # methods relying on networkx. Discuss making this all in a subclass of Model + def _upstream_nodes(self, node_id): + # get upstream nodes + return list(nx.traversal.bfs_tree(self.graph, node_id, reverse=True)) + + def _downstream_nodes(self, node_id): + # get downstream nodes + return list(nx.traversal.bfs_tree(self.graph, node_id)) + + def get_upstream_basins(self, node_id): + # get upstream basin area + upstream_node_ids = self._upstream_nodes(node_id) + return self.basin.area.df[self.basin.area.df.node_id.isin(upstream_node_ids)] + + def get_upstream_edges(self, node_id): + # get upstream edges + upstream_node_ids = self._upstream_nodes(node_id) + mask = self.edge.df.from_node_id.isin(upstream_node_ids[1:]) & self.edge.df.to_node_id.isin(upstream_node_ids) + return self.edge.df[mask] + + def get_downstream_edges(self, node_id): + # get upstream edges + downstream_node_ids = self._downstream_nodes(node_id) + mask = self.edge.df.from_node_id.isin(downstream_node_ids) & self.edge.df.to_node_id.isin( + downstream_node_ids[1:] + ) + return self.edge.df[mask] + def find_node_id(self, ds_node_id=None, us_node_id=None, **kwargs) -> int: """Find a node_id by it's properties""" # get node table