Skip to content

Commit

Permalink
data validation (#45)
Browse files Browse the repository at this point in the history
* pickle the AutoML object

* get best model per estimator

* test deberta

* stateless API

* prevent divide by zero

* test roberta

* BlendSearchTuner

* delta time

* reindex columns when dropping int-indexed columns

* test drop columns and small training data

* param set for ensemble builder

* fillna on copy

Co-authored-by: Chi Wang (MSR) <[email protected]>
  • Loading branch information
sonichi and Chi Wang (MSR) authored Mar 19, 2021
1 parent bf95d7c commit ae5f8e5
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 43 deletions.
8 changes: 7 additions & 1 deletion flaml/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,12 @@ def _prepare_data(self,
self._state.y_val = (X_train, y_train, X_val, y_val)
if self._split_type == "stratified":
logger.info("Using StratifiedKFold")
assert y_train_all.size >= n_splits, (
f"{n_splits}-fold cross validation"
f" requires input data with at least {n_splits} examples.")
assert y_train_all.size >= 2*n_splits, (
f"{n_splits}-fold cross validation with metric=r2 "
f"requires input data with at least {n_splits*2} examples.")
self._state.kf = RepeatedStratifiedKFold(n_splits=n_splits,
n_repeats=1, random_state=RANDOM_SEED)
else:
Expand Down Expand Up @@ -1045,7 +1051,7 @@ def _search(self):
init_config=None,
search_alg=search_state.search_alg,
time_budget_s=budget_left,
verbose=max(self.verbose-1,0), local_dir='logs/tune_results',
verbose=max(self.verbose-1,0), #local_dir='logs/tune_results',
use_ray=False,
)
# warnings.resetwarnings()
Expand Down
24 changes: 17 additions & 7 deletions flaml/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,39 +192,47 @@ def fit_transform(self, X, y, task):
X = X.copy()
n = X.shape[0]
cat_columns, num_columns = [], []
drop = False
for column in X.columns:
if X[column].dtype.name in ('object', 'category'):
if X[column].nunique() == 1 or X[column].nunique(
dropna=True) == n - X[column].isnull().sum():
X.drop(columns=column, inplace=True)
drop = True
elif X[column].dtype.name == 'category':
current_categories = X[column].cat.categories
if '__NAN__' not in current_categories:
X[column] = X[column].cat.add_categories(
'__NAN__').fillna('__NAN__')
cat_columns.append(column)
else:
X[column].fillna('__NAN__', inplace=True)
X[column] = X[column].fillna('__NAN__')
cat_columns.append(column)
else:
# print(X[column].dtype.name)
if X[column].nunique(dropna=True) < 2:
X.drop(columns=column, inplace=True)
drop = True
else:
X[column].fillna(np.nan, inplace=True)
X[column] = X[column].fillna(np.nan)
num_columns.append(column)
X = X[cat_columns + num_columns]
if cat_columns:
X[cat_columns] = X[cat_columns].astype('category')
if num_columns:
X_num = X[num_columns]
if drop and np.issubdtype(X_num.columns.dtype, np.integer):
X_num.columns = range(X_num.shape[1])
else: drop = False
from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
self.transformer = ColumnTransformer([(
'continuous',
SimpleImputer(missing_values=np.nan, strategy='median'),
num_columns)])
X[num_columns] = self.transformer.fit_transform(X)
X_num.columns)])
X[num_columns] = self.transformer.fit_transform(X_num)
self._cat_columns, self._num_columns = cat_columns, num_columns
self._drop = drop

if task == 'regression':
self.label_transformer = None
Expand All @@ -241,7 +249,7 @@ def transform(self, X):
for column in cat_columns:
# print(column, X[column].dtype.name)
if X[column].dtype.name == 'object':
X[column].fillna('__NAN__', inplace=True)
X[column] = X[column].fillna('__NAN__')
elif X[column].dtype.name == 'category':
current_categories = X[column].cat.categories
if '__NAN__' not in current_categories:
Expand All @@ -250,6 +258,8 @@ def transform(self, X):
if cat_columns:
X[cat_columns] = X[cat_columns].astype('category')
if num_columns:
X[num_columns].fillna(np.nan, inplace=True)
X[num_columns] = self.transformer.transform(X)
X_num = X[num_columns].fillna(np.nan)
if self._drop:
X_num.columns = range(X_num.shape[1])
X[num_columns] = self.transformer.transform(X_num)
return X
42 changes: 18 additions & 24 deletions flaml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,11 +239,8 @@ def __init__(self, task='binary:logistic', n_jobs=1,
else: objective = 'regression'
self.params = {
"n_estimators": int(round(n_estimators)),
"num_leaves": params[
'num_leaves'] if 'num_leaves' in params else int(
round(max_leaves)),
'objective': params[
"objective"] if "objective" in params else objective,
"num_leaves": params.get('num_leaves', int(round(max_leaves))),
'objective': params.get("objective", objective),
'n_jobs': n_jobs,
'learning_rate': float(learning_rate),
'reg_alpha': float(reg_alpha),
Expand Down Expand Up @@ -359,18 +356,17 @@ def __init__(self, task='regression', all_thread=False, n_jobs=1,
self._max_leaves = int(round(max_leaves))
self.params = {
'max_leaves': int(round(max_leaves)),
'max_depth': 0,
'grow_policy': params[
"grow_policy"] if "grow_policy" in params else 'lossguide',
'tree_method':tree_method,
'verbosity': 0,
'nthread':n_jobs,
'max_depth': params.get('max_depth', 0),
'grow_policy': params.get("grow_policy", 'lossguide'),
'tree_method': tree_method,
'verbosity': params.get('verbosity', 0),
'nthread': n_jobs,
'learning_rate': float(learning_rate),
'subsample': float(subsample),
'reg_alpha': float(reg_alpha),
'reg_lambda': float(reg_lambda),
'min_child_weight': float(min_child_weight),
'booster': params['booster'] if 'booster' in params else 'gbtree',
'booster': params.get('booster', 'gbtree'),
'colsample_bylevel': float(colsample_bylevel),
'colsample_bytree':float(colsample_bytree),
}
Expand Down Expand Up @@ -429,17 +425,16 @@ def __init__(self, task='binary:logistic', n_jobs=1,
"n_estimators": int(round(n_estimators)),
'max_leaves': int(round(max_leaves)),
'max_depth': 0,
'grow_policy': params[
"grow_policy"] if "grow_policy" in params else 'lossguide',
'tree_method':tree_method,
'grow_policy': params.get("grow_policy", 'lossguide'),
'tree_method': tree_method,
'verbosity': 0,
'n_jobs': n_jobs,
'learning_rate': float(learning_rate),
'subsample': float(subsample),
'reg_alpha': float(reg_alpha),
'reg_lambda': float(reg_lambda),
'min_child_weight': float(min_child_weight),
'booster': params['booster'] if 'booster' in params else 'gbtree',
'booster': params.get('booster', 'gbtree'),
'colsample_bylevel': float(colsample_bylevel),
'colsample_bytree': float(colsample_bytree),
}
Expand Down Expand Up @@ -544,10 +539,10 @@ def __init__(self, task='binary:logistic', n_jobs=1, tol=0.0001, C=1.0,
**params):
super().__init__(task, **params)
self.params = {
'penalty': 'l1',
'penalty': params.get("penalty", 'l1'),
'tol': float(tol),
'C': float(C),
'solver': 'saga',
'solver': params.get("solver", 'saga'),
'n_jobs': n_jobs,
}
if 'regression' in task:
Expand All @@ -573,10 +568,10 @@ def __init__(self, task='binary:logistic', n_jobs=1, tol=0.0001, C=1.0,
**params):
super().__init__(task, **params)
self.params = {
'penalty': 'l2',
'penalty': params.get("penalty", 'l2'),
'tol': float(tol),
'C': float(C),
'solver': 'lbfgs',
'solver': params.get("solver", 'lbfgs'),
'n_jobs': n_jobs,
}
if 'regression' in task:
Expand Down Expand Up @@ -625,9 +620,8 @@ def __init__(self, task = 'binary:logistic', n_jobs=1,
"n_estimators": n_estimators,
'learning_rate': learning_rate,
'thread_count': n_jobs,
'verbose': False,
'random_seed': params[
"random_seed"] if "random_seed" in params else 10242048,
'verbose': params.get('verbose', False),
'random_seed': params.get("random_seed", 10242048),
}
if 'regression' in task:
from catboost import CatBoostRegressor
Expand Down Expand Up @@ -724,7 +718,7 @@ def __init__(self, task='binary:logistic', n_jobs=1,
super().__init__(task, **params)
self.params= {
'n_neighbors': int(round(n_neighbors)),
'weights': 'distance',
'weights': params.get('weights', 'distance'),
'n_jobs': n_jobs,
}
if 'regression' in task:
Expand Down
17 changes: 13 additions & 4 deletions flaml/tune/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def compute_with_config(config):
0 = silent, 1 = only status updates, 2 = status and brief trial
results, 3 = status and detailed trial results. Defaults to 2.
local_dir: A string of the local dir to save ray logs if ray backend is
used.
used; or a local dir to save the tuning log.
num_samples: An integer of the number of configs to try. Defaults to 1.
resources_per_trial: A dictionary of the hardware resources to allocate
per trial, e.g., `{'mem': 1024**3}`. When not using ray backend,
Expand All @@ -221,9 +221,18 @@ def compute_with_config(config):
_verbose = verbose
if verbose > 0:
import os
os.makedirs(local_dir, exist_ok=True)
logger.addHandler(logging.FileHandler(local_dir+'/tune_'+str(
datetime.datetime.now()).replace(':', '-')+'.log'))
if local_dir:
os.makedirs(local_dir, exist_ok=True)
logger.addHandler(logging.FileHandler(local_dir+'/tune_'+str(
datetime.datetime.now()).replace(':', '-')+'.log'))
elif not logger.handlers:
# Add the console handler.
_ch = logging.StreamHandler()
logger_formatter = logging.Formatter(
'[%(name)s: %(asctime)s] {%(lineno)d} %(levelname)s - %(message)s',
'%m-%d %H:%M:%S')
_ch.setFormatter(logger_formatter)
logger.addHandler(_ch)
if verbose<=2:
logger.setLevel(logging.INFO)
else:
Expand Down
6 changes: 4 additions & 2 deletions test/nni/flaml_nni_wrap.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from flaml.searcher.blendsearch import BlendSearchTuner as BST, BlendSearch
from flaml.searcher.blendsearch import BlendSearchTuner as BST


class BlendSearchTuner(BST):
# for best performance pass low cost initial parameters here
def __init__(self, points_to_evaluate=[{"hidden_size":128}]):
super.__init__(self,points_to_evaluate=points_to_evaluate)
super.__init__(self, points_to_evaluate=points_to_evaluate)
14 changes: 10 additions & 4 deletions test/test_automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ def test_classification(self, as_frame=False):
"model_history": True
}
X_train, y_train = load_iris(return_X_y=True, as_frame=as_frame)
if as_frame:
# test drop column
X_train.columns = range(X_train.shape[1])
X_train[X_train.shape[1]] = np.zeros(len(y_train))
automl_experiment.fit(X_train=X_train, y_train=y_train,
**automl_settings)
print(automl_experiment.classes_)
Expand Down Expand Up @@ -252,7 +256,8 @@ def test_sparse_matrix_regression(self):
"task": 'regression',
"log_file_name": "test/sparse_regression.log",
"n_jobs": 1,
"model_history": True
"model_history": True,
"verbose": 0,
}
X_train = scipy.sparse.random(300, 900, density=0.0001)
y_train = np.random.uniform(size=300)
Expand Down Expand Up @@ -327,10 +332,11 @@ def test_sparse_matrix_regression_cv(self):
"task": 'regression',
"log_file_name": "test/sparse_regression.log",
"n_jobs": 1,
"model_history": True
"model_history": True,
"metric": "mse"
}
X_train = scipy.sparse.random(100, 100)
y_train = np.random.uniform(size=100)
X_train = scipy.sparse.random(8, 100)
y_train = np.random.uniform(size=8)
automl_experiment.fit(X_train=X_train, y_train=y_train,
**automl_settings)
print(automl_experiment.predict(X_train))
Expand Down
3 changes: 2 additions & 1 deletion test/test_training_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_training_log(self):
"log_training_metric": True,
"mem_thres": 1024*1024,
"n_jobs": 1,
"model_history": True
"model_history": True,
"verbose": 2,
}
X_train, y_train = load_boston(return_X_y=True)
automl_experiment.fit(X_train=X_train, y_train=y_train,
Expand Down

0 comments on commit ae5f8e5

Please sign in to comment.