Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trial #20

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions fugue_tune/trial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import copy
from threading import RLock
from typing import Any, Callable, Dict, List
from uuid import uuid4

from fugue.rpc import RPCHandler


class TrialsTracker(RPCHandler):
def __init__(self):
super().__init__()
self._tt_lock = RLock()
self._raw_data: Dict[str, List[Dict[str, Any]]] = {}

def get_raw_data(self) -> Dict[str, List[Dict[str, Any]]]:
with self._tt_lock:
return copy.deepcopy(self._raw_data)

def __call__(self, method: str, **kwargs: Any) -> Any:
return getattr(self, method)(**kwargs)

def log_trial(self, trial_id: str, **kwargs: Any) -> None:
with self._tt_lock:
if trial_id not in self._raw_data:
self._raw_data[trial_id] = [dict(kwargs)]
else:
self._raw_data[trial_id].append(dict(kwargs))

def prune(self, trial_id: str) -> bool:
return False


class TrialCallback(object):
def __init__(self, callback: Callable):
self._trial_id = str(uuid4())
self._callback = callback

@property
def trial_id(self) -> str:
return self._trial_id

def log_trial(self, **kwargs: Any) -> None:
self._callback(method="log_trial", trial_id=self.trial_id, **kwargs)

def __getattr__(self, method: str) -> Callable:
def _wrapper(**kwargs: Any) -> Any:
return self._callback(method=method, trial_id=self.trial_id, **kwargs)

return _wrapper
8 changes: 6 additions & 2 deletions fugue_tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, no_type_check
from uuid import uuid4

import matplotlib.pyplot as plt
import pandas as pd
from fugue import (
ArrayDataFrame,
Expand All @@ -30,7 +31,7 @@

from fugue_tune.exceptions import FugueTuneCompileError, FugueTuneRuntimeError
from fugue_tune.space import Space, decode
import matplotlib.pyplot as plt
from fugue_tune.trial import TrialCallback, TrialsTracker


class Tunable(object):
Expand Down Expand Up @@ -301,6 +302,7 @@ def tune( # noqa: C901
tunable: Any,
distributable: Optional[bool] = None,
objective_runner: Optional[ObjectiveRunner] = None,
tracker: Optional[TrialsTracker] = None,
) -> WorkflowDataFrame:
t = _to_tunable( # type: ignore
tunable, *get_caller_global_local_vars(), distributable
Expand All @@ -313,7 +315,9 @@ def tune( # noqa: C901

# input_has: __fmin_params__:str
# schema: *,__fmin_value__:double,__fmin_metadata__:str
def compute_transformer(df: Iterable[Dict[str, Any]]) -> Iterable[Dict[str, Any]]:
def compute_transformer(
df: Iterable[Dict[str, Any]], cb: Optional[Callable] = None
) -> Iterable[Dict[str, Any]]:
for row in df:
dfs: Dict[str, Any] = {}
dfs_keys: Set[str] = set()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def get_version() -> str:
author_email="[email protected]",
keywords="fugue incubator experiment",
url="http://github.com/fugue-project/fugue-incubator",
install_requires=["fugue>=0.5.0", "scikit-learn", "matplotlib"],
install_requires=["fugue==0.5.1.dev0", "scikit-learn", "matplotlib", "flask"],
extras_require={"hyperopt": ["hyperopt"], "all": ["hyperopt"]},
classifiers=[
# "3 - Alpha", "4 - Beta" or "5 - Production/Stable"
Expand Down
11 changes: 11 additions & 0 deletions tests/fugue_tune/test_trial.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from fugue_tune.trial import TrialCallback, TrialsTracker


def test_callback():
tt = TrialsTracker()
tb = TrialCallback(tt)
tb.log_trial(a=1, b=2)
tb.log_trial(a=2, b=3)
key = tb.trial_id
assert [{"a": 1, "b": 2}, {"a": 2, "b": 3}] == tt.get_raw_data()[key]
assert not tb.prune()