Skip to content

Commit

Permalink
fixed mypy typings
Browse files Browse the repository at this point in the history
  • Loading branch information
vinaysb committed May 15, 2024
1 parent 3f7f76b commit b0f6644
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions src/clep/classification/classify.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import json
import logging
from numbers import Real
import sys
from collections import defaultdict
from typing import Dict, List, Any, Callable, Tuple
from typing import Dict, List, Any, Callable, Tuple, Union
import copy

import click
Expand All @@ -16,6 +17,7 @@
import pandera.typing as pat
from sklearn import linear_model, svm, ensemble, model_selection, multiclass, metrics, preprocessing
from sklearn.base import BaseEstimator
from sklearn.model_selection._search import BaseSearchCV
from sklearn.model_selection import StratifiedKFold
from skopt import BayesSearchCV
from xgboost import XGBClassifier
Expand Down Expand Up @@ -53,14 +55,14 @@ def do_classification(
model, optimizer_cv = get_classifier(model_name, validation_cv, *args)

# Separate embeddings from labels in data
labels = data['label'].values
labels = list(data['label'].values)
data = data.drop(columns='label')

if rand_labels:
np.random.shuffle(labels)

if len(np.unique(labels)) > 2:
multi_roc_auc = metrics.make_scorer(multiclass_score_func, metric_func=metrics.roc_auc_score)
multi_roc_auc: Callable = metrics.make_scorer(multiclass_score_func, metric_func=metrics.roc_auc_score)
optimizer = get_optimizer(optimizer_name, model, model_name, optimizer_cv, multi_roc_auc)
optimizer.fit(data, labels)

Expand Down Expand Up @@ -94,8 +96,14 @@ def do_classification(
return cv_results


def _do_multiclass_classification(estimator: BaseEstimator, x: pd.DataFrame, y: pat.Series[str | int | float], cv: int, scoring: List[str],
return_estimator: bool = True) -> Dict[str, Any]:
def _do_multiclass_classification(
estimator: BaseEstimator,
x: pd.DataFrame,
y: Union[pat.Series[str | int | float], List[Any]],
cv: int,
scoring: List[str],
return_estimator: bool = True
) -> Dict[str, Any]:
"""Do multiclass classification using OneVsRest classifier.
:param estimator: estimator/classifier that should be used for cross validation
Expand All @@ -121,7 +129,7 @@ def _do_multiclass_classification(estimator: BaseEstimator, x: pd.DataFrame, y:
for run_num, (train_indexes, test_indexes) in enumerate(k_fold.split(x, y)):
logger.debug(f"\nCurrent Run number: {run_num}\n")
# Make a One-Hot encoding of the classes
y = preprocessing.label_binarize(y, classes=unique_labels)
y = preprocessing.label_binarize(y, classes=unique_labels) # type: ignore

x_train = np.asarray(
[x.iloc[train_index, :].values.tolist()
Expand Down Expand Up @@ -155,7 +163,7 @@ def _do_multiclass_classification(estimator: BaseEstimator, x: pd.DataFrame, y:
logger.debug(f"y_true:\n {y_test}\n\n")

if return_estimator:
cv_results['estimator'].append(clf.estimator)
cv_results['estimator'].append(clf.estimators_)

# For the multiclass metric find the score and add it to cv_results.
for metric in scoring:
Expand Down Expand Up @@ -225,19 +233,27 @@ def _do_multiclass_classification(estimator: BaseEstimator, x: pd.DataFrame, y:
return cv_results


def _multiclass_metric_evaluator(metric_func: Callable[..., float], n_classes: int, y_test: npt.NDArray[Any],
y_pred: npt.NDArray[Any], **kwargs: str) -> float:
def _multiclass_metric_evaluator(
metric_func: Callable[..., Union[float, np.float16, np.float32, np.float64, npt.NDArray[Any]]],
n_classes: int,
y_test: npt.NDArray[Any],
y_pred: npt.NDArray[Any],
**kwargs: str
) -> float | npt.NDArray[Any]:
"""Calculate the average metric for multiclass classifiers."""
metric = 0.0

for label in range(n_classes):
metric += metric_func(y_test[:, label], y_pred[:, label], **kwargs)
metric /= n_classes

return metric
if isinstance(metric, np.ndarray):
return metric

return float(metric)


def get_classifier(model_name: str, cv_opt: int, *args: str) -> Tuple[BaseEstimator, StratifiedKFold]:
def get_classifier(model_name: str, cv_opt: int, *args: Any) -> Tuple[BaseEstimator, StratifiedKFold]:
"""Retrieve the appropriate classifier from sci-kit learn based on the arguments."""
cv = model_selection.StratifiedKFold(n_splits=cv_opt, shuffle=True)

Expand Down Expand Up @@ -270,8 +286,8 @@ def get_optimizer(
estimator: BaseEstimator,
model: str,
cv: StratifiedKFold,
scorer: str
) -> BaseEstimator:
scorer: str | Callable
) -> BaseSearchCV:
"""Retrieve the appropriate optimizer from sci-kit learn based on the arguments."""
if optimizer == 'grid_search':
param_grid = constants.get_param_grid(model)
Expand All @@ -295,8 +311,8 @@ def multiclass_score_func(y: npt.NDArray[Any], y_pred: npt.NDArray[Any], metric_
if n_classes == 2:
return metric_func(y, y_pred)

y = preprocessing.label_binarize(y, classes=classes)
y_pred = preprocessing.label_binarize(y_pred, classes=classes)
y = preprocessing.label_binarize(y, classes=classes) # type: ignore
y_pred = preprocessing.label_binarize(y_pred, classes=classes) # type: ignore

metric = 0.0

Expand All @@ -312,11 +328,11 @@ def _save_json(results: Dict[str, Any], out_dir: str) -> None:
"""Save the cross validation results as a json file."""
for key in results.keys():
# Check if the result is a numpy array, if yes convert to list
if isinstance(results[key], npt.NDArray):
if isinstance(results[key], np.ndarray):
results[key] = results[key].tolist()

# Check if the results are numpy float values, if yes skip it
elif isinstance(results[key][0], np.float64) or isinstance(results[key][0], np.float32):
elif isinstance(results[key][0], np.floating):
continue

elif isinstance(results[key][0], list):
Expand Down

0 comments on commit b0f6644

Please sign in to comment.