-
Notifications
You must be signed in to change notification settings - Fork 33
Group work for each commit #209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
61ec926
to
b4d5433
Compare
d9acc2d
to
76776dd
Compare
torchft/manager.py
Outdated
| torch.futures.Future[torch.Tensor] | ||
| torch.futures.Future[List[torch.Tensor]] | ||
] = None | ||
if should_quantize and IS_TRITON_AVAILABLE: | ||
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
slightly unrelated to your PR but AVG is using world size right and not num_participants? Wonder if that's an issue
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah changed to using SUM
since that's supported now
torchft/manager.py
Outdated
@@ -288,8 +288,33 @@ def shutdown(self, wait: bool = True) -> None: | |||
self._manager.shutdown() | |||
self._executor.shutdown(wait=wait) | |||
|
|||
def collect_all_allreduce( | |||
self, tensors: List[torch.Tensor], should_quantize: bool = False | |||
) -> torch.futures.Future[List[torch.futures.Future[torch.Tensor]]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we're doing a bulk operation we should consider using allreduce_coalesced
instead of N individual collective operations and matching the naming scheme
ad04693
to
6a53641
Compare
5b41219
to
a6ec9ef
Compare
a3b7ea2
to
2b363b5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors and streamlines the synchronization work handling by migrating from direct future‐based APIs to the new Work API, while also updating various tests and process group behaviors. Key changes include:
- Converting Manager.allreduce and related interfaces to return Work objects instead of futures.
- Updating distributed training files (e.g. train_ddp.py, ddp.py) to accommodate the new Work API.
- Refactoring process group and manager tests to align with the updated error and synchronization handling.
Reviewed Changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
train_diloco.py | Introduces a training script using the new Work API. |
train_ddp.py | Updates DistributedSampler parameter from replica_group to replica_rank. |
torchft/work.py | Adds DummyWork and ErrorSwallowingWork classes to support Work API. |
torchft/process_group.py | Updates CUDA synchronization and replaces an internal _DummyWork with DummyWork. |
torchft/manager_test.py | Adjusts tests to match new handling of pending work and numeral synchronization. |
torchft/manager.py | Changes allreduce to return Work and removes pending future tracking. |
torchft/local_sgd.py | Switches to storing Work objects for pending allreduces and revises fragment support. |
torchft/ddp_test.py | Adapts tests for the new Work type using DummyWork. |
torchft/ddp.py | Modifies the communication hook to use the new Work API for allreduce. |
Comments suppressed due to low confidence (2)
torchft/manager.py:320
- Returning DummyWork(None) when an error is detected bypasses the usual scaling and data propagation in an allreduce operation. Please ensure that this behavior is consistent with the overall API design and that downstream consumers of the Work object can handle a 'None' result appropriately.
return DummyWork(None)
torchft/local_sgd.py:468
- [nitpick] The removed check for multiple fragments may introduce inconsistencies since DiLoCo still enforces a single fragment constraint. Review the design to ensure that both LocalSGD and DiLoCo have consistent expectations regarding fragment handling and update the documentation accordingly.
if len(model_fragments) != 1:
1469dec
to
6d12e6f
Compare
add9c84
to
0e69a62
Compare
0e69a62
to
c2bfa46
Compare
Summary
Context
should_commit
currently waits for all scheduled work but we need to only wait for work for a fragmentImplementation
errored()
methodallreduce
to support using custom streams and sprinkle event to this stream when future is readyWork
to usersTODO's
Work::wait()
inside a stream which blocks the cpu because of the syncrhonize call, we'll need to change this otherwise we can't overlap communication with computationTest Plan
train_diloco.py
and gathered some profiles to make sure fragment send/sync happens at the right steps (using nccl)Streaming
Non-Streaming