Skip to content

Commit

Permalink
hmm
Browse files Browse the repository at this point in the history
  • Loading branch information
Parry-Parry committed Nov 26, 2024
1 parent 8b52740 commit 8fad54a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions rankers/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datasets import Dataset
from transformers.trainer_utils import EvalLoopOutput, speed_metrics
from transformers.integrations.deepspeed import deepspeed_init
from dataclasses import dataclass
from dataclasses import is_dataclass
from .loss import LOSS_REGISTRY

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -65,9 +65,9 @@ def compute_loss(
"The model did not return a loss from the inputs, only the following keys: "
f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
)
# We don't use .loss here since the model may return tuples instead of ModelOutput.
loss = outputs.loss if isinstance(outputs, dataclass) else outputs[0]

loss = outputs.loss if is_dataclass(outputs) else outputs[0]

return (loss, outputs) if return_outputs else loss

def compute_metrics(self, result_frame: pd.DataFrame):
Expand Down

0 comments on commit 8fad54a

Please sign in to comment.