Skip to content

Commit

Permalink
Fix multiclass issue
Browse files Browse the repository at this point in the history
  • Loading branch information
ThomasMeissnerDS committed Sep 13, 2023
1 parent d825e15 commit fd1553a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 6 deletions.
14 changes: 13 additions & 1 deletion bluecast/blueprints/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,19 @@ def fit_eval(
"""
self.fit(df, target_col)
y_probs, y_classes = self.predict(df_eval)
eval_dict = eval_classifier(target_eval.values, y_probs, y_classes)

if self.feat_type_detector:
if self.target_label_encoder and self.feat_type_detector:
eval_df = pd.DataFrame(target_eval.values, columns=[target_col])
y_true = self.target_label_encoder.transform_target_labels(
eval_df, target_col
)
else:
y_true = target_eval.values
else:
y_true = target_eval.values

eval_dict = eval_classifier(y_true, y_probs, y_classes)
self.eval_metrics = eval_dict
return eval_dict

Expand Down
25 changes: 20 additions & 5 deletions bluecast/preprocessing/encode_target_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from datetime import datetime
from typing import Dict
from typing import Dict, Optional, Union

import pandas as pd

Expand Down Expand Up @@ -41,12 +41,23 @@ def fit_label_encoder(self, targets: pd.DataFrame) -> Dict[str, int]:
return cat_mapping

def label_encoder_transform(
self, targets: pd.DataFrame, mapping: Dict[str, int]
self,
targets: pd.DataFrame,
mapping: Dict[str, int],
target_col: Optional[Union[str, int, float]] = None,
) -> pd.DataFrame:
"""Transform target column from categorical to numerical representation."""
logger(f"{datetime.utcnow()}: Start encoding target labels.")
if (
isinstance(target_col, str)
or isinstance(target_col, int)
or isinstance(target_col, float)
):
col = target_col
else:
col = targets.name

targets = targets.astype("category")
col = targets.name
if isinstance(targets, pd.Series):
targets = targets.to_frame()
mapping = self.target_label_mapping
Expand All @@ -61,9 +72,13 @@ def fit_transform_target_labels(self, targets: pd.DataFrame) -> pd.DataFrame:
targets = self.label_encoder_transform(targets, self.target_label_mapping)
return targets

def transform_target_labels(self, targets: pd.DataFrame) -> pd.DataFrame:
def transform_target_labels(
self, targets: pd.DataFrame, target_col: Optional[Union[str, int, float]] = None
) -> pd.DataFrame:
"""Transform the target column based on already created mappings."""
targets = self.label_encoder_transform(targets, self.target_label_mapping)
targets = self.label_encoder_transform(
targets, self.target_label_mapping, target_col
)
return targets

def label_encoder_reverse_transform(self, targets: pd.Series) -> pd.DataFrame:
Expand Down
Binary file modified dist/bluecast-0.22-py3-none-any.whl
Binary file not shown.
Binary file modified dist/bluecast-0.22.tar.gz
Binary file not shown.

0 comments on commit fd1553a

Please sign in to comment.