Skip to content

Commit

Permalink
DO NOT MERGE
Browse files Browse the repository at this point in the history
  • Loading branch information
hellkite500 committed Jul 19, 2024
1 parent 50d75c6 commit d677838
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 78 deletions.
7 changes: 5 additions & 2 deletions python/ngen_cal/src/ngen/cal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ def pyobject_schema(cls, field_schema):

PyObject.__modify_schema__ = classmethod(pyobject_schema)

PROJECT_SLUG: Final = "ngen.cal"
hookimpl = pluggy.HookimplMarker(PROJECT_SLUG)

from .configuration import General, Model
from .calibratable import Calibratable, Adjustable, Evaluatable
from .calibration_set import CalibrationSet, UniformCalibrationSet
from .meta import JobMeta
from .plot import *

PROJECT_SLUG: Final = "ngen.cal"

hookimpl = pluggy.HookimplMarker(PROJECT_SLUG)


20 changes: 19 additions & 1 deletion python/ngen_cal/src/ngen/cal/_hookspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

if TYPE_CHECKING:
from ngen.cal.configuration import General
from pandas import Series
from pathlib import Path

hookspec = pluggy.HookspecMarker(PROJECT_SLUG)

#hookspec_model = pluggy.HookspcMarker(PROJECT_SLUG+".model")

@hookspec
def ngen_cal_configure(config: General) -> None:
Expand Down Expand Up @@ -42,3 +44,19 @@ def ngen_cal_finish(exception: Exception | None) -> None:
raised during the calibration loop.
`exception` will be non-none if an exception was raised during calibration.
"""

class ModelHooks():
@hookspec(firstresult=True)
def ngen_cal_model_output(id: str | None) -> Series:
"""
Called during each calibration iteration to provide the model output in the form
of a pandas Series, indexed by time.
Output series should be in units of cubic meters per second.
"""

@hookspec
def ngen_cal_model_post_iteration(path: Path, iteration: int) -> None:
"""
Called after each model iteration is completed and evaluated.
And before the next iteration is configured and started.
"""
13 changes: 13 additions & 0 deletions python/ngen_cal/src/ngen/cal/_plugin_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,16 @@ def setup_plugin_manager(plugins: list[Callable | ModuleType]) -> PluginManager:
assert_never(plugin)

return pm

def before(hook_name, hook_impls: list[pluggy.HookImpl], kwargs):
print("before")
print(hook_name)
print([hook.function.__name__ for hook in hook_impls])
print(kwargs)

def after(outcome, hook_name, hook_impls: list[pluggy.HookImpl], kwargs):
print("after")
print(outcome)
print(hook_name)
print(hook_impls)
print(kwargs)
14 changes: 7 additions & 7 deletions python/ngen_cal/src/ngen/cal/calibratable.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def check_point_file(self) -> 'Path':
"""
return Path('{}_parameter_df_state.parquet'.format(self.id))

def check_point(self, path: 'Path') -> None:
def check_point(self, path: 'Path', iteration: int) -> None:
"""
Save calibration information
"""
Expand All @@ -92,12 +92,12 @@ def load_df(self, path: 'Path') -> None:
"""
self._df = read_parquet(path/self.check_point_file)

@abstractmethod
def save_output(self, i: int) -> None:
"""
Save the last output of the runtime for iteration i
"""
pass
# @abstractmethod
# def save_output(self, i: int) -> None:
# """
# Save the last output of the runtime for iteration i
# """
# pass

def restart(self) -> None:
self.load_df('./')
Expand Down
4 changes: 3 additions & 1 deletion python/ngen_cal/src/ngen/cal/calibration_cathment.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from hypy.nexus import Nexus

from .calibratable import Adjustable, Evaluatable

#TODO git rid of calibration_catchment, should be able to do everything
#via calibration_set of size 1 with proper configuration of output hooks
class AdjustableCatchment(FormulatableCatchment, Adjustable):
"""
A Formulatable catchment that has an Adjustable interface to adjust
Expand All @@ -34,6 +35,7 @@ def __init__(self, workdir: 'Path', id: str, nexus, params: dict = {}):
FormulatableCatchment.__init__(self=self, catchment_id=id, params=params, outflow=nexus)
Adjustable.__init__(self=self, df=DataFrame(params).rename(columns={'init': '0'}))
#FIXME paramterize
#TODO move to explict/catchment plugin function
self._output_file = workdir/'{}.csv'.format(self.id)
self._workdir = workdir

Expand Down
98 changes: 44 additions & 54 deletions python/ngen_cal/src/ngen/cal/calibration_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import shutil
from typing import TYPE_CHECKING, Sequence
import pandas as pd
from pluggy import HookRelay
if TYPE_CHECKING:
from pandas import DataFrame
from pathlib import Path
Expand All @@ -10,24 +11,29 @@
from .model import EvaluationOptions
import os
from pathlib import Path
import warnings
from hypy.nexus import Nexus
from .calibratable import Adjustable, Evaluatable

from .output_handler import OutputHandler

class CalibrationSet(Evaluatable):
"""
A HY_Features based catchment with additional calibration information/functionality
"""

def __init__(self, adjustables: Sequence[Adjustable], eval_nexus: Nexus, routing_output: 'Path', start_time: str, end_time: str, eval_params: 'EvaluationOptions'):
def __init__(self, adjustables: Sequence[Adjustable], eval_nexus: Nexus, hooks: HookRelay, start_time: str, end_time: str, eval_params: 'EvaluationOptions'):
"""
"""
super().__init__(eval_params)
self._eval_nexus = eval_nexus
self._adjustables = adjustables
self._output_file = routing_output
#TODO this becomes a function handler to a plugin
self._output_hook = hooks.ngen_cal_model_output
self._post_hook = hooks.ngen_cal_model_post_iteration

#TODO this becomes a function handler to an InputHandler
# the model class configures this function based on possible available plugins
#use the nwis location to get observation data
obs =self._eval_nexus._hydro_location.get_data(start_time, end_time)
#make sure data is hourly
Expand All @@ -53,40 +59,22 @@ def output(self) -> 'DataFrame':
This re-reads the output file each call, as the output for given calibration catchment changes
for each calibration iteration. If it doesn't exist, should return None
"""
try:
#in this case, look for routed data
#this is really model specific, so not as generalizable the way this
#is coded right now =(
#would be better to hook this from the model object???
#read the routed flow at the eval_nexus
df = pd.read_csv(self._output_file, index_col=0)
df.index = df.index.map(lambda x: 'wb-'+str(x))
tuples = [ eval(x) for x in df.columns ]
df.columns = pd.MultiIndex.from_tuples(tuples)
# TODO should contributing_catchments be singular??? assuming it is for now...
df = df.loc[self._eval_nexus.contributing_catchments[0].replace('cat', 'wb')]
self._output = df.xs('q', level=1, drop_level=False)
#This is a hacky way to get the time index...pass the time around???
tnx_file = list(Path(self._output_file).parent.glob("nex*.csv"))[0]
tnx_df = pd.read_csv(tnx_file, index_col=0, parse_dates=[1], names=['ts', 'time', 'Q']).set_index('time')
dt_range = pd.date_range(tnx_df.index[0], tnx_df.index[-1], len(self._output.index)).round('min')
self._output.index = dt_range
#this may not be strictly nessicary...I think the _evalutate will align these...
self._output = self._output.resample('1H').first()
self._output.name="sim_flow"
# self._output = read_csv(self._output_file, usecols=["Time", self._output_var], parse_dates=['Time'], index_col='Time', dtype={self._output_var: 'float64'})
# self._output.rename(columns={self._output_var:'sim_flow'}, inplace=True)
hydrograph = self._output

except FileNotFoundError:
print("{} not found. Current working directory is {}".format(self._output_file, os.getcwd()))
print("Setting output to None")
hydrograph = None
except Exception as e:
raise(e)
#if hydrograph is None:
# raise(RuntimeError("Error reading output: {}".format(self._output_file)))
return hydrograph
# TODO should contributing_catchments be singular??? assuming it is for now...
#TODO call output handler plugin function here...
#TODO call plugin, then resample time...
df = self._output_hook(id=self._eval_nexus.contributing_catchments[0].replace('cat', 'wb'))
print(df)
print(type(df))
if not df:
# list of results is empty
print("No suitable output found from output hooks...")
df = None
elif len(df) > 1:
warnings.warn("Multiple output data found, using first registered")
df = df[0]
else:
df = df[0]
return df

@output.setter
def output(self, df):
Expand All @@ -106,20 +94,22 @@ def observed(self) -> 'DataFrame':
def observed(self, df):
self._observed = df

def save_output(self, i) -> None:
"""
Save the last output to output for iteration i
"""
#FIXME ensure _output_file exists
#FIXME re-enable this once more complete
shutil.move(self._output_file, '{}_last'.format(self._output_file))
# def save_output(self, i) -> None:
# """
# Save the last output to output for iteration i
# """
# #FIXME ensure _output_file exists
# #FIXME re-enable this once more complete
# shutil.move(self._output_file, '{}_last'.format(self._output_file))

def check_point(self, path: 'Path') -> None:
def check_point(self, path: 'Path', iteration: int) -> None:
"""
Save calibration information
"""
for adjustable in self.adjustables:
adjustable.df.to_parquet(path/adjustable.check_point_file)
# call any model post hooks
self._post_hook(path = path, iteration = iteration)

def restart(self) -> int:
try:
Expand All @@ -134,11 +124,11 @@ class UniformCalibrationSet(CalibrationSet, Adjustable):
A HY_Features based catchment with additional calibration information/functionality
"""

def __init__(self, eval_nexus: Nexus, routing_output: 'Path', start_time: str, end_time: str, eval_params: 'EvaluationOptions', params: dict = {}):
def __init__(self, eval_nexus: Nexus, output: HookRelay, start_time: str, end_time: str, eval_params: 'EvaluationOptions', params: dict = {}):
"""
"""
super().__init__(adjustables=[self], eval_nexus=eval_nexus, routing_output=routing_output, start_time=start_time, end_time=end_time, eval_params=eval_params)
super().__init__(adjustables=[self], eval_nexus=eval_nexus, output=output, start_time=start_time, end_time=end_time, eval_params=eval_params)
Adjustable.__init__(self=self, df=DataFrame(params).rename(columns={'init': '0'}))

#For now, set this to None so meta update does the right thing
Expand All @@ -153,13 +143,13 @@ def id(self) -> str:
"""
return self._id

def save_output(self, i) -> None:
"""
Save the last output to output for iteration i
"""
#FIXME ensure _output_file exists
#FIXME re-enable this once more complete
shutil.move(self._output_file, '{}_last'.format(self._output_file))
# def save_output(self, i) -> None:
# """
# Save the last output to output for iteration i
# """
# #FIXME ensure _output_file exists
# #FIXME re-enable this once more complete
# shutil.move(self._output_file, '{}_last'.format(self._output_file))

#update handled in meta, TODO remove this method???
def update_params(self, iteration: int) -> None:
Expand Down
26 changes: 25 additions & 1 deletion python/ngen_cal/src/ngen/cal/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pydantic import BaseModel, DirectoryPath, conint, PyObject, validator, Field, root_validator
from typing import Optional, Tuple, Union
from typing import Any, cast, Callable, Dict, List, Optional, Tuple, Union
from types import ModuleType, FunctionType
try: #to get literal in python 3.7, it was added to typing in 3.8
from typing import Literal
except ImportError:
Expand All @@ -8,6 +9,10 @@
from pathlib import Path
from abc import ABC, abstractmethod
from .strategy import Objective
from ngen.cal._plugin_system import setup_plugin_manager
from .utils import PyObjectOrModule, type_as_import_string
from pluggy import PluginManager
from ._plugin_system import before, after
# additional constrained types
PosInt = conint(gt=-1)

Expand Down Expand Up @@ -213,7 +218,26 @@ class ModelExec(BaseModel, Configurable):
args: Optional[str]
workdir: DirectoryPath = Path("./") #FIXME test the various workdirs
eval_params: Optional[EvaluationOptions] = Field(default_factory=EvaluationOptions)
output_plugin: List[PyObjectOrModule] = Field(default_factory=list)
plugin_settings: Dict[str, Dict[str, Any]] = Field(default_factory=dict)

_plugin_manager: PluginManager
class Config(BaseModel.Config):
# properly serialize plugins
json_encoders = {
type: type_as_import_string,
ModuleType: lambda mod: mod.__name__,
FunctionType: type_as_import_string,
}
underscore_attrs_are_private = True

def __init__(self, *args, **kwargs):
super().__init__(**kwargs)
model_plugins = cast(List[Union[Callable, ModuleType]], self.output_plugin)
self._plugin_manager = setup_plugin_manager(model_plugins)
self._plugin_manager.add_hookcall_monitoring(before, after)

#TODO use a model exec
#FIXME formalize type: str = "ModelName"
def get_binary(self)->str:
"""Get the binary string to execute
Expand Down
23 changes: 19 additions & 4 deletions python/ngen_cal/src/ngen/cal/ngen.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from .parameter import Parameter, Parameters
from .calibration_cathment import CalibrationCatchment, AdjustableCatchment
from .calibration_set import CalibrationSet, UniformCalibrationSet
#HyFeatures components
from .ngen_output import TrouteOutput, NgenSaveOutput
# HyFeatures components
from hypy.hydrolocation import NWISLocation # type: ignore
from hypy.nexus import Nexus # type: ignore

Expand Down Expand Up @@ -81,6 +82,7 @@ class NgenBase(ModelExec):
nexus: Optional[FilePath]
crosswalk: Optional[FilePath]
ngen_realization: Optional[NgenRealization]
#TODO move output from file to plugin, implement default via stubbed TrouteOutput handler
routing_output: Optional[Path] = Field(default=Path("flowveldepth_Ngen.csv"))
#optional fields
partitions: Optional[FilePath]
Expand Down Expand Up @@ -108,6 +110,8 @@ def __init__(self, **kwargs):
#Let pydantic work its magic
super().__init__(**kwargs)
#now we work ours
self._plugin_manager.register(TrouteOutput(self.routing_output))
self._plugin_manager.register(NgenSaveOutput())
#Make a copy of the config file, just in case
shutil.copy(self.realization, str(self.realization)+'_original')

Expand Down Expand Up @@ -341,6 +345,9 @@ def __init__(self, **kwargs):
#TODO define these extra params in the realization config and parse them out explicity per catchment, cause why not?
eval_params = self.eval_params.copy()
eval_params.id = id
#TODO can this become a CalibrationSet
#and essentially have a list of sets of size 1, with each set
#configured with an OutputHandler related to the explicit catchment output?
self._catchments.append(CalibrationCatchment(self.workdir, id, nexus, start_t, end_t, fabric, output_var, eval_params, params))

def update_config(self, i: int, params: 'pd.DataFrame', id: str, **kwargs):
Expand Down Expand Up @@ -384,6 +391,8 @@ def __init__(self, **kwargs):
#Need to fix the forcing definition or ngen will not work
#for individual catchment configs, it doesn't apply pattern resolution
#and will read the directory `path` key as the file key and will segfault
# FIXME handle netcdf inputs??? Check ngen to see what happens when catchment overirdes
# try to use a netcdf
pattern = catchment_realizations[id].forcing.file_pattern
path = catchment_realizations[id].forcing.path
catchment_realizations[id].forcing.file_pattern = None
Expand Down Expand Up @@ -436,8 +445,14 @@ def __init__(self, **kwargs):
break

if len(eval_nexus) != 1:
raise RuntimeError( "Currently only a single nexus in the hydrfabric can be gaged, set the eval_feature key to pick one.")
self._catchments.append(CalibrationSet(catchments, eval_nexus[0], self.routing_output, start_t, end_t, self.eval_params))
raise RuntimeError( "Currently only a single nexus in the hydrfabric can be gaged, set the eval_feature key to pick one.")
#TODO pass CalibrationSet an ObsHandler and an OutputHandler
# These need to be configured based on possible user plugin specs
# with defaults based on the current usage...
# so eval_nexus[0] simply becomes ObsHandler and self.routing_output will be the OutputHandler
# TODO probably want a configureOutput function in the NgenBase so once each strategy identifies

self._catchments.append(CalibrationSet(catchments, eval_nexus[0], self._plugin_manager.hook, start_t, end_t, self.eval_params))

def _strip_global_params(self) -> None:
module = self.ngen_realization.global_config.formulations[0].params
Expand Down Expand Up @@ -493,7 +508,7 @@ def __init__(self, **kwargs):
if len(eval_nexus) != 1:
raise RuntimeError( "Currently only a single nexus in the hydrfabric can be gaged, set the eval_feature key to pick one.")
params = _params_as_df(self.params)
self._catchments.append(UniformCalibrationSet(eval_nexus=eval_nexus[0], routing_output=self.routing_output, start_time=start_t, end_time=end_t, eval_params=self.eval_params, params=params))
self._catchments.append(UniformCalibrationSet(eval_nexus=eval_nexus[0], output=self._plugin_manager.hook.ngen_cal_model_output, start_time=start_t, end_time=end_t, eval_params=self.eval_params, params=params))

class Ngen(BaseModel, Configurable, smart_union=True):
__root__: Union[NgenExplicit, NgenIndependent, NgenUniform] = Field(discriminator="strategy")
Expand Down
Loading

0 comments on commit d677838

Please sign in to comment.