diff --git a/README.md b/README.md index 2fc7360c2b..7c9be73a8b 100644 --- a/README.md +++ b/README.md @@ -140,7 +140,6 @@ For more technical details, please check our papers. ``` * [Frugal Optimization for Cost-related Hyperparameters](https://arxiv.org/abs/2005.01571). Qingyun Wu, Chi Wang, Silu Huang. AAAI 2021. * [Economical Hyperparameter Optimization With Blended Search Strategy](https://www.microsoft.com/en-us/research/publication/economical-hyperparameter-optimization-with-blended-search-strategy/). Chi Wang, Qingyun Wu, Silu Huang, Amin Saied. ICLR 2021. - * ChaCha for online AutoML. Qingyun Wu, Chi Wang, John Langford, Paul Mineiro and Marco Rossi. To appear in ICML 2021. ## Contributing diff --git a/flaml/__init__.py b/flaml/__init__.py index 6b8900f28e..345915d3b3 100644 --- a/flaml/__init__.py +++ b/flaml/__init__.py @@ -3,7 +3,8 @@ try: from flaml.onlineml.autovw import AutoVW except ImportError: - print('need to install vowpalwabbit to use AutoVW') + # print('need to install vowpalwabbit to use AutoVW') + pass from flaml.version import __version__ import logging diff --git a/flaml/automl.py b/flaml/automl.py index 30698ddab4..4bff9ad385 100644 --- a/flaml/automl.py +++ b/flaml/automl.py @@ -809,14 +809,17 @@ def fit(self, dataframe and label are ignored; If not, dataframe and label must be provided. metric: A string of the metric name or a function, - e.g., 'accuracy', 'roc_auc', 'f1', 'micro_f1', 'macro_f1', 'log_loss', 'mae', 'mse', 'r2' + e.g., 'accuracy', 'roc_auc', 'f1', 'micro_f1', 'macro_f1', + 'log_loss', 'mae', 'mse', 'r2' if passing a customized metric function, the function needs to have the follwing signature: .. code-block:: python - def custom_metric(X_test, y_test, estimator, labels, - X_train, y_train, weight_test=None, weight_train=None): + def custom_metric( + X_test, y_test, estimator, labels, + X_train, y_train, weight_test=None, weight_train=None + ): return metric_to_minimize, metrics_to_log which returns a float number as the minimization objective, @@ -1238,7 +1241,7 @@ def _select_estimator(self, estimator_list): for i, estimator in enumerate(estimator_list): if estimator in self._search_states and ( self._search_states[estimator].sample_size - ): # sample_size=none meaning no result + ): # sample_size=None meaning no result search_state = self._search_states[estimator] if (self._search_states[estimator].time2eval_best > self._state.time_budget - self._state.time_from_start diff --git a/flaml/ml.py b/flaml/ml.py index b7c27a8883..2f6b658aa3 100644 --- a/flaml/ml.py +++ b/flaml/ml.py @@ -294,3 +294,45 @@ def get_classification_objective(num_labels: int) -> str: else: objective_name = 'multi:softmax' return objective_name + + +def norm_confusion_matrix(y_true, y_pred): + '''normalized confusion matrix + + Args: + estimator: A multi-class classification estimator + y_true: A numpy array or a pandas series of true labels + y_pred: A numpy array or a pandas series of predicted labels + + Returns: + A normalized confusion matrix + ''' + from sklearn.metrics import confusion_matrix + conf_mat = confusion_matrix(y_true, y_pred) + norm_conf_mat = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis] + return norm_conf_mat + + +def multi_class_curves(y_true, y_pred_proba, curve_func): + '''Binarize the data for multi-class tasks and produce ROC or precision-recall curves + + Args: + y_true: A numpy array or a pandas series of true labels + y_pred_proba: A numpy array or a pandas dataframe of predicted probabilites + curve_func: A function to produce a curve (e.g., roc_curve or precision_recall_curve) + + Returns: + A tuple of two dictionaries with the same set of keys (class indices) + The first dictionary curve_x stores the x coordinates of each curve, e.g., + curve_x[0] is an 1D array of the x coordinates of class 0 + The second dictionary curve_y stores the y coordinates of each curve, e.g., + curve_y[0] is an 1D array of the y coordinates of class 0 + ''' + from sklearn.preprocessing import label_binarize + classes = np.unique(y_true) + y_true_binary = label_binarize(y_true, classes=classes) + + curve_x, curve_y = {}, {} + for i in range(len(classes)): + curve_x[i], curve_y[i], _ = curve_func(y_true_binary[:, i], y_pred_proba[:, i]) + return curve_x, curve_y diff --git a/flaml/model.py b/flaml/model.py index 08ce5eeb0e..8d32f3f38f 100644 --- a/flaml/model.py +++ b/flaml/model.py @@ -461,6 +461,7 @@ def __init__( 'booster': params.get('booster', 'gbtree'), 'colsample_bylevel': float(colsample_bylevel), 'colsample_bytree': float(colsample_bytree), + 'use_label_encoder': params.get('use_label_encoder', False), } if 'regression' in task: diff --git a/flaml/onlineml/autovw.py b/flaml/onlineml/autovw.py index 644cf63b3e..3feef97b09 100644 --- a/flaml/onlineml/autovw.py +++ b/flaml/onlineml/autovw.py @@ -1,4 +1,3 @@ -import numpy as np from typing import Optional, Union import logging from flaml.tune import Trial, Categorical, Float, PolynomialExpansionSet, polynomial_expansion_set @@ -6,6 +5,7 @@ from flaml.scheduler import ChaChaScheduler from flaml.searcher import ChampionFrontierSearcher from flaml.onlineml.trial import get_ns_feature_dim_from_vw_example + logger = logging.getLogger(__name__) diff --git a/flaml/onlineml/trial.py b/flaml/onlineml/trial.py index 5dba6356bf..e1f3691300 100644 --- a/flaml/onlineml/trial.py +++ b/flaml/onlineml/trial.py @@ -4,10 +4,10 @@ import math import copy import collections -from typing import Dict, Optional +from typing import Optional from sklearn.metrics import mean_squared_error, mean_absolute_error -from vowpalwabbit import pyvw from flaml.tune import Trial + logger = logging.getLogger(__name__) @@ -270,6 +270,7 @@ class VowpalWabbitTrial(BaseOnlineTrial): - Namespace vs features: https://stackoverflow.com/questions/28586225/in-vowpal-wabbit-what-is-the-difference-between-a-namespace-and-feature """ + from vowpalwabbit import pyvw MODEL_CLASS = pyvw.vw cost_unit = 1.0 interactions_config_key = 'interactions' diff --git a/flaml/searcher/__init__.py b/flaml/searcher/__init__.py index 009e6879e0..9f74737bd7 100644 --- a/flaml/searcher/__init__.py +++ b/flaml/searcher/__init__.py @@ -1,6 +1,3 @@ from .blendsearch import CFO, BlendSearch, BlendSearchTuner from .flow2 import FLOW2 -try: - from .online_searcher import ChampionFrontierSearcher -except ImportError: - print('need to install vowpalwabbit to use ChampionFrontierSearcher') +from .online_searcher import ChampionFrontierSearcher diff --git a/flaml/searcher/blendsearch.py b/flaml/searcher/blendsearch.py index c2cc6ab40c..245c2cc55d 100644 --- a/flaml/searcher/blendsearch.py +++ b/flaml/searcher/blendsearch.py @@ -109,10 +109,10 @@ def __init__(self, init_config = low_cost_partial_config or {} if not init_config: logger.warning( - "No low-cost init config given to the search algorithm." + "No low-cost partial config given to the search algorithm. " "For cost-frugal search, " - "consider providing init values for cost-related hps via " - "'init_config'." + "consider providing low-cost values for cost-related hps via " + "'low_cost_partial_config'." ) self._points_to_evaluate = points_to_evaluate or [] self._config_constraints = config_constraints diff --git a/flaml/searcher/flow2.py b/flaml/searcher/flow2.py index f4120c3eee..14497138b9 100644 --- a/flaml/searcher/flow2.py +++ b/flaml/searcher/flow2.py @@ -124,15 +124,16 @@ def _init_search(self): if callable(getattr(domain, 'get_sampler', None)): self._tunable_keys.append(key) sampler = domain.get_sampler() - # if isinstance(sampler, sample.Quantized): - # sampler_inner = sampler.get_sampler() - # if str(sampler_inner) == 'Uniform': - # self._step_lb = min( - # self._step_lb, sampler.q/(domain.upper-domain.lower)) - # elif isinstance(domain, sample.Integer) and str( - # sampler) == 'Uniform': - # self._step_lb = min( - # self._step_lb, 1.0/(domain.upper-domain.lower)) + # the step size lower bound for uniform variables doesn't depend + # on the current config + if isinstance(sampler, sample.Quantized): + sampler_inner = sampler.get_sampler() + if str(sampler_inner) == 'Uniform': + self._step_lb = min( + self._step_lb, sampler.q / (domain.upper - domain.lower)) + elif isinstance(domain, sample.Integer) and str(sampler) == 'Uniform': + self._step_lb = min( + self._step_lb, 1.0 / (domain.upper - domain.lower)) if isinstance(domain, sample.Categorical): cat_hp_cost = self.cat_hp_cost if cat_hp_cost and key in cat_hp_cost: @@ -199,6 +200,8 @@ def step_lower_bound(self) -> float: continue domain = self.space[key] sampler = domain.get_sampler() + # the stepsize lower bound for log uniform variables depends on the + # current config if isinstance(sampler, sample.Quantized): sampler_inner = sampler.get_sampler() if str(sampler_inner) == 'LogUniform': diff --git a/test/test_automl.py b/test/test_automl.py index 15211d9192..7e3bfdc51c 100644 --- a/test/test_automl.py +++ b/test/test_automl.py @@ -260,6 +260,14 @@ def test_micro_macro_f1(self): X_train=X_train, y_train=y_train, metric='micro_f1', **automl_settings) automl_experiment_macro.fit( X_train=X_train, y_train=y_train, metric='macro_f1', **automl_settings) + estimator = automl_experiment_macro.model + y_pred = estimator.predict(X_train) + y_pred_proba = estimator.predict_proba(X_train) + from flaml.ml import norm_confusion_matrix, multi_class_curves + print(norm_confusion_matrix(y_train, y_pred)) + from sklearn.metrics import roc_curve, precision_recall_curve + print(multi_class_curves(y_train, y_pred_proba, roc_curve)) + print(multi_class_curves(y_train, y_pred_proba, precision_recall_curve)) def test_regression(self): automl_experiment = AutoML()