diff --git a/python/ngen_cal/src/ngen/cal/__main__.py b/python/ngen_cal/src/ngen/cal/__main__.py index 14dd04df..7cfe5833 100644 --- a/python/ngen_cal/src/ngen/cal/__main__.py +++ b/python/ngen_cal/src/ngen/cal/__main__.py @@ -15,6 +15,7 @@ from types import ModuleType if TYPE_CHECKING: + from ngen.config.realization import NgenRealization from typing import Mapping, Any from pluggy import PluginManager @@ -35,6 +36,25 @@ def _loaded_plugins(pm: PluginManager) -> str: return f"Plugins Loaded: {', '.join(plugins)}" +def _update_troute_config( + realization: NgenRealization, + troute_config: dict[str, Any], +): + start = realization.time.start_time + end = realization.time.end_time + duration = (end - start).total_seconds() + + troute_config["compute_parameters"]["restart_parameters"]["start_datetime"] = ( + start.strftime("%Y-%m-%d %H:%M:%S") + ) + + forcing_parameters = troute_config["compute_parameters"]["forcing_parameters"] + dt = forcing_parameters["dt"] + nts, r = divmod(duration, dt) + assert r == 0, "routing timestep is not evenly divisible by ngen_timesteps" + forcing_parameters["nts"] = nts + + def main(general: General, model_conf: Mapping[str, Any]): #seed the random number generators if requested if general.random_seed is not None: @@ -108,6 +128,80 @@ def main(general: General, model_conf: Mapping[str, Any]): #for catchment_set in agent.model.adjustables: # func(start_iteration, general.iterations, catchment_set, agent) func(start_iteration, general.iterations, agent) + + if (validation_parms := model.model.unwrap().val_params) is not None: + print("configuring calibration") + # NOTE: importing here so its easier to refactor in the future + from ngen.cal.calibration_set import CalibrationSet + import pandas as pd + from typing import TYPE_CHECKING + if TYPE_CHECKING: + from typing import Sequence + import pandas as pd + from ngen.cal.calibration_cathment import CalibrationCatchment + + adjustables: Sequence[CalibrationCatchment] = agent.model.adjustables + + realization: NgenRealization = agent.model.unwrap().ngen_realization + assert realization is not None + + sim_start, sim_end = validation_parms.sim_interval() + eval_start, eval_end = validation_parms.evaluation_interval() + print(f"validation {sim_start=} {sim_end=}") + + # NOTE: do this before `update_config` is called so the right path is written to disk + realization.time.start_time = sim_start + realization.time.end_time = sim_end + + assert realization.routing is not None + + troute_config_path = realization.routing.config + + with troute_config_path.open() as fp: + troute_config = yaml.safe_load(fp) + + _update_troute_config(realization, troute_config) + + troute_config_path_validation = troute_config_path.with_name("troute_validation.yaml") + with troute_config_path_validation.open("w") as fp: + yaml.dump(troute_config, fp) + + # NOTE: do this before `update_config` is called so the right path is written to disk + realization.routing.config = troute_config_path_validation + + for calibration_object in adjustables: + best_df: pd.DataFrame = calibration_object.df[[str(agent.best_params), 'param', 'model']] + + agent.update_config(agent.best_params, best_df, calibration_object.id) + + # NOTE: importing here so its easier to refactor in the future + from ngen.cal.search import _execute, _objective_func + from ngen.cal.utils import pushd + + print("starting calibration") + # TODO: validation_parms.objective and target are not being correctly configured + _execute(agent) + with pushd(agent.job.workdir): + sim = calibration_object.output + + assert isinstance(calibration_object, CalibrationSet) + # TODO: get from realization config + simulation_interval = pd.Timedelta(3600, unit="s") + # TODO: need a way to get the nexus + nexus = calibration_object._eval_nexus + agent_pm = agent.model.unwrap()._plugin_manager + obs = agent_pm.hook.ngen_cal_model_observations( + nexus=nexus, + # NOTE: techinically start_time=`eval_start` + `simulation_interval` + start_time=eval_start, + end_time=eval_end, + simulation_interval=simulation_interval, + ) + print(f"{sim=}") + print(f"{obs=}") + score = _objective_func(sim, obs, validation_parms.objective, (sim_start, sim_end)) + print(f"validation run score: {score}") + # call `ngen_cal_finish` plugin hook functions except Exception as e: plugin_manager.hook.ngen_cal_finish(exception=e) diff --git a/python/ngen_cal/src/ngen/cal/model.py b/python/ngen_cal/src/ngen/cal/model.py index a88796d6..2d5da773 100644 --- a/python/ngen_cal/src/ngen/cal/model.py +++ b/python/ngen_cal/src/ngen/cal/model.py @@ -212,6 +212,51 @@ def restart(self) -> int: return start_iteration + +class ValidationOptions(BaseModel): + """A data class holding validation options""" + #Optional, but co-dependent, see @_validate_start_stop_both_or_neither_exist for validation logic + evaluation_start: datetime + evaluation_stop: datetime + sim_start: Optional[datetime] = None + sim_stop: Optional[datetime] = None + objective: Optional[Union[Objective, PyObject]] = Objective.custom + target: Union[Literal['min'], Literal['max'], float] = 'min' + + def sim_interval(self) -> tuple[datetime, datetime]: + """Returns a tuple of simulation start and stop datetimes""" + start = self.sim_start if self.sim_start is not None else self.evaluation_start + stop = self.sim_stop if self.sim_stop is not None else self.evaluation_stop + return start, stop + + def evaluation_interval(self) -> tuple[datetime, datetime]: + """Returns a tuple of evaluation start and stop datetimes""" + return self.evaluation_start, self.evaluation_stop + + @root_validator(skip_on_failure=True) + @classmethod + def _validate_periods(cls, values: dict[str, datetime | None]) -> dict[str, datetime | None]: + evaluation_start: datetime = values["evaluation_start"] # type: ignore + evaluation_stop: datetime = values["evaluation_stop"] # type: ignore + sim_start: datetime | None = values.get("sim_start") + sim_stop: datetime | None = values.get("sim_stop") + + errs: list[str] = [] + if sim_start is not None and sim_start > evaluation_start: + errs.append("`sim_start` must be <= `evaluation_start`") + + if sim_stop is not None and sim_stop < evaluation_stop: + errs.append("`evaluation_stop` must be <= `sim_stop`") + + if evaluation_stop < evaluation_start: + errs.append("`evaluation_start` must be <= `evaluation_stop`") + + if errs: + raise ValueError("\n".join(errs)) + + return values + + class ModelExec(BaseModel, Configurable): """ The data class for a given model, which must also be Configurable @@ -220,6 +265,8 @@ class ModelExec(BaseModel, Configurable): args: Optional[str] workdir: DirectoryPath = Path("./") #FIXME test the various workdirs eval_params: Optional[EvaluationOptions] = Field(default_factory=EvaluationOptions) + # TODO: likely want to move this into `NgenBase` instead of here + val_params: Optional[ValidationOptions] = None plugins: List[PyObjectOrModule] = Field(default_factory=list) plugin_settings: Dict[str, Dict[str, Any]] = Field(default_factory=dict) diff --git a/python/ngen_cal/src/ngen/cal/ngen.py b/python/ngen_cal/src/ngen/cal/ngen.py index 9d2cd498..86bea941 100644 --- a/python/ngen_cal/src/ngen/cal/ngen.py +++ b/python/ngen_cal/src/ngen/cal/ngen.py @@ -388,7 +388,16 @@ def update_config(self, i: int, params: pd.DataFrame, id: str = None, path=Path( # Cleanup any t-route parquet files between runs # TODO this may not be _the_ best place to do this, but for now, # it works, so here it be... - for file in Path(path).glob("*NEXOUT.parquet"): + import itertools + to_remove = ( + # Path(path).glob("troute_output_*.*"), + # Path(path).glob("flowveldepth_*.*"), + Path(path).glob("*NEXOUT.parquet"), + # ngen files + # Path(path).glob("cat-*.csv"), + # Path(path).glob("nex-*_output.csv"), + ) + for file in itertools.chain(*to_remove): file.unlink() class NgenExplicit(NgenBase): diff --git a/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py b/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py index 40344d64..d6625f9a 100644 --- a/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py +++ b/python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py @@ -7,22 +7,30 @@ import pandas as pd from ngen.cal import hookimpl +from pydantic import BaseModel if TYPE_CHECKING: from ngen.cal.meta import JobMeta - from ngen.cal.model import ModelExec + from ngen.cal.model import ModelExec, ValidationOptions, EvaluationOptions from ngen.config.realization import NgenRealization class _NgenCalModelOutputFn(typing.Protocol): def __call__(self, id: str) -> pd.Series: ... +class TrouteOutputSettings(BaseModel): + validation_routing_output: Path + @typing.final class TrouteOutput: def __init__(self, filepath: Path) -> None: self._output_file = filepath + self._settings: TrouteOutputSettings | None = None + self._ngen_realization: NgenRealization | None = None + self._validation_options: ValidationOptions | None = None + self._eval_options: EvaluationOptions | None = None @hookimpl def ngen_cal_model_configure(self, config: ModelExec) -> None: @@ -33,36 +41,79 @@ def ngen_cal_model_configure(self, config: ModelExec) -> None: assert config.ngen_realization is not None self._ngen_realization = config.ngen_realization - # Try external provided output hooks, if those fail, try this one - # this will only execute if all other hooks return None (or they don't exist) - @hookimpl(specname="ngen_cal_model_output", trylast=True) - def get_output(self, id: str) -> pd.Series | None: + if (eval_options := config.eval_params) is not None: + self._eval_options = eval_options + + if (validation_config := config.val_params) is not None: + self._validation_options = validation_config + + # maybe pull in plugin settings + if (plugin_settings := config.plugin_settings.get("ngen_cal_troute_output")) is not None: + self._settings = TrouteOutputSettings.parse_obj(plugin_settings) + + def _sim_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]: assert ( self._ngen_realization is not None ), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured" - if not self._output_file.exists(): - print( - f"{self._output_file} not found. Current working directory is {Path.cwd()!s}" - ) - print("Setting output to None") - return None + if self._eval_options is not None and self._eval_options.evaluation_start is not None: + assert self._eval_options.evaluation_stop is not None + return self._eval_options.evaluation_start, self._eval_options.evaluation_stop + + return self._ngen_realization.time.start_time, self._ngen_realization.time.end_time - filetype = self._output_file.suffix.lower() + def _validation_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]: + if self._validation_options is None: + print("validation options not provided, using sim evaluation interval") + return self._sim_eval_interval() + return self._validation_options.evaluation_interval() + + def _output_handler_factory(self, output_file: Path) -> _NgenCalModelOutputFn: + filetype = output_file.suffix.lower() if filetype == ".csv": - fn = self._factory_handler_csv(self._output_file) + fn = self._factory_handler_csv(output_file) # TODO: fix. dont know if this format still works # elif filetype == ".hdf5": # fn = _model_output_legacy_hdf5(self._output_file) elif filetype == ".nc": - fn = _stream_output_netcdf_v1(self._output_file) + fn = _stream_output_netcdf_v1(output_file) elif filetype == ".parquet": - fn = _stream_output_parquet_v1(self._output_file) + fn = _stream_output_parquet_v1(output_file) else: raise RuntimeError( - f"unsupported t-route output filetype: {self._output_file.suffix}" + f"unsupported t-route output filetype: {output_file.suffix}" ) + return fn + + # Try external provided output hooks, if those fail, try this one + # this will only execute if all other hooks return None (or they don't exist) + @hookimpl(specname="ngen_cal_model_output", trylast=True) + def get_output(self, id: str) -> pd.Series | None: + assert ( + self._ngen_realization is not None + ), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured" + + if self._settings is not None and self._settings.validation_routing_output.exists(): + output_file = self._settings.validation_routing_output + print(f"retrieving simulation data from validation output file: {output_file!s}") + + start, end = self._validation_eval_interval() + print(f"validation: {start=} {end=}") + elif self._output_file.exists(): + output_file = self._output_file + print(f"retrieving simulation data from output file: {output_file!s}") + + start, end = self._sim_eval_interval() + print(f"{start=} {end=}") + else: + print( + f"{self._output_file} not found. Current working directory is {Path.cwd()!s}" + ) + print("Setting output to None") + return None + # TODO: I dont think all output handlers can handle validation (csv comes to mind). circle back to this + fn = self._output_handler_factory(output_file) ds = fn(id) ds.name = "sim_flow" @@ -74,7 +125,7 @@ def get_output(self, id: str) -> pd.Series | None: seconds=self._ngen_realization.time.output_interval ) start = self._ngen_realization.time.start_time - ds = ds.loc[start + ngen_dt :] + ds = ds.loc[start + ngen_dt :end] ds = ds.resample("1h").first() return ds diff --git a/python/ngen_cal/src/ngen/cal/search.py b/python/ngen_cal/src/ngen/cal/search.py index 7e29d920..5ca2b8b8 100644 --- a/python/ngen_cal/src/ngen/cal/search.py +++ b/python/ngen_cal/src/ngen/cal/search.py @@ -24,11 +24,14 @@ def _objective_func(simulated_hydrograph, observed_hydrograph, objective, eval_range: tuple[datetime, datetime] | None = None): df = pd.merge(simulated_hydrograph, observed_hydrograph, left_index=True, right_index=True) - if df.empty: - print("WARNING: Cannot compute objective function, do time indicies align?") if eval_range: df = df.loc[eval_range[0]:eval_range[1]] - #print( df ) + if df.empty: + print("WARNING: Cannot compute objective function, do time indicies align?") + if eval_range: + print(f"\teval range: [{eval_range[0]!s} : {eval_range[1]!s}]") + print(f"\tsim interval: [{simulated_hydrograph.index.min()!s} : {simulated_hydrograph.index.max()!s}]") + print(f"\tobs interval: [{observed_hydrograph.index.min()!s} : {observed_hydrograph.index.max()!s}]") #Evaluate custom objective function providing simulated, observed series return objective(df['obs_flow'], df['sim_flow'])