Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ngen.cal optional post calibration validation run #200

Merged
merged 4 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions python/ngen_cal/src/ngen/cal/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from types import ModuleType

if TYPE_CHECKING:
from ngen.config.realization import NgenRealization
from typing import Mapping, Any
from pluggy import PluginManager

Expand All @@ -35,6 +36,25 @@ def _loaded_plugins(pm: PluginManager) -> str:
return f"Plugins Loaded: {', '.join(plugins)}"


def _update_troute_config(
realization: NgenRealization,
troute_config: dict[str, Any],
):
start = realization.time.start_time
end = realization.time.end_time
duration = (end - start).total_seconds()

troute_config["compute_parameters"]["restart_parameters"]["start_datetime"] = (
start.strftime("%Y-%m-%d %H:%M:%S")
)

forcing_parameters = troute_config["compute_parameters"]["forcing_parameters"]
dt = forcing_parameters["dt"]
nts, r = divmod(duration, dt)
assert r == 0, "routing timestep is not evenly divisible by ngen_timesteps"
forcing_parameters["nts"] = nts


def main(general: General, model_conf: Mapping[str, Any]):
#seed the random number generators if requested
if general.random_seed is not None:
Expand Down Expand Up @@ -108,6 +128,80 @@ def main(general: General, model_conf: Mapping[str, Any]):
#for catchment_set in agent.model.adjustables:
# func(start_iteration, general.iterations, catchment_set, agent)
func(start_iteration, general.iterations, agent)

if (validation_parms := model.model.unwrap().val_params) is not None:
print("configuring calibration")
# NOTE: importing here so its easier to refactor in the future
from ngen.cal.calibration_set import CalibrationSet
import pandas as pd
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Sequence
import pandas as pd
from ngen.cal.calibration_cathment import CalibrationCatchment

adjustables: Sequence[CalibrationCatchment] = agent.model.adjustables

realization: NgenRealization = agent.model.unwrap().ngen_realization
assert realization is not None

sim_start, sim_end = validation_parms.sim_interval()
eval_start, eval_end = validation_parms.evaluation_interval()
print(f"validation {sim_start=} {sim_end=}")

# NOTE: do this before `update_config` is called so the right path is written to disk
realization.time.start_time = sim_start
realization.time.end_time = sim_end

assert realization.routing is not None

troute_config_path = realization.routing.config

with troute_config_path.open() as fp:
troute_config = yaml.safe_load(fp)

_update_troute_config(realization, troute_config)

troute_config_path_validation = troute_config_path.with_name("troute_validation.yaml")
with troute_config_path_validation.open("w") as fp:
yaml.dump(troute_config, fp)

# NOTE: do this before `update_config` is called so the right path is written to disk
realization.routing.config = troute_config_path_validation

for calibration_object in adjustables:
best_df: pd.DataFrame = calibration_object.df[[str(agent.best_params), 'param', 'model']]

agent.update_config(agent.best_params, best_df, calibration_object.id)

# NOTE: importing here so its easier to refactor in the future
from ngen.cal.search import _execute, _objective_func
from ngen.cal.utils import pushd

print("starting calibration")
# TODO: validation_parms.objective and target are not being correctly configured
_execute(agent)
with pushd(agent.job.workdir):
sim = calibration_object.output

assert isinstance(calibration_object, CalibrationSet)
# TODO: get from realization config
simulation_interval = pd.Timedelta(3600, unit="s")
# TODO: need a way to get the nexus
nexus = calibration_object._eval_nexus
agent_pm = agent.model.unwrap()._plugin_manager
obs = agent_pm.hook.ngen_cal_model_observations(
nexus=nexus,
# NOTE: techinically start_time=`eval_start` + `simulation_interval`
start_time=eval_start,
end_time=eval_end,
simulation_interval=simulation_interval,
)
print(f"{sim=}")
print(f"{obs=}")
score = _objective_func(sim, obs, validation_parms.objective, (sim_start, sim_end))
print(f"validation run score: {score}")

# call `ngen_cal_finish` plugin hook functions
except Exception as e:
plugin_manager.hook.ngen_cal_finish(exception=e)
Expand Down
47 changes: 47 additions & 0 deletions python/ngen_cal/src/ngen/cal/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,51 @@ def restart(self) -> int:

return start_iteration


class ValidationOptions(BaseModel):
"""A data class holding validation options"""
#Optional, but co-dependent, see @_validate_start_stop_both_or_neither_exist for validation logic
evaluation_start: datetime
evaluation_stop: datetime
sim_start: Optional[datetime] = None
sim_stop: Optional[datetime] = None
objective: Optional[Union[Objective, PyObject]] = Objective.custom
target: Union[Literal['min'], Literal['max'], float] = 'min'

def sim_interval(self) -> tuple[datetime, datetime]:
"""Returns a tuple of simulation start and stop datetimes"""
start = self.sim_start if self.sim_start is not None else self.evaluation_start
stop = self.sim_stop if self.sim_stop is not None else self.evaluation_stop
return start, stop

def evaluation_interval(self) -> tuple[datetime, datetime]:
"""Returns a tuple of evaluation start and stop datetimes"""
return self.evaluation_start, self.evaluation_stop

@root_validator(skip_on_failure=True)
@classmethod
def _validate_periods(cls, values: dict[str, datetime | None]) -> dict[str, datetime | None]:
evaluation_start: datetime = values["evaluation_start"] # type: ignore
evaluation_stop: datetime = values["evaluation_stop"] # type: ignore
sim_start: datetime | None = values.get("sim_start")
sim_stop: datetime | None = values.get("sim_stop")

errs: list[str] = []
if sim_start is not None and sim_start > evaluation_start:
errs.append("`sim_start` must be <= `evaluation_start`")

if sim_stop is not None and sim_stop < evaluation_stop:
errs.append("`evaluation_stop` must be <= `sim_stop`")

if evaluation_stop < evaluation_start:
errs.append("`evaluation_start` must be <= `evaluation_stop`")

if errs:
raise ValueError("\n".join(errs))

return values


class ModelExec(BaseModel, Configurable):
"""
The data class for a given model, which must also be Configurable
Expand All @@ -220,6 +265,8 @@ class ModelExec(BaseModel, Configurable):
args: Optional[str]
workdir: DirectoryPath = Path("./") #FIXME test the various workdirs
eval_params: Optional[EvaluationOptions] = Field(default_factory=EvaluationOptions)
# TODO: likely want to move this into `NgenBase` instead of here
val_params: Optional[ValidationOptions] = None
plugins: List[PyObjectOrModule] = Field(default_factory=list)
plugin_settings: Dict[str, Dict[str, Any]] = Field(default_factory=dict)

Expand Down
11 changes: 10 additions & 1 deletion python/ngen_cal/src/ngen/cal/ngen.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,16 @@ def update_config(self, i: int, params: pd.DataFrame, id: str = None, path=Path(
# Cleanup any t-route parquet files between runs
# TODO this may not be _the_ best place to do this, but for now,
# it works, so here it be...
for file in Path(path).glob("*NEXOUT.parquet"):
import itertools
to_remove = (
# Path(path).glob("troute_output_*.*"),
# Path(path).glob("flowveldepth_*.*"),
Path(path).glob("*NEXOUT.parquet"),
# ngen files
# Path(path).glob("cat-*.csv"),
# Path(path).glob("nex-*_output.csv"),
)
for file in itertools.chain(*to_remove):
file.unlink()

class NgenExplicit(NgenBase):
Expand Down
85 changes: 68 additions & 17 deletions python/ngen_cal/src/ngen/cal/ngen_hooks/ngen_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,30 @@

import pandas as pd
from ngen.cal import hookimpl
from pydantic import BaseModel

if TYPE_CHECKING:
from ngen.cal.meta import JobMeta
from ngen.cal.model import ModelExec
from ngen.cal.model import ModelExec, ValidationOptions, EvaluationOptions
from ngen.config.realization import NgenRealization


class _NgenCalModelOutputFn(typing.Protocol):
def __call__(self, id: str) -> pd.Series: ...

class TrouteOutputSettings(BaseModel):
validation_routing_output: Path


@typing.final
class TrouteOutput:
def __init__(self, filepath: Path) -> None:
self._output_file = filepath
self._settings: TrouteOutputSettings | None = None

self._ngen_realization: NgenRealization | None = None
self._validation_options: ValidationOptions | None = None
self._eval_options: EvaluationOptions | None = None

@hookimpl
def ngen_cal_model_configure(self, config: ModelExec) -> None:
Expand All @@ -33,36 +41,79 @@ def ngen_cal_model_configure(self, config: ModelExec) -> None:
assert config.ngen_realization is not None
self._ngen_realization = config.ngen_realization

# Try external provided output hooks, if those fail, try this one
# this will only execute if all other hooks return None (or they don't exist)
@hookimpl(specname="ngen_cal_model_output", trylast=True)
def get_output(self, id: str) -> pd.Series | None:
if (eval_options := config.eval_params) is not None:
self._eval_options = eval_options

if (validation_config := config.val_params) is not None:
self._validation_options = validation_config

# maybe pull in plugin settings
if (plugin_settings := config.plugin_settings.get("ngen_cal_troute_output")) is not None:
self._settings = TrouteOutputSettings.parse_obj(plugin_settings)

def _sim_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]:
assert (
self._ngen_realization is not None
), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured"

if not self._output_file.exists():
print(
f"{self._output_file} not found. Current working directory is {Path.cwd()!s}"
)
print("Setting output to None")
return None
if self._eval_options is not None and self._eval_options.evaluation_start is not None:
assert self._eval_options.evaluation_stop is not None
return self._eval_options.evaluation_start, self._eval_options.evaluation_stop

return self._ngen_realization.time.start_time, self._ngen_realization.time.end_time

filetype = self._output_file.suffix.lower()
def _validation_eval_interval(self) -> tuple[datetime.datetime, datetime.datetime]:
if self._validation_options is None:
print("validation options not provided, using sim evaluation interval")
return self._sim_eval_interval()
return self._validation_options.evaluation_interval()

def _output_handler_factory(self, output_file: Path) -> _NgenCalModelOutputFn:
filetype = output_file.suffix.lower()
if filetype == ".csv":
fn = self._factory_handler_csv(self._output_file)
fn = self._factory_handler_csv(output_file)
# TODO: fix. dont know if this format still works
# elif filetype == ".hdf5":
# fn = _model_output_legacy_hdf5(self._output_file)
elif filetype == ".nc":
fn = _stream_output_netcdf_v1(self._output_file)
fn = _stream_output_netcdf_v1(output_file)
elif filetype == ".parquet":
fn = _stream_output_parquet_v1(self._output_file)
fn = _stream_output_parquet_v1(output_file)
else:
raise RuntimeError(
f"unsupported t-route output filetype: {self._output_file.suffix}"
f"unsupported t-route output filetype: {output_file.suffix}"
)
return fn

# Try external provided output hooks, if those fail, try this one
# this will only execute if all other hooks return None (or they don't exist)
@hookimpl(specname="ngen_cal_model_output", trylast=True)
def get_output(self, id: str) -> pd.Series | None:
assert (
self._ngen_realization is not None
), "ngen realization required; ensure `ngen_cal_model_configure` was called and the plugin was properly configured"

if self._settings is not None and self._settings.validation_routing_output.exists():
output_file = self._settings.validation_routing_output
print(f"retrieving simulation data from validation output file: {output_file!s}")

start, end = self._validation_eval_interval()
print(f"validation: {start=} {end=}")
elif self._output_file.exists():
output_file = self._output_file
print(f"retrieving simulation data from output file: {output_file!s}")

start, end = self._sim_eval_interval()
print(f"{start=} {end=}")
else:
print(
f"{self._output_file} not found. Current working directory is {Path.cwd()!s}"
)
print("Setting output to None")
return None

# TODO: I dont think all output handlers can handle validation (csv comes to mind). circle back to this
fn = self._output_handler_factory(output_file)
ds = fn(id)
ds.name = "sim_flow"

Expand All @@ -74,7 +125,7 @@ def get_output(self, id: str) -> pd.Series | None:
seconds=self._ngen_realization.time.output_interval
)
start = self._ngen_realization.time.start_time
ds = ds.loc[start + ngen_dt :]
ds = ds.loc[start + ngen_dt :end]
ds = ds.resample("1h").first()
return ds

Expand Down
9 changes: 6 additions & 3 deletions python/ngen_cal/src/ngen/cal/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,14 @@

def _objective_func(simulated_hydrograph, observed_hydrograph, objective, eval_range: tuple[datetime, datetime] | None = None):
df = pd.merge(simulated_hydrograph, observed_hydrograph, left_index=True, right_index=True)
if df.empty:
print("WARNING: Cannot compute objective function, do time indicies align?")
if eval_range:
df = df.loc[eval_range[0]:eval_range[1]]
#print( df )
if df.empty:
print("WARNING: Cannot compute objective function, do time indicies align?")
if eval_range:
print(f"\teval range: [{eval_range[0]!s} : {eval_range[1]!s}]")
print(f"\tsim interval: [{simulated_hydrograph.index.min()!s} : {simulated_hydrograph.index.max()!s}]")
print(f"\tobs interval: [{observed_hydrograph.index.min()!s} : {observed_hydrograph.index.max()!s}]")
#Evaluate custom objective function providing simulated, observed series
return objective(df['obs_flow'], df['sim_flow'])

Expand Down
Loading