Skip to content

Commit

Permalink
fix conflict in trainer.py (#75)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
Cemberk and github-actions[bot] authored Dec 3, 2024
1 parent acdc0ff commit 3c34382
Showing 1 changed file with 3 additions and 65 deletions.
68 changes: 3 additions & 65 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2483,73 +2483,11 @@ def _inner_training_loop(
if not do_sync_step:
self.accelerator.gradient_state._set_sync_gradients(False)
else:
self.state.num_input_tokens_seen += (
torch.sum(
self.accelerator.gather(
torch.tensor(
inputs[main_input_name].numel(), device=self.args.device, dtype=torch.int64
)
)
)
.cpu()
.item()
)
if rng_to_sync:
self._load_rng_state(resume_from_checkpoint)
rng_to_sync = False

if (self.state.global_step == 10):
start_train_stable_time = time.time()

# Skip past any already trained steps if resuming training
if steps_trained_in_current_epoch > 0:
steps_trained_in_current_epoch -= 1
if steps_trained_progress_bar is not None:
steps_trained_progress_bar.update(1)
if steps_trained_in_current_epoch == 0:
self._load_rng_state(resume_from_checkpoint)
continue
elif steps_trained_progress_bar is not None:
steps_trained_progress_bar.close()
steps_trained_progress_bar = None

if step % args.gradient_accumulation_steps == 0:
self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

with self.accelerator.accumulate(model):
tr_loss_step = self.training_step(model, inputs)

if (
args.logging_nan_inf_filter
and not is_torch_xla_available()
and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
):
# if loss is nan or inf simply add the average of previous logged losses
tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
else:
if tr_loss.device != tr_loss_step.device:
raise ValueError(
f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
)
tr_loss += tr_loss_step

self.current_flos += float(self.floating_point_ops(inputs))

is_last_step_and_steps_less_than_grad_acc = (
steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
)

if (
total_batched_samples % args.gradient_accumulation_steps == 0
or
# last step in epoch but step is always smaller than gradient_accumulation_steps
is_last_step_and_steps_less_than_grad_acc
):
# the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
# in accelerate. So, explicitly enable sync gradients to True in that case.
if is_last_step_and_steps_less_than_grad_acc:
self.accelerator.gradient_state._set_sync_gradients(True)

if (self.state.global_step == 10):
start_train_stable_time = time.time()

if self.args.include_num_input_tokens_seen:
main_input_name = getattr(self.model, "main_input_name", "input_ids")
if main_input_name not in inputs:
Expand Down

0 comments on commit 3c34382

Please sign in to comment.