Skip to content

Commit

Permalink
fixing custom metric (#357)
Browse files Browse the repository at this point in the history
* fixing the error for custom metric
  • Loading branch information
liususan091219 authored Dec 24, 2021
1 parent c6c0c29 commit b2900f4
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 50 deletions.
85 changes: 47 additions & 38 deletions flaml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def __init__(self, task="seq-classification", **config):
from transformers import TrainingArguments
self._TrainingArguments = TrainingArguments

def _join(self, X_train, y_train):
@staticmethod
def _join(X_train, y_train):
y_train = DataFrame(y_train, columns=["label"], index=X_train.index)
train_df = X_train.join(y_train)
return train_df
Expand Down Expand Up @@ -367,12 +368,20 @@ def _init_hpo_args(self, automl_fit_kwargs: dict = None):
setattr(custom_hpo_args, key, val)
self.custom_hpo_args = custom_hpo_args

def _preprocess(self, X, y=None, task=None, **kwargs):
def _preprocess(self, X, y=None, **kwargs):
from .nlp.utils import tokenize_text

if X.dtypes[0] == "string":
# is_str = False
# for each_type in ["string", "str"]:
# try:
# is_str = is_str or (X.dtypes[0] == each_type)
# except TypeError:
# pass
is_str = str(X.dtypes[0]) in ("string", "str")

if is_str:
return tokenize_text(
X=X, Y=y, task=task, custom_hpo_args=self.custom_hpo_args
X=X, Y=y, task=self._task, custom_hpo_args=self.custom_hpo_args
)
else:
return X, None
Expand Down Expand Up @@ -445,14 +454,16 @@ def on_epoch_end(self, args, state, control, **callback_kwargs):
y_val = kwargs.get("y_val")

if self._task not in NLG_TASKS:
self._X_train, _ = self._preprocess(X=X_train, task=self._task, **kwargs)
self._X_train, _ = self._preprocess(X=X_train, **kwargs)
self._y_train = y_train
else:
self._X_train, self._y_train = self._preprocess(
X=X_train, y=y_train, task=self._task, **kwargs
X=X_train, y=y_train, **kwargs
)

train_dataset = Dataset.from_pandas(self._join(self._X_train, self._y_train))
train_dataset = Dataset.from_pandas(
TransformersEstimator._join(self._X_train, self._y_train)
)

# TODO: set a breakpoint here, observe the resulting train_dataset,
# compare it with the output of the tokenized results in your transformer example
Expand All @@ -462,13 +473,13 @@ def on_epoch_end(self, args, state, control, **callback_kwargs):

if X_val is not None:
if self._task not in NLG_TASKS:
self._X_val, _ = self._preprocess(X=X_val, task=self._task, **kwargs)
self._X_val, _ = self._preprocess(X=X_val, **kwargs)
self._y_val = y_val
else:
self._X_val, self._y_val = self._preprocess(
X=X_val, y=y_val, task=self._task, **kwargs
)
eval_dataset = Dataset.from_pandas(self._join(self._X_val, self._y_val))
self._X_val, self._y_val = self._preprocess(X=X_val, y=y_val, **kwargs)
eval_dataset = Dataset.from_pandas(
TransformersEstimator._join(self._X_val, self._y_val)
)
else:
eval_dataset = None

Expand Down Expand Up @@ -589,7 +600,7 @@ def _select_checkpoint(self, trainer):

if trainer.ckpt_to_metric:
best_ckpt, _ = min(
trainer.ckpt_to_metric.items(), key=lambda x: x[1]["val_loss"]
trainer.ckpt_to_metric.items(), key=lambda x: x[1]["eval_loss"]
)
best_ckpt_global_step = trainer.ckpt_to_global_step[best_ckpt]
for each_ckpt in list(trainer.ckpt_to_metric):
Expand All @@ -609,30 +620,28 @@ def _select_checkpoint(self, trainer):
return best_ckpt

def _compute_metrics_by_dataset_name(self, eval_pred):
from .ml import metric_loss_score
from .nlp.utils import postprocess_text

predictions, labels = eval_pred

if self._task in NLG_TASKS:
if isinstance(predictions, tuple):
predictions = np.argmax(predictions[0], axis=2)
decoded_preds = self._tokenizer.batch_decode(
predictions, skip_special_tokens=True
)
labels = np.where(labels != -100, labels, self._tokenizer.pad_token_id)
decoded_labels = self._tokenizer.batch_decode(
labels, skip_special_tokens=True
)
predictions, labels = postprocess_text(decoded_preds, decoded_labels)
else:
predictions = (
np.squeeze(predictions)
if self._task == SEQREGRESSION
else np.argmax(predictions, axis=1)
)

if isinstance(self._metric, str):
from .ml import metric_loss_score
from .nlp.utils import postprocess_text

predictions, labels = eval_pred
if self._task in NLG_TASKS:
if isinstance(predictions, tuple):
predictions = np.argmax(predictions[0], axis=2)
decoded_preds = self._tokenizer.batch_decode(
predictions, skip_special_tokens=True
)
labels = np.where(labels != -100, labels, self._tokenizer.pad_token_id)
decoded_labels = self._tokenizer.batch_decode(
labels, skip_special_tokens=True
)
predictions, labels = postprocess_text(decoded_preds, decoded_labels)
else:
predictions = (
np.squeeze(predictions)
if self._task == SEQREGRESSION
else np.argmax(predictions, axis=1)
)
return {
"val_loss": metric_loss_score(
metric_name=self._metric, y_predict=predictions, y_true=labels
Expand All @@ -659,7 +668,7 @@ def predict_proba(self, X_test):
from transformers import TrainingArguments
from .nlp.utils import load_model

X_test, _ = self._preprocess(X_test, task=self._task, **self._kwargs)
X_test, _ = self._preprocess(X_test, **self._kwargs)
test_dataset = Dataset.from_pandas(X_test)

best_model = load_model(
Expand All @@ -681,7 +690,7 @@ def predict(self, X_test):
from .nlp.utils import load_model
from .nlp.huggingface.trainer import TrainerForAuto

X_test, _ = self._preprocess(X=X_test, task=self._task, **self._kwargs)
X_test, _ = self._preprocess(X=X_test, **self._kwargs)
test_dataset = Dataset.from_pandas(X_test)

best_model = load_model(
Expand Down
10 changes: 5 additions & 5 deletions flaml/nlp/huggingface/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ def evaluate(
ignore_keys,
metric_key_prefix,
)
if metrics:
for key in list(metrics.keys()):
if key.startswith("eval_"):
metrics[key[5:]] = metrics.pop(key)
# if metrics:
# for key in list(metrics.keys()):
# if key.startswith("eval_"):
# metrics[key[5:]] = metrics.pop(key)
if hasattr(self, "ckpt_to_global_step"):
self.ckpt_to_global_step[ckpt_dir] = self.state.global_step
if metrics:
self.ckpt_to_metric[ckpt_dir] = metrics
else:
self.ckpt_to_global_step = {ckpt_dir: self.state.global_step}
self.ckpt_to_metric = {ckpt_dir: metrics} if metrics else {}

return metrics

# TODO: if your task is SUMMARIZATION, you need a different
# class Seq2SeqTrainerForAuto, uncomment the code below
Expand Down
32 changes: 25 additions & 7 deletions test/nlp/test_autohf_custom_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest


def toy_metric(
def custom_metric(
X_test,
y_test,
estimator,
Expand All @@ -15,11 +15,25 @@ def toy_metric(
groups_test=None,
groups_train=None,
):
return 0, {
"val_loss": 0,
"train_loss": 0,
"pred_time": 0,
}
from datasets import Dataset
from flaml.model import TransformersEstimator

if y_test is not None:
X_test, _ = estimator._preprocess(X_test)
eval_dataset = Dataset.from_pandas(TransformersEstimator._join(X_test, y_test))
else:
X_test, _ = estimator._preprocess(X_test)
eval_dataset = Dataset.from_pandas(X_test)

trainer = estimator._model

trainer_compute_metrics_cache = trainer.compute_metrics
trainer.compute_metrics = None

metrics = trainer.evaluate(eval_dataset)
trainer.compute_metrics = trainer_compute_metrics_cache

return metrics["eval_loss"], metrics


@pytest.mark.skipif(sys.platform == "darwin", reason="do not run on mac os")
Expand Down Expand Up @@ -54,7 +68,7 @@ def test_custom_metric():
"max_iter": 1,
"time_budget": 5,
"task": "seq-classification",
"metric": toy_metric,
"metric": custom_metric,
"log_file_name": "seqclass.log",
}

Expand All @@ -77,3 +91,7 @@ def test_custom_metric():
)

del automl


if __name__ == "__main__":
test_custom_metric()

0 comments on commit b2900f4

Please sign in to comment.