diff --git a/fugue_tune/trial.py b/fugue_tune/trial.py new file mode 100644 index 0000000..38e87e8 --- /dev/null +++ b/fugue_tune/trial.py @@ -0,0 +1,49 @@ +import copy +from threading import RLock +from typing import Any, Callable, Dict, List +from uuid import uuid4 + +from fugue.rpc import RPCHandler + + +class TrialsTracker(RPCHandler): + def __init__(self): + super().__init__() + self._tt_lock = RLock() + self._raw_data: Dict[str, List[Dict[str, Any]]] = {} + + def get_raw_data(self) -> Dict[str, List[Dict[str, Any]]]: + with self._tt_lock: + return copy.deepcopy(self._raw_data) + + def __call__(self, method: str, **kwargs: Any) -> Any: + return getattr(self, method)(**kwargs) + + def log_trial(self, trial_id: str, **kwargs: Any) -> None: + with self._tt_lock: + if trial_id not in self._raw_data: + self._raw_data[trial_id] = [dict(kwargs)] + else: + self._raw_data[trial_id].append(dict(kwargs)) + + def prune(self, trial_id: str) -> bool: + return False + + +class TrialCallback(object): + def __init__(self, callback: Callable): + self._trial_id = str(uuid4()) + self._callback = callback + + @property + def trial_id(self) -> str: + return self._trial_id + + def log_trial(self, **kwargs: Any) -> None: + self._callback(method="log_trial", trial_id=self.trial_id, **kwargs) + + def __getattr__(self, method: str) -> Callable: + def _wrapper(**kwargs: Any) -> Any: + return self._callback(method=method, trial_id=self.trial_id, **kwargs) + + return _wrapper diff --git a/fugue_tune/tune.py b/fugue_tune/tune.py index 5fd558c..998d75f 100644 --- a/fugue_tune/tune.py +++ b/fugue_tune/tune.py @@ -6,6 +6,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, no_type_check from uuid import uuid4 +import matplotlib.pyplot as plt import pandas as pd from fugue import ( ArrayDataFrame, @@ -30,7 +31,7 @@ from fugue_tune.exceptions import FugueTuneCompileError, FugueTuneRuntimeError from fugue_tune.space import Space, decode -import matplotlib.pyplot as plt +from fugue_tune.trial import TrialCallback, TrialsTracker class Tunable(object): @@ -301,6 +302,7 @@ def tune( # noqa: C901 tunable: Any, distributable: Optional[bool] = None, objective_runner: Optional[ObjectiveRunner] = None, + tracker: Optional[TrialsTracker] = None, ) -> WorkflowDataFrame: t = _to_tunable( # type: ignore tunable, *get_caller_global_local_vars(), distributable @@ -313,7 +315,9 @@ def tune( # noqa: C901 # input_has: __fmin_params__:str # schema: *,__fmin_value__:double,__fmin_metadata__:str - def compute_transformer(df: Iterable[Dict[str, Any]]) -> Iterable[Dict[str, Any]]: + def compute_transformer( + df: Iterable[Dict[str, Any]], cb: Optional[Callable] = None + ) -> Iterable[Dict[str, Any]]: for row in df: dfs: Dict[str, Any] = {} dfs_keys: Set[str] = set() diff --git a/setup.py b/setup.py index 724e980..0e68ef9 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def get_version() -> str: author_email="goodwanghan@gmail.com", keywords="fugue incubator experiment", url="http://github.com/fugue-project/fugue-incubator", - install_requires=["fugue>=0.5.0", "scikit-learn", "matplotlib"], + install_requires=["fugue==0.5.1.dev0", "scikit-learn", "matplotlib", "flask"], extras_require={"hyperopt": ["hyperopt"], "all": ["hyperopt"]}, classifiers=[ # "3 - Alpha", "4 - Beta" or "5 - Production/Stable" diff --git a/tests/fugue_tune/test_trial.py b/tests/fugue_tune/test_trial.py new file mode 100644 index 0000000..a246986 --- /dev/null +++ b/tests/fugue_tune/test_trial.py @@ -0,0 +1,11 @@ +from fugue_tune.trial import TrialCallback, TrialsTracker + + +def test_callback(): + tt = TrialsTracker() + tb = TrialCallback(tt) + tb.log_trial(a=1, b=2) + tb.log_trial(a=2, b=3) + key = tb.trial_id + assert [{"a": 1, "b": 2}, {"a": 2, "b": 3}] == tt.get_raw_data()[key] + assert not tb.prune()