Skip to content

Commit

Permalink
support for customized splitters (#333)
Browse files Browse the repository at this point in the history
* add support for customized splitters

* use the param split_type for feeding generators

* use single API for customized splitter and add test

* when task==TS_FORCAST, always set shuffle=False

* update docstr

Co-authored-by: Chi Wang <[email protected]>
  • Loading branch information
wuchihsu and sonichi authored Dec 17, 2021
1 parent 7b24662 commit 671ccbb
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 9 deletions.
32 changes: 26 additions & 6 deletions flaml/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,11 @@ def custom_metric(
True - retrain only after search finishes; False - no retraining;
'budget' - do best effort to retrain without violating the time
budget.
split_type: str, default="auto" | the data split type.
split_type: str or splitter object, default="auto" | the data split type.
A valid splitter object is an instance of a derived class of scikit-learn KFold
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)
and have ``split`` and ``get_n_splits`` methods with the same signatures.
Valid str options depend on different tasks.
For classification tasks, valid choices are [
"auto", 'stratified', 'uniform', 'time']. "auto" -> stratified.
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
Expand Down Expand Up @@ -955,7 +959,7 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
self._state.task in CLASSIFICATION
and self._auto_augment
and self._state.fit_kwargs.get("sample_weight") is None
and self._split_type not in ["time", "group"]
and self._split_type in ["stratified", "uniform"]
):
# logger.info(f"label {pd.unique(y_train_all)}")
label_set, counts = np.unique(y_train_all, return_counts=True)
Expand Down Expand Up @@ -1183,11 +1187,14 @@ def _prepare_data(self, eval_method, split_ratio, n_splits):
self._state.kf = TimeSeriesSplit(n_splits=n_splits, test_size=period)
else:
self._state.kf = TimeSeriesSplit(n_splits=n_splits)
else:
elif isinstance(self._split_type, str):
# logger.info("Using RepeatedKFold")
self._state.kf = RepeatedKFold(
n_splits=n_splits, n_repeats=1, random_state=RANDOM_SEED
)
else:
# logger.info("Using splitter object")
self._state.kf = self._split_type

def add_learner(self, learner_name, learner_class):
"""Add a customized learner.
Expand Down Expand Up @@ -1277,7 +1284,11 @@ def retrain_from_log(
['auto', 'cv', 'holdout'].
split_ratio: A float of the validation data percentage for holdout.
n_splits: An integer of the number of folds for cross-validation.
split_type: str, default="auto" | the data split type.
split_type: str or splitter object, default="auto" | the data split type.
A valid splitter object is an instance of a derived class of scikit-learn KFold
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)
and have ``split`` and ``get_n_splits`` methods with the same signatures.
Valid str options depend on different tasks.
For classification tasks, valid choices are [
"auto", 'stratified', 'uniform', 'time', 'group']. "auto" -> stratified.
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
Expand Down Expand Up @@ -1399,7 +1410,12 @@ def _decide_split_type(self, split_type):
self._state.task = get_classification_objective(
len(np.unique(self._y_train_all))
)
if self._state.task in CLASSIFICATION:
if not isinstance(split_type, str):
assert hasattr(split_type, "split") and hasattr(
split_type, "get_n_splits"
), "split_type must be a string or a splitter object with split and get_n_splits methods."
self._split_type = split_type
elif self._state.task in CLASSIFICATION:
assert split_type in ["auto", "stratified", "uniform", "time", "group"]
self._split_type = (
split_type
Expand Down Expand Up @@ -1786,7 +1802,11 @@ def custom_metric(
True - retrain only after search finishes; False - no retraining;
'budget' - do best effort to retrain without violating the time
budget.
split_type: str, default="auto" | the data split type.
split_type: str or splitter object, default="auto" | the data split type.
A valid splitter object is an instance of a derived class of scikit-learn KFold
(https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.KFold.html#sklearn.model_selection.KFold)
and have ``split`` and ``get_n_splits`` methods with the same signatures.
Valid str options depend on different tasks.
For classification tasks, valid choices are [
"auto", 'stratified', 'uniform', 'time']. "auto" -> stratified.
For regression tasks, valid choices are ["auto", 'uniform', 'time'].
Expand Down
3 changes: 1 addition & 2 deletions flaml/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def evaluate_model_CV(
else:
labels = None
groups = None
shuffle = True
shuffle = False if task == TS_FORECAST else True
if isinstance(kf, RepeatedStratifiedKFold):
kf = kf.split(X_train_split, y_train_split)
elif isinstance(kf, GroupKFold):
Expand All @@ -423,7 +423,6 @@ def evaluate_model_CV(
y_train_all = pd.DataFrame(y_train_all, columns=[TS_VALUE_COL])
train = X_train_all.join(y_train_all)
kf = kf.split(train)
shuffle = False
elif isinstance(kf, TimeSeriesSplit):
kf = kf.split(X_train_split, y_train_split)
else:
Expand Down
41 changes: 40 additions & 1 deletion test/automl/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from sklearn.datasets import fetch_openml
from flaml.automl import AutoML
from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score


Expand Down Expand Up @@ -123,6 +123,45 @@ def test_rank():
automl.fit(X, y, **automl_settings)


def test_object():
from sklearn.externals._arff import ArffException

try:
X, y = fetch_openml(name=dataset, return_X_y=True)
except (ArffException, ValueError):
from sklearn.datasets import load_wine

X, y = load_wine(return_X_y=True)

import numpy as np

class TestKFold(KFold):
def __init__(self, n_splits):
self.n_splits = int(n_splits)

def split(self, X):
rng = np.random.default_rng()
train_num = int(len(X) * 0.8)
for _ in range(self.n_splits):
permu_idx = rng.permutation(len(X))
yield permu_idx[:train_num], permu_idx[train_num:]

def get_n_splits(self, X=None, y=None, groups=None):
return self.n_splits

automl = AutoML()
automl_settings = {
"time_budget": 2,
# "metric": 'accuracy',
"task": "classification",
"log_file_name": "test/{}.log".format(dataset),
"model_history": True,
"log_training_metric": True,
"split_type": TestKFold(5),
}
automl.fit(X, y, **automl_settings)


if __name__ == "__main__":
# unittest.main()
test_groups()

0 comments on commit 671ccbb

Please sign in to comment.