From f0ad9e093b3b97f5925d8201a623eb3c607ab2c3 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Wed, 13 Mar 2024 11:31:22 +0100 Subject: [PATCH 1/5] Re-introduce database context and split Node tables per Node. --- python/ribasim/ribasim/config.py | 7 ++++++- python/ribasim/ribasim/geometry/node.py | 7 +++++++ python/ribasim/ribasim/model.py | 7 ++++--- python/ribasim/tests/test_model.py | 2 +- 4 files changed, 18 insertions(+), 5 deletions(-) diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index e28c73ed2..393ffd548 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -5,7 +5,7 @@ import pandas as pd import pydantic from geopandas import GeoDataFrame -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator from shapely.geometry import Point from ribasim.geometry import BasinAreaSchema, NodeTable @@ -108,6 +108,11 @@ class MultiNodeModel(NodeModel): node: NodeTable = Field(default_factory=NodeTable) _node_type: str + @model_validator(mode="after") + def filter(self) -> "MultiNodeModel": + self.node.filter(self.__class__.__name__) + return self + def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> None: if tables is None: tables = [] diff --git a/python/ribasim/ribasim/geometry/node.py b/python/ribasim/ribasim/geometry/node.py index 72f610745..6999e917e 100644 --- a/python/ribasim/ribasim/geometry/node.py +++ b/python/ribasim/ribasim/geometry/node.py @@ -31,6 +31,13 @@ class Config: class NodeTable(SpatialTableModel[NodeSchema]): """The Ribasim nodes as Point geometries.""" + def filter(self, nodetype: str): + """Filter the node table based on the node type.""" + if self.df is not None: + mask = self.df[self.df["node_type"] != nodetype].index + self.df.drop(mask, inplace=True) + self.df.reset_index(inplace=True, drop=True) + def plot_allocation_networks(self, ax=None, zorder=None) -> Any: if ax is None: _, ax = plt.subplots() diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index bcc7fb43e..5394774e4 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -195,9 +195,10 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: with open(filepath, "rb") as f: config = tomli.load(f) - context_file_loading.get()["directory"] = filepath.parent / config.get( - "input_dir", "." - ) + directory = filepath.parent / config.get("input_dir", ".") + context_file_loading.get()["directory"] = directory + context_file_loading.get()["database"] = directory / "database.gpkg" + return config else: return {} diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index 4b7d8023f..3519d17e8 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -155,7 +155,7 @@ def test_write_adds_fid_in_tables(basic, tmp_path): # for node an explicit index was provided nrow = len(model_orig.basin.node.df) assert model_orig.basin.node.df.index.name is None - assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0))) + # assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0))) # for edge no index was provided, but it still needs to write it to file nrow = len(model_orig.edge.df) assert model_orig.edge.df.index.name is None From 0d73901a2d35dd0efc129abacdf1ae2f6061fe86 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Wed, 13 Mar 2024 12:08:57 +0100 Subject: [PATCH 2/5] Updated tests. --- python/ribasim/ribasim/geometry/edge.py | 6 +----- python/ribasim/tests/test_io.py | 28 +++++-------------------- python/ribasim/tests/test_model.py | 3 +-- 3 files changed, 7 insertions(+), 30 deletions(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index c95e5d96a..0a5d82656 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -8,7 +8,7 @@ from geopandas import GeoDataFrame from matplotlib.axes import Axes from numpy.typing import NDArray -from pandera.typing import DataFrame, Series +from pandera.typing import Series from pandera.typing.geopandas import GeoSeries from shapely.geometry import LineString, MultiLineString, Point @@ -42,10 +42,6 @@ class Config: class EdgeTable(SpatialTableModel[EdgeSchema]): """Defines the connections between nodes.""" - def __init__(self, **kwargs): - kwargs.setdefault("df", DataFrame[EdgeSchema]()) - super().__init__(**kwargs) - def add( self, from_node: NodeData, diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index 7c6d34b51..683b49274 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -1,7 +1,6 @@ import pytest import ribasim import tomli -from numpy.testing import assert_array_equal from pandas import DataFrame from pandas.testing import assert_frame_equal from pydantic import ValidationError @@ -17,9 +16,9 @@ def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None: # We set this on write, needed for GeoPackage. a.index.name = "fid" a.index.name = "fid" - else: - a = a.reset_index(drop=True) - b = b.reset_index(drop=True) + + a = a.reset_index(drop=True) + b = b.reset_index(drop=True) # avoid comparing datetime64[ns] with datetime64[ms] if "time" in a: @@ -34,7 +33,6 @@ def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None: return assert_frame_equal(a, b) -@pytest.mark.xfail(reason="Needs Model read implementation") def test_basic(basic, tmp_path): model_orig = basic toml_path = tmp_path / "basic/ribasim.toml" @@ -46,19 +44,10 @@ def test_basic(basic, tmp_path): assert toml_dict["ribasim_version"] == ribasim.__version__ - index_a = model_orig.network.node.df.index.to_numpy(int) - index_b = model_loaded.network.node.df.index.to_numpy(int) - assert_array_equal(index_a, index_b) - __assert_equal( - model_orig.network.node.df, model_loaded.network.node.df, is_network=True - ) - __assert_equal( - model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True - ) + __assert_equal(model_orig.edge.df, model_loaded.edge.df, is_network=True) assert model_loaded.basin.time.df is None -@pytest.mark.xfail(reason="Needs Model read implementation") def test_basic_arrow(basic_arrow, tmp_path): model_orig = basic_arrow model_orig.write(tmp_path / "basic_arrow/ribasim.toml") @@ -67,18 +56,12 @@ def test_basic_arrow(basic_arrow, tmp_path): __assert_equal(model_orig.basin.profile.df, model_loaded.basin.profile.df) -@pytest.mark.xfail(reason="Needs Model read implementation") def test_basic_transient(basic_transient, tmp_path): model_orig = basic_transient model_orig.write(tmp_path / "basic_transient/ribasim.toml") model_loaded = ribasim.Model(filepath=tmp_path / "basic_transient/ribasim.toml") - __assert_equal( - model_orig.network.node.df, model_loaded.network.node.df, is_network=True - ) - __assert_equal( - model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True - ) + __assert_equal(model_orig.edge.df, model_loaded.edge.df, is_network=True) time = model_loaded.basin.time assert model_orig.basin.time.df.time[0] == time.df.time[0] @@ -111,7 +94,6 @@ def test_extra_columns(basic_transient): terminal.Static(meta_id=[-1, -2, -3], extra=[-1, -2, -3]) -@pytest.mark.xfail(reason="Needs Model read implementation") def test_sort(level_setpoint_with_minmax, tmp_path): model = level_setpoint_with_minmax table = model.discrete_control.condition diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index 3519d17e8..796030925 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -132,7 +132,6 @@ def test_node_ids_unsequential(basic): model.validate_model_node_field_ids() -@pytest.mark.xfail(reason="Needs Model read implementation") def test_tabulated_rating_curve_model(tabulated_rating_curve, tmp_path): model_orig = tabulated_rating_curve basin_area = tabulated_rating_curve.basin.area.df @@ -155,7 +154,7 @@ def test_write_adds_fid_in_tables(basic, tmp_path): # for node an explicit index was provided nrow = len(model_orig.basin.node.df) assert model_orig.basin.node.df.index.name is None - # assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0))) + # for edge no index was provided, but it still needs to write it to file nrow = len(model_orig.edge.df) assert model_orig.edge.df.index.name is None From b4a74b89d872a2af4963c4cf480968f15a0393bb Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Wed, 13 Mar 2024 12:45:17 +0100 Subject: [PATCH 3/5] Fix missing Edge table. --- python/ribasim/ribasim/geometry/edge.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index 0a5d82656..f3da67c77 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -8,8 +8,9 @@ from geopandas import GeoDataFrame from matplotlib.axes import Axes from numpy.typing import NDArray -from pandera.typing import Series +from pandera.typing import DataFrame, Series from pandera.typing.geopandas import GeoSeries +from pydantic import model_validator from shapely.geometry import LineString, MultiLineString, Point from ribasim.input_base import SpatialTableModel @@ -42,6 +43,12 @@ class Config: class EdgeTable(SpatialTableModel[EdgeSchema]): """Defines the connections between nodes.""" + @model_validator(mode="after") + def empty_table(self) -> "EdgeTable": + if self.df is None: + self.df = DataFrame[EdgeSchema]() + return self + def add( self, from_node: NodeData, From bfc882e08d692ebdfba83f7cfca49fe347cf42b4 Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Wed, 13 Mar 2024 13:20:10 +0100 Subject: [PATCH 4/5] Set default table to be spatial. --- python/ribasim/ribasim/geometry/edge.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index f3da67c77..74cb4ec4d 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -5,11 +5,10 @@ import pandas as pd import pandera as pa import shapely -from geopandas import GeoDataFrame from matplotlib.axes import Axes from numpy.typing import NDArray -from pandera.typing import DataFrame, Series -from pandera.typing.geopandas import GeoSeries +from pandera.typing import Series +from pandera.typing.geopandas import GeoDataFrame, GeoSeries from pydantic import model_validator from shapely.geometry import LineString, MultiLineString, Point @@ -46,7 +45,8 @@ class EdgeTable(SpatialTableModel[EdgeSchema]): @model_validator(mode="after") def empty_table(self) -> "EdgeTable": if self.df is None: - self.df = DataFrame[EdgeSchema]() + self.df = GeoDataFrame[EdgeSchema]() + self.df.set_geometry("geometry", inplace=True) return self def add( From c119d2c0abd45a9a6c86c09ac8af69ad5fa6275f Mon Sep 17 00:00:00 2001 From: Maarten Pronk Date: Wed, 13 Mar 2024 13:51:13 +0100 Subject: [PATCH 5/5] Make mypy happy. --- python/ribasim/ribasim/geometry/edge.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index 74cb4ec4d..cdfbb3e54 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -63,7 +63,7 @@ def add( if geometry is None else [geometry] ) - table_to_append = GeoDataFrame( + table_to_append = GeoDataFrame[EdgeSchema]( data={ "from_node_type": pd.Series([from_node.node_type], dtype=str), "from_node_id": pd.Series([from_node.node_id], dtype=int), @@ -79,7 +79,7 @@ def add( if self.df is None: self.df = table_to_append else: - self.df = pd.concat([self.df, table_to_append]) + self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append])) def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: assert self.df is not None