Skip to content

Commit

Permalink
v0.2.6 (#32)
Browse files Browse the repository at this point in the history
* xgboost notebook

* finetuning notebook

* finetuning test

* experimental nni support

* support nested search space

* log file name

* record training_iteration

* eps

* reset times

* std set to default step size if 0
  • Loading branch information
sonichi authored Feb 28, 2021
1 parent 6ff0ed4 commit 7bd231e
Show file tree
Hide file tree
Showing 12 changed files with 1,370 additions and 220 deletions.
4 changes: 2 additions & 2 deletions flaml/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ def custom_metric(X_test, y_test, estimator, labels,
if eval_method == 'auto' or self._state.X_val is not None:
eval_method = self._decide_eval_method(time_budget)
self._state.eval_method = eval_method
if not mlflow or not mlflow.active_run() and not logger.handler:
if (not mlflow or not mlflow.active_run()) and not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler()
_ch.setFormatter(logger_formatter)
Expand Down Expand Up @@ -1074,7 +1074,7 @@ def _search(self):
search_state.best_config,
estimator,
search_state.sample_size)
if mlflow is not None:
if mlflow is not None and mlflow.active_run():
with mlflow.start_run(nested=True) as run:
mlflow.log_metric('iter_counter',
self._iter_per_learner[estimator])
Expand Down
90 changes: 88 additions & 2 deletions flaml/searcher/blendsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class BlendSearch(Searcher):
'''class for BlendSearch algorithm
'''

cost_attr = "time_total_s" # cost attribute in result

def __init__(self,
metric: Optional[str] = None,
mode: Optional[str] = None,
Expand Down Expand Up @@ -193,7 +195,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
self._search_thread_pool[self._thread_count] = SearchThread(
self._ls.mode,
self._ls.create(config, result[self._metric], cost=result[
"time_total_s"])
self.cost_attr])
)
thread_id = self._thread_count
self._thread_count += 1
Expand Down Expand Up @@ -393,7 +395,89 @@ def _valid(self, config: Dict) -> bool:
return True


class CFO(BlendSearch):
try:
from nni.tuner import Tuner as NNITuner
from nni.utils import extract_scalar_reward
try:
from ray.tune import (uniform, quniform, choice, randint, qrandint, randn,
qrandn, loguniform, qloguniform)
except:
from .sample import (uniform, quniform, choice, randint, qrandint, randn,
qrandn, loguniform, qloguniform)

class BlendSearchTuner(BlendSearch, NNITuner):
'''Tuner class for NNI
'''

def receive_trial_result(self, parameter_id, parameters, value,
**kwargs):
'''
Receive trial's final result.
parameter_id: int
parameters: object created by 'generate_parameters()'
value: final metrics of the trial, including default metric
'''
result = {}
for key, value in parameters:
result['config/'+key] = value
reward = extract_scalar_reward(value)
result[self._metric] = reward
# if nni does not report training cost,
# using sequence as an approximation.
# if no sequence, using a constant 1
result[self.cost_attr] = value.get(self.cost_attr, value.get(
'sequence', 1))
self.on_trial_complete(str(parameter_id), result)
...

def generate_parameters(self, parameter_id, **kwargs) -> Dict:
'''
Returns a set of trial (hyper-)parameters, as a serializable object
parameter_id: int
'''
return self.suggest(str(parameter_id))
...

def update_search_space(self, search_space):
'''
Tuners are advised to support updating search space at run-time.
If a tuner can only set search space once before generating first hyper-parameters,
it should explicitly document this behaviour.
search_space: JSON object created by experiment owner
'''
config = {}
for key, value in search_space:
v = value.get("_value")
_type = value['_type']
if _type == 'choice':
config[key] = choice(v)
elif _type == 'randint':
config[key] = randint(v[0], v[1]-1)
elif _type == 'uniform':
config[key] = uniform(v[0], v[1])
elif _type == 'quniform':
config[key] = quniform(v[0], v[1], v[2])
elif _type == 'loguniform':
config[key] = loguniform(v[0], v[1])
elif _type == 'qloguniform':
config[key] = qloguniform(v[0], v[1], v[2])
elif _type == 'normal':
config[key] = randn(v[1], v[2])
elif _type == 'qnormal':
config[key] = qrandn(v[1], v[2], v[3])
else:
raise ValueError(
f'unsupported type in search_space {_type}')
self._ls.set_search_properties(None, None, config)
if self._gs is not None:
self._gs.set_search_properties(None, None, config)
self._init_search()

except:
class BlendSearchTuner(BlendSearch): pass


class CFO(BlendSearchTuner):
''' class for CFO algorithm
'''

Expand All @@ -416,3 +500,5 @@ def _create_condition(self, result: Dict) -> bool:
''' create thread condition
'''
return len(self._search_thread_pool) < 2


35 changes: 20 additions & 15 deletions flaml/searcher/flow2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
from ray.tune.suggest import Searcher
from ray.tune.suggest.variant_generator import generate_variants
from ray.tune import sample
from ray.tune.utils.util import flatten_dict, unflatten_dict
except ImportError:
from .suggestion import Searcher
from .variant_generator import generate_variants
from .variant_generator import generate_variants, flatten_dict, unflatten_dict
from ..tune import sample


Expand Down Expand Up @@ -86,6 +87,7 @@ def __init__(self,
elif mode == "min":
self.metric_op = 1.
self.space = space or {}
self.space = flatten_dict(self.space, prevent_delimiter=True)
self._random = np.random.RandomState(seed)
self._seed = seed
if not init_config:
Expand All @@ -95,7 +97,8 @@ def __init__(self,
"consider providing init values for cost-related hps via "
"'init_config'."
)
self.init_config = self.best_config = init_config
self.init_config = init_config
self.best_config = flatten_dict(init_config)
self.cat_hp_cost = cat_hp_cost
self.prune_attr = prune_attr
self.min_resource = min_resource
Expand Down Expand Up @@ -171,7 +174,7 @@ def _init_search(self):
# logger.info(self._resource)
else: self._resource = None
self.incumbent = {}
self.incumbent = self.normalize(self.init_config)
self.incumbent = self.normalize(self.best_config) # flattened
self.best_obj = self.cost_incumbent = None
self.dim = len(self._tunable_keys) # total # tunable dimensions
self._direction_tried = None
Expand Down Expand Up @@ -247,7 +250,7 @@ def complete_config(self, partial_config: Dict,
if key not in self._unordered_cat_hp:
if upper and lower:
u, l = upper[key], lower[key]
gauss_std = u-l
gauss_std = u-l or self.STEPSIZE
# allowed bound
u += self.STEPSIZE
l -= self.STEPSIZE
Expand All @@ -261,11 +264,11 @@ def complete_config(self, partial_config: Dict,
normalized[key] = max(l, min(u, normalized[key] + delta))
# use best config for unordered cat choice
config = self.denormalize(normalized)
self._reset_times += 1
else:
# first time init_config, or other configs, take as is
config = partial_config.copy()

if partial_config == self.init_config: self._reset_times += 1
config = flatten_dict(config)
for key, value in self.space.items():
if key not in config:
config[key] = value
Expand All @@ -277,13 +280,13 @@ def complete_config(self, partial_config: Dict,

if self._resource:
config[self.prune_attr] = self.min_resource
return config
return unflatten_dict(config)

def create(self, init_config: Dict, obj: float, cost: float) -> Searcher:
flow2 = FLOW2(init_config, self.metric, self.mode, self._cat_hp_cost,
self.space, self.prune_attr, self.min_resource,
self.max_resource, self.resource_multiple_factor,
self._seed+1)
unflatten_dict(self.space), self.prune_attr,
self.min_resource, self.max_resource,
self.resource_multiple_factor, self._seed+1)
flow2.best_obj = obj * self.metric_op # minimize internally
flow2.cost_incumbent = cost
return flow2
Expand All @@ -292,7 +295,7 @@ def normalize(self, config) -> Dict:
''' normalize each dimension in config to [0,1]
'''
config_norm = {}
for key, value in config.items():
for key, value in flatten_dict(config).items():
if key in self.space:
# domain: sample.Categorical/Integer/Float/Function
domain = self.space[key]
Expand Down Expand Up @@ -426,7 +429,7 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
obj = result.get(self._metric)
if obj:
obj *= self.metric_op
if obj < self.best_obj:
if self.best_obj is None or obj < self.best_obj:
self.best_obj, self.best_config = obj, self._configs[
trial_id]
self.incumbent = self.normalize(self.best_config)
Expand All @@ -437,7 +440,8 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
self._cost_complete4incumbent = 0
self._num_allowed4incumbent = 2 * self.dim
self._proposed_by.clear()
if self._K > 0:
if self._K > 0:
# self._oldK must have been set when self._K>0
self.step *= np.sqrt(self._K/self._oldK)
if self.step > self.step_ub: self.step = self.step_ub
self._iter_best_config = self.trial_count
Expand Down Expand Up @@ -474,7 +478,7 @@ def on_trial_result(self, trial_id: str, result: Dict):
obj = result.get(self._metric)
if obj:
obj *= self.metric_op
if obj < self.best_obj:
if self.best_obj is None or obj < self.best_obj:
self.best_obj = obj
config = self._configs[trial_id]
if self.best_config != config:
Expand Down Expand Up @@ -533,7 +537,7 @@ def suggest(self, trial_id: str) -> Optional[Dict]:
config = self.denormalize(move)
self._proposed_by[trial_id] = self.incumbent
self._configs[trial_id] = config
return config
return unflatten_dict(config)

def _project(self, config):
''' project normalized config in the feasible region and set prune_attr
Expand All @@ -553,6 +557,7 @@ def can_suggest(self) -> bool:
def config_signature(self, config) -> tuple:
''' return the signature tuple of a config
'''
config = flatten_dict(config)
value_list = []
for key in self._space_keys:
if key in config:
Expand Down
3 changes: 2 additions & 1 deletion flaml/searcher/search_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class SearchThread:
'''

cost_attr = 'time_total_s'
eps = 1e-10

def __init__(self, mode: str = "min",
search_alg: Optional[Searcher] = None):
Expand Down Expand Up @@ -70,7 +71,7 @@ def _update_speed(self):
# calculate speed; use 0 for invalid speed temporarily
if self.obj_best2 > self.obj_best1:
self.speed = (self.obj_best2 - self.obj_best1) / (
self.cost_total - self.cost_best2)
self.cost_total - self.cost_best2 + self.eps)
else: self.speed = 0

def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
Expand Down
40 changes: 40 additions & 0 deletions flaml/searcher/variant_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,46 @@
logger = logging.getLogger(__name__)


def flatten_dict(dt, delimiter="/", prevent_delimiter=False):
dt = copy.deepcopy(dt)
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
raise ValueError(
"Found delimiter `{}` in key when trying to flatten array."
"Please avoid using the delimiter in your specification.")
while any(isinstance(v, dict) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
raise ValueError(
"Found delimiter `{}` in key when trying to "
"flatten array. Please avoid using the delimiter "
"in your specification.")
add[delimiter.join([key, str(subkey)])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt


def unflatten_dict(dt, delimiter="/"):
"""Unflatten dict. Does not support unflattening lists."""
dict_type = type(dt)
out = dict_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for k in path[:-1]:
item = item.setdefault(k, dict_type())
item[path[-1]] = val
return out


class TuneError(Exception):
"""General error class raised by ray.tune."""
pass
Expand Down
12 changes: 11 additions & 1 deletion flaml/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
_use_ray = True
_runner = None
_verbose = 0
_running_trial = None
_training_iteration = 0


class ExperimentAnalysis(EA):
Expand Down Expand Up @@ -68,6 +70,8 @@ def compute_with_config(config):
'''
global _use_ray
global _verbose
global _running_trial
global _training_iteration
if _use_ray:
from ray import tune
return tune.report(_metric, **kwargs)
Expand All @@ -77,6 +81,12 @@ def compute_with_config(config):
logger.info(f"result: {kwargs}")
if _metric: result['_default_anonymous_metric'] = _metric
trial = _runner.running_trial
if _running_trial == trial:
_training_iteration += 1
else:
_training_iteration = 0
_running_trial = trial
result["training_iteration"] = _training_iteration
result['config'] = trial.config
for key, value in trial.config.items():
result['config/'+key] = value
Expand Down Expand Up @@ -213,7 +223,7 @@ def compute_with_config(config):
import os
os.makedirs(local_dir, exist_ok=True)
logger.addHandler(logging.FileHandler(local_dir+'/tune_'+str(
datetime.datetime.now())+'.log'))
datetime.datetime.now()).replace(':', '-')+'.log'))
if verbose<=2:
logger.setLevel(logging.INFO)
else:
Expand Down
2 changes: 1 addition & 1 deletion flaml/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.2.5"
__version__ = "0.2.6"
Loading

0 comments on commit 7bd231e

Please sign in to comment.