Skip to content

Group work for each commit #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

tushar00jain
Copy link
Contributor

@tushar00jain tushar00jain commented Jun 4, 2025

Summary

Context

  • should_commit currently waits for all scheduled work but we need to only wait for work for a fragment
  • otherwise, we need to wait for all fragments, which just makes it regular diloco

Implementation

  • removed cuda synchronize when we check errors on process group otherwise we also wait for all the work to finish on the process group
  • instead we synchronize on the current stream but this might not be needed though because the current steam is doing training only -- @d4l3k? also baby pg doesn't override errored() method
  • remove waiting on pending futures in the manager, the callers need to ensure now that they wait for the futures to complete
  • had to change manager allreduce to support using custom streams and sprinkle event to this stream when future is ready
    • this way the users of the api can synchronize on the custom stream to make the cpu wait for the work to finish
    • it avoids exposing Work to users
    • this also works for baby nccl, currently it was just syncing on the main stream to wait for the allreduce work to finish

TODO's

  • need to make sure streaming also works with all pg's with more tests
  • removed support from quantization for now
    • it calls Work::wait() inside a stream which blocks the cpu because of the syncrhonize call, we'll need to change this otherwise we can't overlap communication with computation
    • need to make sure it works with all pg's
  • unit/integration tests for streaming diloco
  • tests to make sure the sync logic works as intended with all pg's

Test Plan

  • added a script train_diloco.py and gathered some profiles to make sure fragment send/sync happens at the right steps (using nccl)
  • compare streaming vs non-streaming
  • streaming performs better

Streaming

  • 2 fragments, H=10, T=5
  • fragments get sent and synced at every 5 steps
  • there's still a problem -- if there's already an allreduce in flight, a wait will block the allreduce that's issued on the current step as well. likely because there being 1 stream for networking. we need to create a separate stream for each of these allreduce to fix this. ideally nccl pg can offer an api to specify which stream users want to use
image

Non-Streaming

  • 1 fragment, H=10, T=0
  • fragments get sent and synced (in the same step) at every 10 steps
image

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 4, 2025
@tushar00jain tushar00jain force-pushed the feature/group-futures branch 2 times, most recently from 61ec926 to b4d5433 Compare June 4, 2025 06:08
Copilot

This comment was marked as outdated.

@tushar00jain tushar00jain force-pushed the feature/group-futures branch 5 times, most recently from d9acc2d to 76776dd Compare June 5, 2025 16:24
| torch.futures.Future[torch.Tensor]
| torch.futures.Future[List[torch.Tensor]]
] = None
if should_quantize and IS_TRITON_AVAILABLE:
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slightly unrelated to your PR but AVG is using world size right and not num_participants? Wonder if that's an issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah changed to using SUM since that's supported now

@@ -288,8 +288,33 @@ def shutdown(self, wait: bool = True) -> None:
self._manager.shutdown()
self._executor.shutdown(wait=wait)

def collect_all_allreduce(
self, tensors: List[torch.Tensor], should_quantize: bool = False
) -> torch.futures.Future[List[torch.futures.Future[torch.Tensor]]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're doing a bulk operation we should consider using allreduce_coalesced instead of N individual collective operations and matching the naming scheme

@tushar00jain tushar00jain force-pushed the feature/group-futures branch 6 times, most recently from ad04693 to 6a53641 Compare June 6, 2025 00:35
@tushar00jain tushar00jain requested a review from Copilot June 6, 2025 00:40
Copilot

This comment was marked as outdated.

@tushar00jain tushar00jain force-pushed the feature/group-futures branch 8 times, most recently from 5b41219 to a6ec9ef Compare June 7, 2025 00:32
@tushar00jain tushar00jain force-pushed the feature/group-futures branch 4 times, most recently from a3b7ea2 to 2b363b5 Compare June 7, 2025 01:38
@tushar00jain tushar00jain requested review from d4l3k and Copilot June 7, 2025 01:43
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors and streamlines the synchronization work handling by migrating from direct future‐based APIs to the new Work API, while also updating various tests and process group behaviors. Key changes include:

  • Converting Manager.allreduce and related interfaces to return Work objects instead of futures.
  • Updating distributed training files (e.g. train_ddp.py, ddp.py) to accommodate the new Work API.
  • Refactoring process group and manager tests to align with the updated error and synchronization handling.

Reviewed Changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.

Show a summary per file
File Description
train_diloco.py Introduces a training script using the new Work API.
train_ddp.py Updates DistributedSampler parameter from replica_group to replica_rank.
torchft/work.py Adds DummyWork and ErrorSwallowingWork classes to support Work API.
torchft/process_group.py Updates CUDA synchronization and replaces an internal _DummyWork with DummyWork.
torchft/manager_test.py Adjusts tests to match new handling of pending work and numeral synchronization.
torchft/manager.py Changes allreduce to return Work and removes pending future tracking.
torchft/local_sgd.py Switches to storing Work objects for pending allreduces and revises fragment support.
torchft/ddp_test.py Adapts tests for the new Work type using DummyWork.
torchft/ddp.py Modifies the communication hook to use the new Work API for allreduce.
Comments suppressed due to low confidence (2)

torchft/manager.py:320

  • Returning DummyWork(None) when an error is detected bypasses the usual scaling and data propagation in an allreduce operation. Please ensure that this behavior is consistent with the overall API design and that downstream consumers of the Work object can handle a 'None' result appropriately.
return DummyWork(None)

torchft/local_sgd.py:468

  • [nitpick] The removed check for multiple fragments may introduce inconsistencies since DiLoCo still enforces a single fragment constraint. Review the design to ensure that both LocalSGD and DiLoCo have consistent expectations regarding fragment handling and update the documentation accordingly.
if len(model_fragments) != 1:

@tushar00jain tushar00jain force-pushed the feature/group-futures branch 3 times, most recently from 1469dec to 6d12e6f Compare June 7, 2025 04:48
@tushar00jain tushar00jain marked this pull request as ready for review June 7, 2025 06:11
@tushar00jain tushar00jain force-pushed the feature/group-futures branch 2 times, most recently from add9c84 to 0e69a62 Compare June 8, 2025 03:33
@tushar00jain tushar00jain force-pushed the feature/group-futures branch from 0e69a62 to c2bfa46 Compare June 8, 2025 03:36
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants