Skip to content

Commit

Permalink
fix zerodivision (#1000)
Browse files Browse the repository at this point in the history
* fix zerodivision

* update

* remove final

---------

Co-authored-by: Li Jiang <[email protected]>
  • Loading branch information
liususan091219 and thinkall authored Apr 23, 2023
1 parent da0d8c0 commit 7114b8f
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions flaml/automl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,8 +1136,7 @@ def predict_proba(self, X, **pred_kwargs):
except ZeroDivisionError:
logger.warning("Zero division error appeared in HuggingFace Transformers.")
predictions = np.array([-0.05] * len(test_dataset))
else:
return predictions
return predictions

def score(self, X_val: DataFrame, y_val: Series, **kwargs):
import transformers
Expand Down Expand Up @@ -1169,14 +1168,13 @@ def predict(self, X, **pred_kwargs):

kwargs = {} if self._task not in NLG_TASKS else {"metric_key_prefix": "predict"}
try:
predictions = new_trainer.predict(test_dataset, **kwargs)
predictions = new_trainer.predict(test_dataset, **kwargs).predictions
except ZeroDivisionError:
logger.warning("Zero division error appeared in HuggingFace Transformers.")
predictions = np.array([0] * len(test_dataset))

post_y_pred, _ = postprocess_prediction_and_true(
task=self._task,
y_pred=predictions.predictions,
y_pred=predictions,
tokenizer=self.tokenizer,
hf_args=self._training_args,
X=X,
Expand Down

0 comments on commit 7114b8f

Please sign in to comment.