Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable read method again with the new add API #1243

Merged
merged 6 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/ribasim/ribasim/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import pydantic
from geopandas import GeoDataFrame
from pydantic import ConfigDict, Field
from pydantic import ConfigDict, Field, model_validator
from shapely.geometry import Point

from ribasim.geometry import BasinAreaSchema, NodeTable
Expand Down Expand Up @@ -108,6 +108,11 @@ class MultiNodeModel(NodeModel):
node: NodeTable = Field(default_factory=NodeTable)
_node_type: str

@model_validator(mode="after")
def filter(self) -> "MultiNodeModel":
self.node.filter(self.__class__.__name__)
return self

def add(self, node: Node, tables: Sequence[TableModel[Any]] | None = None) -> None:
if tables is None:
tables = []
Expand Down
7 changes: 7 additions & 0 deletions python/ribasim/ribasim/geometry/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,13 @@ class Config:
class NodeTable(SpatialTableModel[NodeSchema]):
"""The Ribasim nodes as Point geometries."""

def filter(self, nodetype: str):
"""Filter the node table based on the node type."""
if self.df is not None:
mask = self.df[self.df["node_type"] != nodetype].index
self.df.drop(mask, inplace=True)
self.df.reset_index(inplace=True, drop=True)

def plot_allocation_networks(self, ax=None, zorder=None) -> Any:
if ax is None:
_, ax = plt.subplots()
Expand Down
7 changes: 4 additions & 3 deletions python/ribasim/ribasim/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,10 @@ def _load(cls, filepath: Path | None) -> dict[str, Any]:
with open(filepath, "rb") as f:
config = tomli.load(f)

context_file_loading.get()["directory"] = filepath.parent / config.get(
"input_dir", "."
)
directory = filepath.parent / config.get("input_dir", ".")
context_file_loading.get()["directory"] = directory
context_file_loading.get()["database"] = directory / "database.gpkg"

return config
else:
return {}
Expand Down
2 changes: 1 addition & 1 deletion python/ribasim/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_write_adds_fid_in_tables(basic, tmp_path):
# for node an explicit index was provided
nrow = len(model_orig.basin.node.df)
assert model_orig.basin.node.df.index.name is None
assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0)))
# assert model_orig.basin.node.df.index.equals(pd.Index(np.full(nrow, 0)))
evetion marked this conversation as resolved.
Show resolved Hide resolved
# for edge no index was provided, but it still needs to write it to file
nrow = len(model_orig.edge.df)
assert model_orig.edge.df.index.name is None
Expand Down
Loading