Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sft_trainer incompatible with accelerator.gather_for_metrics #3047

Closed
5 tasks done
jamesbraza opened this issue Mar 10, 2025 · 0 comments · Fixed by #3048
Closed
5 tasks done

sft_trainer incompatible with accelerator.gather_for_metrics #3047

jamesbraza opened this issue Mar 10, 2025 · 0 comments · Fixed by #3048
Labels
⚡accelerate Related to accelerate 🐛 bug Something isn't working ⚡ PEFT Related to PEFT

Comments

@jamesbraza
Copy link
Contributor

jamesbraza commented Mar 10, 2025

Reproduction

When using trl.SFTTrainer as of current main at https://github.com/huggingface/trl/tree/e3244d2d096ff1e2e248c931d06d39e165e20623, I get this error

246 [rank0]:   File "/home/user/code/repo/scripts/sft_train.py", line 415, in main
247 [rank0]:     trainer.train()
248 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2245, in train
249 [rank0]:     return inner_training_loop(
250 [rank0]:            ^^^^^^^^^^^^^^^^^^^^
251 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2556, in _inner_training_loop
252 [rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
253 [rank0]:                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
254 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3706, in training_step
255 [rank0]:     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
256 [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
257 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 468, in compute_loss
258 [rank0]:     self._total_train_tokens += self.accelerator.gather_for_metrics(inputs["attention_mask"].sum().item())
259 [rank0]:                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
260 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 2581, in gather_for_metrics
261 [rank0]:     data = gather_object(input_data)
262 [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
263 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 459, in gather_object
264 [rank0]:     return _gpu_gather_object(object)
265 [rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^
266 [rank0]:   File "/home/user/code/repo/.venv/lib/python3.12/site-packages/accelerate/utils/operations.py", line 442, in _gpu_gather_object
267 [rank0]:     return [x for y in output_objects for x in y]
268 [rank0]:                                                ^
269 [rank0]: TypeError: 'int' object is not iterable

Here's what inputs["attention_mask"] looks like:

>>> inputs["attention_mask"].shape
torch.Size([6, 1178])
>>> inputs["attention_mask"]
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')
>>> inputs["attention_mask"].sum()
tensor(5316, device='cuda:0')

It seems #3012 hit this too, but closed it. Their solution is valid:

--- self._total_train_tokens += self.accelerator.gather_for_metrics(inputs["attention_mask"].sum().item())
+++ self._total_train_tokens += (
+++     self.accelerator.gather_for_metrics(inputs["attention_mask"].sum()).sum().item()
+++ )

System Info

  • Platform: Linux-5.15.0-112-generic-x86_64-with-glibc2.35
  • Python version: 3.12.9
  • TRL version: 0.16.0.dev0
  • PyTorch version: 2.5.1+cu124
  • CUDA device(s): NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3, NVIDIA H100 80GB HBM3
  • Transformers version: 4.49.0
  • Accelerate version: 1.4.0
  • Accelerate config: not found
  • Datasets version: 3.3.2
  • HF Hub version: 0.29.2
  • bitsandbytes version: not installed
  • DeepSpeed version: 0.16.4
  • Diffusers version: not installed
  • Liger-Kernel version: 0.5.4
  • LLM-Blender version: not installed
  • OpenAI version: 1.65.5
  • PEFT version: 0.14.0
  • vLLM version: not installed

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
⚡accelerate Related to accelerate 🐛 bug Something isn't working ⚡ PEFT Related to PEFT
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant