Skip to content

Commit

Permalink
Fix unit tests. Fix classification threshold in cast_cv
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Aug 19, 2024
1 parent f55589e commit 7808530
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 4 deletions.
11 changes: 10 additions & 1 deletion bluecast/blueprints/cast_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,17 @@ def predict(
if return_sub_models_preds:
return result_df.loc[:, prob_cols], result_df.loc[:, class_cols]
else:
if self.conf_params_xgboost:
classification_threshold = (
self.conf_params_xgboost.classification_threshold
)
else:
classification_threshold = 0.5

Check warning on line 328 in bluecast/blueprints/cast_cv.py

View check run for this annotation

Codecov / codecov/patch

bluecast/blueprints/cast_cv.py#L328

Added line #L328 was not covered by tests

y_probs = result_df.loc[:, prob_cols].mean(axis=1)
y_classes = (result_df.loc[:, prob_cols].mean(axis=1) > 0.5).astype(int)
y_classes = (
result_df.loc[:, prob_cols].mean(axis=1) > classification_threshold
).astype(int)

if (
self.bluecast_models[0].feat_type_detector
Expand Down
4 changes: 2 additions & 2 deletions bluecast/blueprints/custom_model_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def predict(self, df: pd.DataFrame) -> Tuple[PredictedProbas, PredictedClasses]:
class LinearRegressionModel(BaseClassMlModel):
def __init__(self):
self.linear_regression_model: LinearRegression = LinearRegression()
self.model: Optional[GridSearchCV] = None
self.model: Optional[LinearRegression] = None

def autotune(
self,
Expand All @@ -106,7 +106,7 @@ def fit(

def predict(self, df: pd.DataFrame) -> Tuple[PredictedProbas, PredictedClasses]:

if isinstance(self.model, GridSearchCV):
if isinstance(self.model, LinearRegression):
preds = self.model.predict(df)
return preds
else:
Expand Down
Binary file modified dist/bluecast-1.6.0-py3-none-any.whl
Binary file not shown.
Binary file modified dist/bluecast-1.6.0.tar.gz
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def replace_diff_values(x, diff_values=diff_values):
train_config.calculate_shap_values = False
else:
train_config.autotune_model = True
train_config.hypertuning_cv_folds = 1
train_config.hypertuning_cv_folds = 5
train_config.hypertuning_cv_repeats = 1
train_config.cardinality_threshold_for_onehot_encoding = 3
train_config.hyperparameter_tuning_rounds = 50
Expand Down

0 comments on commit 7808530

Please sign in to comment.