From b206363c9a412728df0fabebf441e3b02b41e5b7 Mon Sep 17 00:00:00 2001 From: Chi Wang Date: Sat, 22 May 2021 08:51:38 -0700 Subject: [PATCH] metric constraint (#90) * penalty change * metric modification * catboost init --- flaml/automl.py | 1 + flaml/model.py | 10 +++++++ flaml/searcher/blendsearch.py | 52 +++++++++++++++++++++++++++++------ 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/flaml/automl.py b/flaml/automl.py index e0435f7675..d15e6c279e 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -922,6 +922,7 @@ def custom_metric(X_test, y_test, estimator, labels, # set up learner search space for estimator_name in estimator_list: estimator_class = self._state.learner_classes[estimator_name] + estimator_class.init() self._search_states[estimator_name] = SearchState( learner_class=estimator_class, data_size=self._state.data_size, task=self._state.task, diff --git a/flaml/model.py b/flaml/model.py index 9deec6f400..08ce5eeb0e 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -163,6 +163,11 @@ def cost_relative2lgbm(cls): '''[optional method] relative cost compared to lightgbm''' return 1.0 + @classmethod + def init(cls): + '''[optional method] initialize the class''' + pass + class SKLearnEstimator(BaseEstimator): @@ -632,6 +637,11 @@ def size(cls, config): def cost_relative2lgbm(cls): return 15 + @classmethod + def init(cls): + CatBoostEstimator._time_per_iter = None + CatBoostEstimator._train_size = 0 + def __init__( self, task='binary:logistic', n_jobs=1, n_estimators=8192, learning_rate=0.1, early_stopping_rounds=4, **params diff --git a/flaml/searcher/blendsearch.py b/flaml/searcher/blendsearch.py index fae5de2a3b..fcf8c343cf 100644 --- a/flaml/searcher/blendsearch.py +++ b/flaml/searcher/blendsearch.py @@ -27,6 +27,8 @@ class BlendSearch(Searcher): ''' cost_attr = "time_total_s" # cost attribute in result + lagrange = '_lagrange' # suffix for lagrange-modified metric + penalty = 1e+10 # penalty term for constraints def __init__(self, metric: Optional[str] = None, @@ -106,6 +108,11 @@ def __init__(self, self._metric, self._mode = metric, mode init_config = low_cost_partial_config or {} self._points_to_evaluate = points_to_evaluate or [] + self._config_constraints = config_constraints + self._metric_constraints = metric_constraints + if self._metric_constraints: + # metric modified by lagrange + metric += self.lagrange if global_search_alg is not None: self._gs = global_search_alg elif getattr(self, '__name__', None) != 'CFO': @@ -115,8 +122,6 @@ def __init__(self, self._ls = LocalSearch( init_config, metric, mode, cat_hp_cost, space, prune_attr, min_resource, max_resource, reduction_factor, seed) - self._config_constraints = config_constraints - self._metric_constraints = metric_constraints self._init_search() def set_search_properties(self, @@ -131,6 +136,11 @@ def set_search_properties(self, else: if metric: self._metric = metric + if self._metric_constraints: + # metric modified by lagrange + metric += self.lagrange + # TODO: don't change metric for global search methods that + # can handle constraints already if mode: self._mode = mode self._ls.set_search_properties(metric, mode, config) @@ -156,6 +166,13 @@ def _init_search(self): self._gs_admissible_max = self._ls_bound_max.copy() self._result = {} # config_signature: tuple -> result: Dict self._deadline = np.inf + if self._metric_constraints: + self._metric_constraint_satisfied = False + self._metric_constraint_penalty = [ + self.penalty for _ in self._metric_constraints] + else: + self._metric_constraint_satisfied = True + self._metric_constraint_penalty = None def save(self, checkpoint_path: str): save_object = self @@ -182,6 +199,8 @@ def restore(self, checkpoint_path: str): self._ls = state._ls self._config_constraints = state._config_constraints self._metric_constraints = state._metric_constraints + self._metric_constraint_satisfied = state._metric_constraint_satisfied + self._metric_constraint_penalty = state._metric_constraint_penalty def restore_from_dir(self, checkpoint_dir: str): super.restore_from_dir(checkpoint_dir) @@ -190,10 +209,11 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, error: bool = False): ''' search thread updater and cleaner ''' + metric_constraint_satisfied = True if result and not error and self._metric_constraints: - # accout for metric constraints if any + # account for metric constraints if any objective = result[self._metric] - for constraint in self._metric_constraints: + for i, constraint in enumerate(self._metric_constraints): metric_constraint, sign, threshold = constraint value = result.get(metric_constraint) if value: @@ -202,8 +222,16 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, violation = (value - threshold) * sign_op if violation > 0: # add penalty term to the metric - objective += 1e+10 * violation * self._ls.metric_op - result[self._metric] = objective + objective += self._metric_constraint_penalty[ + i] * violation * self._ls.metric_op + metric_constraint_satisfied = False + if self._metric_constraint_penalty[i] < self.penalty: + self._metric_constraint_penalty[i] += violation + result[self._metric + self.lagrange] = objective + if metric_constraint_satisfied and not self._metric_constraint_satisfied: + # found a feasible point + self._metric_constraint_penalty = [1 for _ in self._metric_constraints] + self._metric_constraint_satisfied |= metric_constraint_satisfied thread_id = self._trial_proposed_by.get(trial_id) if thread_id in self._search_thread_pool: self._search_thread_pool[thread_id].on_trial_complete( @@ -219,10 +247,13 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, else: # add to result cache self._result[self._ls.config_signature(config)] = result # update target metric if improved - objective = result[self._metric] + objective = result[ + self._metric + self.lagrange] if self._metric_constraints \ + else result[self._metric] if (objective - self._metric_target) * self._ls.metric_op < 0: self._metric_target = objective - if not thread_id and self._create_condition(result): + if not thread_id and metric_constraint_satisfied \ + and self._create_condition(result): # thread creator self._search_thread_pool[self._thread_count] = SearchThread( self._ls.mode, @@ -233,6 +264,9 @@ def on_trial_complete(self, trial_id: str, result: Optional[Dict] = None, self._thread_count += 1 self._update_admissible_region( config, self._ls_bound_min, self._ls_bound_max) + elif thread_id and not self._metric_constraint_satisfied: + # no point has been found to satisfy metric constraint + self._expand_admissible_region() # reset admissible region to ls bounding box self._gs_admissible_min.update(self._ls_bound_min) self._gs_admissible_max.update(self._ls_bound_max) @@ -306,6 +340,8 @@ def on_trial_result(self, trial_id: str, result: Dict): thread_id = self._trial_proposed_by[trial_id] if thread_id not in self._search_thread_pool: return + if result and self._metric_constraints: + result[self._metric + self.lagrange] = result[self._metric] self._search_thread_pool[thread_id].on_trial_result(trial_id, result) def suggest(self, trial_id: str) -> Optional[Dict]: