Skip to content

Commit

Permalink
don't init global search with points_to_evaluate unless evaluated_rew…
Browse files Browse the repository at this point in the history
…ards is provided; handle callbacks in fit kwargs (#469)
  • Loading branch information
sonichi authored Mar 2, 2022
1 parent df01031 commit 31ac984
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 9 deletions.
13 changes: 8 additions & 5 deletions flaml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,10 +959,16 @@ def fit(self, X_train, y_train, budget=None, **kwargs):
# when not trained, train at least one iter
self.params[self.ITER_HP] = max(max_iter, 1)
if self.HAS_CALLBACK:
kwargs_callbacks = kwargs.get("callbacks")
if kwargs_callbacks:
callbacks = kwargs_callbacks + self._callbacks(start_time, deadline)
kwargs.pop("callbacks")
else:
callbacks = self._callbacks(start_time, deadline)
self._fit(
X_train,
y_train,
callbacks=self._callbacks(start_time, deadline),
callbacks=callbacks,
**kwargs,
)
best_iteration = (
Expand Down Expand Up @@ -1821,10 +1827,7 @@ def search_space(cls, data_size, pred_horizon, **params):
"low_cost_init_value": False,
},
"lags": {
"domain": tune.randint(
lower=1, upper=int(np.sqrt(data_size[0]))

),
"domain": tune.randint(lower=1, upper=int(np.sqrt(data_size[0]))),
"init_value": 3,
},
}
Expand Down
3 changes: 2 additions & 1 deletion flaml/searcher/blendsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ def __init__(
else:
sampler = None
try:
assert evaluated_rewards
self._gs = GlobalSearch(
space=gs_space,
metric=metric,
Expand All @@ -180,7 +181,7 @@ def __init__(
points_to_evaluate=points_to_evaluate,
evaluated_rewards=evaluated_rewards,
)
except ValueError:
except (AssertionError, ValueError):
self._gs = GlobalSearch(
space=gs_space,
metric=metric,
Expand Down
7 changes: 6 additions & 1 deletion test/automl/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ def test_sparse_matrix_xgboost(self):
}
X_train = scipy.sparse.eye(900000)
y_train = np.random.randint(2, size=900000)
automl_experiment.fit(X_train=X_train, y_train=y_train, **automl_settings)
import xgboost as xgb

callback = xgb.callback.TrainingCallback()
automl_experiment.fit(
X_train=X_train, y_train=y_train, callbacks=[callback], **automl_settings
)
print(automl_experiment.predict(X_train))
print(automl_experiment.model)
print(automl_experiment.config_history)
Expand Down
11 changes: 9 additions & 2 deletions test/tune/test_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,21 @@ def test_searcher():
# sign of metric constraints must be <= or >=.
pass
searcher = BlendSearch(
metric="m", global_search_alg=searcher, metric_constraints=[("c", "<=", 1)]
metric="m",
global_search_alg=searcher,
metric_constraints=[("c", "<=", 1)],
points_to_evaluate=[{"a": 1, "b": 0.01}],
)
searcher.set_search_properties(
metric="m2", config=config, setting={"time_budget_s": 0}
)
c = searcher.suggest("t1")
searcher.on_trial_complete("t1", {"config": c}, True)
print("t1", c)
c = searcher.suggest("t2")
print("t2", c)
c = searcher.suggest("t3")
print("t3", c)
searcher.on_trial_complete("t1", {"config": c}, True)
searcher.on_trial_complete("t2", {"config": c, "m2": 1, "c": 2, "time_total_s": 1})
config1 = config.copy()
config1["_choice_"] = 0
Expand Down

0 comments on commit 31ac984

Please sign in to comment.