From 8fad54ad9933923939a1b0c5a5f16a324068db88 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CAndrew?= Date: Tue, 26 Nov 2024 10:14:28 +0000 Subject: [PATCH] hmm --- rankers/train/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rankers/train/trainer.py b/rankers/train/trainer.py index cc87d39..7227536 100644 --- a/rankers/train/trainer.py +++ b/rankers/train/trainer.py @@ -9,7 +9,7 @@ from datasets import Dataset from transformers.trainer_utils import EvalLoopOutput, speed_metrics from transformers.integrations.deepspeed import deepspeed_init -from dataclasses import dataclass +from dataclasses import is_dataclass from .loss import LOSS_REGISTRY logger = logging.getLogger(__name__) @@ -65,9 +65,9 @@ def compute_loss( "The model did not return a loss from the inputs, only the following keys: " f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}." ) - # We don't use .loss here since the model may return tuples instead of ModelOutput. - loss = outputs.loss if isinstance(outputs, dataclass) else outputs[0] - + + loss = outputs.loss if is_dataclass(outputs) else outputs[0] + return (loss, outputs) if return_outputs else loss def compute_metrics(self, result_frame: pd.DataFrame):