Skip to content

Commit

Permalink
Add CV wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Jul 4, 2023
1 parent 594aa74 commit 79089e7
Showing 1 changed file with 36 additions and 28 deletions.
64 changes: 36 additions & 28 deletions bluecast/blueprints/cast_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,77 +12,84 @@


class BlueCastCV:
"""Wrapper to train and predict multiple blueCast intsances.
A custom splitter can be provided."""
def __init__(
self,
class_problem: Literal["binary", "multiclass"] = "binary",
stratifier: Optional[Any] = None,
conf_training: Optional[TrainingConfig] = None,
conf_xgboost: Optional[XgboostTuneParamsConfig] = None,
conf_params_xgboost: Optional[XgboostFinalParamConfig] = None,
):
self.class_problem = class_problem
self.conf_training = conf_training
self.conf_xgboost = conf_xgboost

self.conf_training = conf_training
if not self.conf_training:
self.conf_training = TrainingConfig()

self.conf_params_xgboost = conf_params_xgboost
self.bluecast_models: List[BlueCast] = []

self.stratifier = stratifier
if not self.stratifier:
self.stratifier = StratifiedKFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
)

def prepare_data(
self, df: pd.DataFrame, target: str
) -> Tuple[pd.DataFrame, pd.Series]:
df = df.reset_index(drop=True)
y = df[target]
X = df.drop(target, axis=1)
return X, y

def fit(self, df: pd.DataFrame, target: str, stratifier: Optional[Any]) -> None:
X, y = self.prepare_data(df, target)
def fit(self, df: pd.DataFrame, target_col: str) -> None:
"""Fit multiple BlueCast instances on different data splits.
if not self.conf_training:
self.conf_training = TrainingConfig()

if not stratifier:
stratifier = StratifiedKFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
)
Input df is expected the target column."""
X, y = self.prepare_data(df, target_col)

for fn, (trn_idx, val_idx) in enumerate(stratifier.split(X, y)):
for fn, (trn_idx, val_idx) in enumerate(self.stratifier.split(X, y)):
X_train, X_val = X.iloc[trn_idx], X.iloc[val_idx]
y_train, y_val = y.iloc[trn_idx], y.iloc[val_idx]
x_train = pd.concat([X_train, X_val], axis=1)
y_train = pd.concat([y_train, y_val], axis=1)
x_train = pd.concat([X_train, X_val], ignore_index=True)
y_train = pd.concat([y_train, y_val], ignore_index=True)

x_train = x_train.reset_index(drop=True)
X_train = x_train.reset_index(drop=True)
y_train = y_train.reset_index(drop=True)
x_train[target] = y_train[target]
X_train[target_col] = y_train.values

self.conf_training.global_random_state += fn

automl = BlueCast(
class_problem=self.class_problem,
target_column=target,
target_column=target_col,
conf_training=self.conf_training,
conf_xgboost=self.conf_xgboost,
conf_params_xgboost=self.conf_params_xgboost,
)
automl.fit(X_train, target_col=target)
automl.fit(X_train, target_col=target_col)
self.bluecast_models.append(automl)

def fit_eval(
self, df: pd.DataFrame, target_col: str, stratifier: Optional[Any] = None
self, df: pd.DataFrame, target_col: str
) -> None:
"""Fit multiple BlueCast instances on different data splits.
Input df is expected the target column. Evaluation is executed on out-of-fold dataset
in each split."""
X, y = self.prepare_data(df, target_col)

if not self.conf_training:
self.conf_training = TrainingConfig()

if not stratifier:
stratifier = StratifiedKFold(
n_splits=5,
shuffle=True,
random_state=self.conf_training.global_random_state,
)

for fn, (trn_idx, val_idx) in enumerate(stratifier.split(X, y)):
for fn, (trn_idx, val_idx) in enumerate(self.stratifier.split(X, y)):
X_train, X_val = X.iloc[trn_idx], X.iloc[val_idx]
y_train, y_val = y.iloc[trn_idx], y.iloc[val_idx]

Expand All @@ -103,6 +110,7 @@ def fit_eval(
def predict(
self, df: pd.DataFrame, return_sub_models_preds: bool = False
) -> Tuple[Union[pd.DataFrame, pd.Series], Union[pd.DataFrame, pd.Series]]:
"""Predict on unseen data using multiple trained BlueCast instances"""
or_cols = df.columns
prob_cols: list[str] = []
class_cols: list[str] = []
Expand Down

0 comments on commit 79089e7

Please sign in to comment.