Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Dec 10, 2024
1 parent 4a7cb29 commit ca8a0f7
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions optimum/onnxruntime/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""
The ORTTrainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task with ONNX Runtime.
"""

import functools
import math
import os
Expand Down Expand Up @@ -131,11 +132,11 @@ def __init__(self, model, args, label_smoother):
# Label smoothing
self.label_smoother = label_smoother

def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs, num_items_in_batch):
# The compute_model_plus_loss_internal is assigned once the class is instantiated.
# It should have same signature as Trainer.compute_loss().
# We do this to avoid potential un-synced states if we duplicated compute loss codes .
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs)
return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs, num_items_in_batch)

@property
def module(self):
Expand Down Expand Up @@ -291,14 +292,14 @@ def _set_signature_columns_if_needed(self):
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))

def compute_loss(self, model_with_loss, inputs, return_outputs=False):
def compute_loss(self, model_with_loss, inputs, return_outputs=False, num_items_in_batch=None):
# Run model forward + loss compute.
if isinstance(self.model, ModuleWithLoss):
# ORTModule Does not support the BatchEncoding Type so we have to convert to a dict.
dict_inputs = dict(inputs.items())
return model_with_loss(dict_inputs, return_outputs)
return model_with_loss(dict_inputs, return_outputs, num_items_in_batch)
else:
return super().compute_loss(model_with_loss, inputs, return_outputs)
return super().compute_loss(model_with_loss, inputs, return_outputs, num_items_in_batch)

def train(
self,
Expand Down

0 comments on commit ca8a0f7

Please sign in to comment.