Skip to content

Commit

Permalink
metric constraint (#90)
Browse files Browse the repository at this point in the history
* penalty change

* metric modification

* catboost init
  • Loading branch information
sonichi authored May 22, 2021
1 parent 0925e2b commit b206363
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 8 deletions.
1 change: 1 addition & 0 deletions flaml/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,6 +922,7 @@ def custom_metric(X_test, y_test, estimator, labels,
# set up learner search space
for estimator_name in estimator_list:
estimator_class = self._state.learner_classes[estimator_name]
estimator_class.init()
self._search_states[estimator_name] = SearchState(
learner_class=estimator_class,
data_size=self._state.data_size, task=self._state.task,
Expand Down
10 changes: 10 additions & 0 deletions flaml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def cost_relative2lgbm(cls):
'''[optional method] relative cost compared to lightgbm'''
return 1.0

@classmethod
def init(cls):
'''[optional method] initialize the class'''
pass


class SKLearnEstimator(BaseEstimator):

Expand Down Expand Up @@ -632,6 +637,11 @@ def size(cls, config):
def cost_relative2lgbm(cls):
return 15

@classmethod
def init(cls):
CatBoostEstimator._time_per_iter = None
CatBoostEstimator._train_size = 0

def __init__(
self, task='binary:logistic', n_jobs=1,
n_estimators=8192, learning_rate=0.1, early_stopping_rounds=4, **params
Expand Down
52 changes: 44 additions & 8 deletions flaml/searcher/blendsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class BlendSearch(Searcher):
'''

cost_attr = "time_total_s" # cost attribute in result
lagrange = '_lagrange' # suffix for lagrange-modified metric
penalty = 1e+10 # penalty term for constraints

def __init__(self,
metric: Optional[str] = None,
Expand Down Expand Up @@ -106,6 +108,11 @@ def __init__(self,
self._metric, self._mode = metric, mode
init_config = low_cost_partial_config or {}
self._points_to_evaluate = points_to_evaluate or []
self._config_constraints = config_constraints
self._metric_constraints = metric_constraints
if self._metric_constraints:
# metric modified by lagrange
metric += self.lagrange
if global_search_alg is not None:
self._gs = global_search_alg
elif getattr(self, '__name__', None) != 'CFO':
Expand All @@ -115,8 +122,6 @@ def __init__(self,
self._ls = LocalSearch(
init_config, metric, mode, cat_hp_cost, space,
prune_attr, min_resource, max_resource, reduction_factor, seed)
self._config_constraints = config_constraints
self._metric_constraints = metric_constraints
self._init_search()

def set_search_properties(self,
Expand All @@ -131,6 +136,11 @@ def set_search_properties(self,
else:
if metric:
self._metric = metric
if self._metric_constraints:
# metric modified by lagrange
metric += self.lagrange
# TODO: don't change metric for global search methods that
# can handle constraints already
if mode:
self._mode = mode
self._ls.set_search_properties(metric, mode, config)
Expand All @@ -156,6 +166,13 @@ def _init_search(self):
self._gs_admissible_max = self._ls_bound_max.copy()
self._result = {} # config_signature: tuple -> result: Dict
self._deadline = np.inf
if self._metric_constraints:
self._metric_constraint_satisfied = False
self._metric_constraint_penalty = [
self.penalty for _ in self._metric_constraints]
else:
self._metric_constraint_satisfied = True
self._metric_constraint_penalty = None

def save(self, checkpoint_path: str):
save_object = self
Expand All @@ -182,6 +199,8 @@ def restore(self, checkpoint_path: str):
self._ls = state._ls
self._config_constraints = state._config_constraints
self._metric_constraints = state._metric_constraints
self._metric_constraint_satisfied = state._metric_constraint_satisfied
self._metric_constraint_penalty = state._metric_constraint_penalty

def restore_from_dir(self, checkpoint_dir: str):
super.restore_from_dir(checkpoint_dir)
Expand All @@ -190,10 +209,11 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
error: bool = False):
''' search thread updater and cleaner
'''
metric_constraint_satisfied = True
if result and not error and self._metric_constraints:
# accout for metric constraints if any
# account for metric constraints if any
objective = result[self._metric]
for constraint in self._metric_constraints:
for i, constraint in enumerate(self._metric_constraints):
metric_constraint, sign, threshold = constraint
value = result.get(metric_constraint)
if value:
Expand All @@ -202,8 +222,16 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
violation = (value - threshold) * sign_op
if violation > 0:
# add penalty term to the metric
objective += 1e+10 * violation * self._ls.metric_op
result[self._metric] = objective
objective += self._metric_constraint_penalty[
i] * violation * self._ls.metric_op
metric_constraint_satisfied = False
if self._metric_constraint_penalty[i] < self.penalty:
self._metric_constraint_penalty[i] += violation
result[self._metric + self.lagrange] = objective
if metric_constraint_satisfied and not self._metric_constraint_satisfied:
# found a feasible point
self._metric_constraint_penalty = [1 for _ in self._metric_constraints]
self._metric_constraint_satisfied |= metric_constraint_satisfied
thread_id = self._trial_proposed_by.get(trial_id)
if thread_id in self._search_thread_pool:
self._search_thread_pool[thread_id].on_trial_complete(
Expand All @@ -219,10 +247,13 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
else: # add to result cache
self._result[self._ls.config_signature(config)] = result
# update target metric if improved
objective = result[self._metric]
objective = result[
self._metric + self.lagrange] if self._metric_constraints \
else result[self._metric]
if (objective - self._metric_target) * self._ls.metric_op < 0:
self._metric_target = objective
if not thread_id and self._create_condition(result):
if not thread_id and metric_constraint_satisfied \
and self._create_condition(result):
# thread creator
self._search_thread_pool[self._thread_count] = SearchThread(
self._ls.mode,
Expand All @@ -233,6 +264,9 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None,
self._thread_count += 1
self._update_admissible_region(
config, self._ls_bound_min, self._ls_bound_max)
elif thread_id and not self._metric_constraint_satisfied:
# no point has been found to satisfy metric constraint
self._expand_admissible_region()
# reset admissible region to ls bounding box
self._gs_admissible_min.update(self._ls_bound_min)
self._gs_admissible_max.update(self._ls_bound_max)
Expand Down Expand Up @@ -306,6 +340,8 @@ def on_trial_result(self, trial_id: str, result: Dict):
thread_id = self._trial_proposed_by[trial_id]
if thread_id not in self._search_thread_pool:
return
if result and self._metric_constraints:
result[self._metric + self.lagrange] = result[self._metric]
self._search_thread_pool[thread_id].on_trial_result(trial_id, result)

def suggest(self, trial_id: str) -> Optional[Dict]:
Expand Down

0 comments on commit b206363

Please sign in to comment.