Skip to content

Commit

Permalink
add 'name' attribute to identify the model
Browse files Browse the repository at this point in the history
  • Loading branch information
duynguyen02 committed Dec 22, 2024
1 parent 6ac922e commit 0d4e456
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions hydtank/hydtank.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import io
import uuid
from datetime import datetime
from queue import Queue
from typing import Optional, List, Literal, Union
Expand Down Expand Up @@ -34,13 +35,13 @@

class HydTANK:
def __init__(
self,
dataset: Dataset,
basin_defs: List[BasinDef],
root_node: List[BasinDef],
interval: float = 24.0,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
self,
dataset: Dataset,
basin_defs: List[BasinDef],
root_node: List[BasinDef],
interval: float = 24.0,
start: Optional[datetime] = None,
end: Optional[datetime] = None,
):
self._dataset = dataset
self._basin_defs = basin_defs
Expand All @@ -55,6 +56,8 @@ def __init__(
self._E: Optional[np.ndarray] = None
self._q_obs: Optional[np.ndarray] = None

self._name = uuid.uuid4().hex

self._logs: List[str] = []

self._run()
Expand Down Expand Up @@ -87,6 +90,14 @@ def end(self):
def logs(self):
return self._logs

@property
def name(self):
return self._name

@name.setter
def name(self, name: str):
self._name = name

def copy(self):
return copy.deepcopy(self)

Expand Down Expand Up @@ -152,7 +163,7 @@ def reload(self):
self._run()

def _reconfig_by_stacked_parameters(
self, basin_defs: List[Union[Subbasin, Reach]], stacked_parameters: List[float]
self, basin_defs: List[Union[Subbasin, Reach]], stacked_parameters: List[float]
):
subbasin_steps = len(SubbasinParameters().to_initial_params())
reach_steps = len(ReachParameters().to_initial_params())
Expand All @@ -172,7 +183,7 @@ def _reconfig_by_stacked_parameters(
self._run()

def _optimize_operator(
self, stacked_parameters: List[float], basin_defs: List[Union[Subbasin, Reach]]
self, stacked_parameters: List[float], basin_defs: List[Union[Subbasin, Reach]]
):
self._reconfig_by_stacked_parameters(basin_defs, stacked_parameters)
_nse = 0
Expand Down Expand Up @@ -237,8 +248,8 @@ def reconfig_parameters(self, name: str, parameters: Parameters):
basin_def = self.get_basin_def_by_name(name)

if (
isinstance(basin_def, Subbasin)
and isinstance(parameters, SubbasinParameters)
isinstance(basin_def, Subbasin)
and isinstance(parameters, SubbasinParameters)
) or (isinstance(basin_def, Reach) and isinstance(parameters, ReachParameters)):
basin_def.parameters = parameters

Expand All @@ -250,7 +261,7 @@ def reconfig_parameters(self, name: str, parameters: Parameters):
self._run()

def _generate_plot(
self, plot_data, xlabel="Timeseries", ylabel="Flow", figsize=(12, 6), title=None
self, plot_data, xlabel="Timeseries", ylabel="Flow", figsize=(12, 6), title=None
):
plt.figure(figsize=figsize)

Expand Down Expand Up @@ -325,12 +336,12 @@ def plot_headwater_q(self):
return self._plot_basin_q(include_all=False)

def plot_basin_network(
self,
layout_type: Literal[
"hierarchical", "circular", "spring", "kamada-kawai", "multipartite"
] = "hierarchical",
figsize=(15, 10),
node_spacing=1.0,
self,
layout_type: Literal[
"hierarchical", "circular", "spring", "kamada-kawai", "multipartite"
] = "hierarchical",
figsize=(15, 10),
node_spacing=1.0,
):
color_map = {
"Subbasin": "#7FB3D5",
Expand Down

0 comments on commit 0d4e456

Please sign in to comment.