Skip to content

Commit

Permalink
Pre-commit changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jul 4, 2023
1 parent 79089e7 commit 4623fc2
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions bluecast/blueprints/cast_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class BlueCastCV:
"""Wrapper to train and predict multiple blueCast intsances.
A custom splitter can be provided."""

def __init__(
self,
class_problem: Literal["binary", "multiclass"] = "binary",
Expand All @@ -25,21 +26,10 @@ def __init__(
):
self.class_problem = class_problem
self.conf_xgboost = conf_xgboost

self.conf_training = conf_training
if not self.conf_training:
self.conf_training = TrainingConfig()

self.conf_params_xgboost = conf_params_xgboost
self.bluecast_models: List[BlueCast] = []

self.stratifier = stratifier
if not self.stratifier:
self.stratifier = StratifiedKFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
)

def prepare_data(
self, df: pd.DataFrame, target: str
Expand All @@ -55,6 +45,16 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
Input df is expected the target column."""
X, y = self.prepare_data(df, target_col)

if not self.conf_training:
self.conf_training = TrainingConfig()

if not self.stratifier:
self.stratifier = StratifiedKFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
)

for fn, (trn_idx, val_idx) in enumerate(self.stratifier.split(X, y)):
X_train, X_val = X.iloc[trn_idx], X.iloc[val_idx]
y_train, y_val = y.iloc[trn_idx], y.iloc[val_idx]
Expand All @@ -77,9 +77,7 @@ def fit(self, df: pd.DataFrame, target_col: str) -> None:
automl.fit(X_train, target_col=target_col)
self.bluecast_models.append(automl)

def fit_eval(
self, df: pd.DataFrame, target_col: str
) -> None:
def fit_eval(self, df: pd.DataFrame, target_col: str) -> None:
"""Fit multiple BlueCast instances on different data splits.
Input df is expected the target column. Evaluation is executed on out-of-fold dataset
Expand All @@ -89,6 +87,13 @@ def fit_eval(
if not self.conf_training:
self.conf_training = TrainingConfig()

if not self.stratifier:
self.stratifier = StratifiedKFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
)

for fn, (trn_idx, val_idx) in enumerate(self.stratifier.split(X, y)):
X_train, X_val = X.iloc[trn_idx], X.iloc[val_idx]
y_train, y_val = y.iloc[trn_idx], y.iloc[val_idx]
Expand Down Expand Up @@ -124,4 +129,7 @@ def predict(
if return_sub_models_preds:
return df.loc[:, prob_cols], df.loc[:, class_cols]
else:
return df.loc[:, prob_cols].mean(axis=1), df.loc[:, prob_cols].mean(axis=1) > 0.5
return (
df.loc[:, prob_cols].mean(axis=1),
df.loc[:, prob_cols].mean(axis=1) > 0.5,
)

0 comments on commit 4623fc2

Please sign in to comment.