Skip to content

Commit

Permalink
Fix multiclass issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Sep 13, 2023
1 parent fd1553a commit 8e1b866
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 12 deletions.
53 changes: 42 additions & 11 deletions bluecast/blueprints/cast_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,46 @@ def predict(
class_cols: list[str] = []
for fn, pipeline in enumerate(self.bluecast_models):
y_probs, y_classes = pipeline.predict(df.loc[:, or_cols])
df[f"proba_{fn}"] = y_probs
df[f"classes_{fn}"] = y_classes
prob_cols.append(f"proba_{fn}")
class_cols.append(f"classes_{fn}")

if return_sub_models_preds:
return df.loc[:, prob_cols], df.loc[:, class_cols]
if self.class_problem == "multiclass":
proba_cols = [
f"class_{col}_proba_model_{fn}" for col in range(y_probs.shape[1])
]
df[proba_cols] = y_probs
df[f"classes_{fn}"] = y_classes
for col in proba_cols:
prob_cols.append(col)
class_cols.append(f"classes_{fn}")

else:
df[f"proba_{fn}"] = y_probs
df[f"classes_{fn}"] = y_classes
prob_cols.append(f"proba_{fn}")
class_cols.append(f"classes_{fn}")

if self.class_problem == "multiclass":
if return_sub_models_preds:
return df.loc[:, prob_cols], df.loc[:, class_cols]
else:
classes = df.loc[:, class_cols].mode(axis=1)[0].astype(int)

if self.bluecast_models[0].feat_type_detector:
if (
self.bluecast_models[0].target_label_encoder
and self.bluecast_models[0].feat_type_detector
):
classes = self.bluecast_models[
0
].target_label_encoder.label_encoder_reverse_transform(classes)

return (
df.loc[:, prob_cols].mean(axis=1),
classes,
)
else:
return (
df.loc[:, prob_cols].mean(axis=1),
df.loc[:, prob_cols].mean(axis=1) > 0.5,
)
if return_sub_models_preds:
return df.loc[:, prob_cols], df.loc[:, class_cols]
else:
return (
df.loc[:, prob_cols].mean(axis=1),
df.loc[:, prob_cols].mean(axis=1) > 0.5,
)
3 changes: 2 additions & 1 deletion bluecast/preprocessing/encode_target_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def fit_label_encoder(self, targets: pd.DataFrame) -> Dict[str, int]:
if isinstance(targets, pd.Series):
targets = targets.to_frame()

values = targets[col].unique()
values = sorted(targets[col].unique().tolist())

cat_mapping = {}
for label, cat in enumerate(values):
cat_mapping[cat] = label
Expand Down
Binary file modified dist/bluecast-0.22-py3-none-any.whl
Binary file not shown.
Binary file modified dist/bluecast-0.22.tar.gz
Binary file not shown.

0 comments on commit 8e1b866

Please sign in to comment.