Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix missing argument in StratifiedKFold.split() #1304

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions flaml/automl/task/generic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,9 +414,7 @@ def prepare_data(
sample_weight_full,
random_state=RANDOM_SEED,
)
state.fit_kwargs[
"sample_weight"
] = (
state.fit_kwargs["sample_weight"] = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets reduce the change, keep it focus

state.sample_weight_all
) # NOTE: _prepare_data is before kwargs is updated to fit_kwargs_by_estimator
if isinstance(state.sample_weight_all, pd.Series):
Expand Down Expand Up @@ -501,9 +499,7 @@ def prepare_data(
y_rest = (
y_train_all[rest]
if isinstance(y_train_all, np.ndarray)
else iloc_pandas_on_spark(y_train_all, rest)
if is_spark_dataframe
else y_train_all.iloc[rest]
else iloc_pandas_on_spark(y_train_all, rest) if is_spark_dataframe else y_train_all.iloc[rest]
)
stratify = y_rest if split_type == "stratified" else None
X_train, X_val, y_train, y_val = self._train_test_split(
Expand Down Expand Up @@ -619,9 +615,11 @@ def preprocess(self, X, transformer=None):
X = pd.DataFrame(
dict(
[
(transformer._str_columns[idx], X[idx])
if isinstance(X[0], List)
else (transformer._str_columns[idx], [X[idx]])
(
(transformer._str_columns[idx], X[idx])
if isinstance(X[0], List)
else (transformer._str_columns[idx], [X[idx]])
)
for idx in range(len(X))
]
)
Expand Down Expand Up @@ -701,7 +699,7 @@ def evaluate_model_CV(
elif isinstance(kf, TimeSeriesSplit):
kf = kf.split(X_train_split, y_train_split)
else:
kf = kf.split(X_train_split)
kf = kf.split(X_train_split, y_train_split)

for train_index, val_index in kf:
if shuffle:
Expand Down