diff --git a/python/ribasim/ribasim/config.py b/python/ribasim/ribasim/config.py index b34710bb0..94a3a9b9f 100644 --- a/python/ribasim/ribasim/config.py +++ b/python/ribasim/ribasim/config.py @@ -5,7 +5,7 @@ import pandas as pd import pydantic from geopandas import GeoDataFrame -from pydantic import ConfigDict, Field +from pydantic import ConfigDict, Field, model_validator from shapely.geometry import Point from ribasim.geometry import BasinAreaSchema, NodeTable @@ -108,6 +108,11 @@ class MultiNodeModel(NodeModel): node: NodeTable = Field(default_factory=NodeTable) _node_type: str + @model_validator(mode="after") + def filter(self) -> "MultiNodeModel": + self.node.filter(self.__class__.__name__) + return self + def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> None: if tables is None: tables = [] diff --git a/python/ribasim/ribasim/geometry/edge.py b/python/ribasim/ribasim/geometry/edge.py index c95e5d96a..cdfbb3e54 100644 --- a/python/ribasim/ribasim/geometry/edge.py +++ b/python/ribasim/ribasim/geometry/edge.py @@ -5,11 +5,11 @@ import pandas as pd import pandera as pa import shapely -from geopandas import GeoDataFrame from matplotlib.axes import Axes from numpy.typing import NDArray -from pandera.typing import DataFrame, Series -from pandera.typing.geopandas import GeoSeries +from pandera.typing import Series +from pandera.typing.geopandas import GeoDataFrame, GeoSeries +from pydantic import model_validator from shapely.geometry import LineString, MultiLineString, Point from ribasim.input_base import SpatialTableModel @@ -42,9 +42,12 @@ class Config: class EdgeTable(SpatialTableModel[EdgeSchema]): """Defines the connections between nodes.""" - def __init__(self, **kwargs): - kwargs.setdefault("df", DataFrame[EdgeSchema]()) - super().__init__(**kwargs) + @model_validator(mode="after") + def empty_table(self) -> "EdgeTable": + if self.df is None: + self.df = GeoDataFrame[EdgeSchema]() + self.df.set_geometry("geometry", inplace=True) + return self def add( self, @@ -60,7 +63,7 @@ def add( if geometry is None else [geometry] ) - table_to_append = GeoDataFrame( + table_to_append = GeoDataFrame[EdgeSchema]( data={ "from_node_type": pd.Series([from_node.node_type], dtype=str), "from_node_id": pd.Series([from_node.node_id], dtype=int), @@ -76,7 +79,7 @@ def add( if self.df is None: self.df = table_to_append else: - self.df = pd.concat([self.df, table_to_append]) + self.df = GeoDataFrame[EdgeSchema](pd.concat([self.df, table_to_append])) def get_where_edge_type(self, edge_type: str) -> NDArray[np.bool_]: assert self.df is not None diff --git a/python/ribasim/ribasim/geometry/node.py b/python/ribasim/ribasim/geometry/node.py index 72f610745..6999e917e 100644 --- a/python/ribasim/ribasim/geometry/node.py +++ b/python/ribasim/ribasim/geometry/node.py @@ -31,6 +31,13 @@ class Config: class NodeTable(SpatialTableModel[NodeSchema]): """The Ribasim nodes as Point geometries.""" + def filter(self, nodetype: str): + """Filter the node table based on the node type.""" + if self.df is not None: + mask = self.df[self.df["node_type"] != nodetype].index + self.df.drop(mask, inplace=True) + self.df.reset_index(inplace=True, drop=True) + def plot_allocation_networks(self, ax=None, zorder=None) -> Any: if ax is None: _, ax = plt.subplots() diff --git a/python/ribasim/ribasim/model.py b/python/ribasim/ribasim/model.py index bcc7fb43e..5394774e4 100644 --- a/python/ribasim/ribasim/model.py +++ b/python/ribasim/ribasim/model.py @@ -195,9 +195,10 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]: with open(filepath, "rb") as f: config = tomli.load(f) - context_file_loading.get()["directory"] = filepath.parent / config.get( - "input_dir", "." - ) + directory = filepath.parent / config.get("input_dir", ".") + context_file_loading.get()["directory"] = directory + context_file_loading.get()["database"] = directory / "database.gpkg" + return config else: return {} diff --git a/python/ribasim/tests/test_io.py b/python/ribasim/tests/test_io.py index 7c6d34b51..683b49274 100644 --- a/python/ribasim/tests/test_io.py +++ b/python/ribasim/tests/test_io.py @@ -1,7 +1,6 @@ import pytest import ribasim import tomli -from numpy.testing import assert_array_equal from pandas import DataFrame from pandas.testing import assert_frame_equal from pydantic import ValidationError @@ -17,9 +16,9 @@ def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None: # We set this on write, needed for GeoPackage. a.index.name = "fid" a.index.name = "fid" - else: - a = a.reset_index(drop=True) - b = b.reset_index(drop=True) + + a = a.reset_index(drop=True) + b = b.reset_index(drop=True) # avoid comparing datetime64[ns] with datetime64[ms] if "time" in a: @@ -34,7 +33,6 @@ def __assert_equal(a: DataFrame, b: DataFrame, is_network=False) -> None: return assert_frame_equal(a, b) -@pytest.mark.xfail(reason="Needs Model read implementation") def test_basic(basic, tmp_path): model_orig = basic toml_path = tmp_path / "basic/ribasim.toml" @@ -46,19 +44,10 @@ def test_basic(basic, tmp_path): assert toml_dict["ribasim_version"] == ribasim.__version__ - index_a = model_orig.network.node.df.index.to_numpy(int) - index_b = model_loaded.network.node.df.index.to_numpy(int) - assert_array_equal(index_a, index_b) - __assert_equal( - model_orig.network.node.df, model_loaded.network.node.df, is_network=True - ) - __assert_equal( - model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True - ) + __assert_equal(model_orig.edge.df, model_loaded.edge.df, is_network=True) assert model_loaded.basin.time.df is None -@pytest.mark.xfail(reason="Needs Model read implementation") def test_basic_arrow(basic_arrow, tmp_path): model_orig = basic_arrow model_orig.write(tmp_path / "basic_arrow/ribasim.toml") @@ -67,18 +56,12 @@ def test_basic_arrow(basic_arrow, tmp_path): __assert_equal(model_orig.basin.profile.df, model_loaded.basin.profile.df) -@pytest.mark.xfail(reason="Needs Model read implementation") def test_basic_transient(basic_transient, tmp_path): model_orig = basic_transient model_orig.write(tmp_path / "basic_transient/ribasim.toml") model_loaded = ribasim.Model(filepath=tmp_path / "basic_transient/ribasim.toml") - __assert_equal( - model_orig.network.node.df, model_loaded.network.node.df, is_network=True - ) - __assert_equal( - model_orig.network.edge.df, model_loaded.network.edge.df, is_network=True - ) + __assert_equal(model_orig.edge.df, model_loaded.edge.df, is_network=True) time = model_loaded.basin.time assert model_orig.basin.time.df.time[0] == time.df.time[0] @@ -111,7 +94,6 @@ def test_extra_columns(basic_transient): terminal.Static(meta_id=[-1, -2, -3], extra=[-1, -2, -3]) -@pytest.mark.xfail(reason="Needs Model read implementation") def test_sort(level_setpoint_with_minmax, tmp_path): model = level_setpoint_with_minmax table = model.discrete_control.condition diff --git a/python/ribasim/tests/test_model.py b/python/ribasim/tests/test_model.py index 4b7d8023f..796030925 100644 --- a/python/ribasim/tests/test_model.py +++ b/python/ribasim/tests/test_model.py @@ -132,7 +132,6 @@ def test_node_ids_unsequential(basic): model.validate_model_node_field_ids() -@pytest.mark.xfail(reason="Needs Model read implementation") def test_tabulated_rating_curve_model(tabulated_rating_curve, tmp_path): model_orig = tabulated_rating_curve basin_area = tabulated_rating_curve.basin.area.df @@ -155,7 +154,7 @@ def test_write_adds_fid_in_tables(basic, tmp_path): # for node an explicit index was provided nrow = len(model_orig.basin.node.df) assert model_orig.basin.node.df.index.name is None - assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0))) + # for edge no index was provided, but it still needs to write it to file nrow = len(model_orig.edge.df) assert model_orig.edge.df.index.name is None