Skip to content

return work from manager allreduce #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions torchft/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AllreduceOptions,
AllToAllOptions,
ReduceScatterOptions,
Work,
)
from torch.futures import Future

Expand Down Expand Up @@ -288,7 +289,7 @@ def allreduce_quantized(
opts: AllreduceOptions | ReduceOp,
process_group: "ProcessGroup",
sync_stream: cuda.Stream | None = None,
) -> Future[list[torch.Tensor]]:
) -> Work:
"""
Performs a quantized all-reduce operation on a list of tensors.

Expand Down Expand Up @@ -379,17 +380,26 @@ def allreduce_quantized(
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
_to_allgather_options(allreduce_opts),
)

# NOTE: This is not supposed to be used with gloo, only with NCCL.
# So we setup the stream dependency here by calling work.wait(),
# which doesn't block the CPU.
#
# The future callback below will run after the work has been
# completed.

work.wait()
fut = work.get_future()

def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
def callback(fut: Future[list[torch.Tensor]]) -> None:
# Dequantize and copy to output buffer.
nonlocal tensors, quantized_tensors, world_size, sync_stream

with torch.cuda.stream(sync_stream):
# Setup stream dependency
fut.wait()
# Dequantize the result back to the original precision
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
return tensors

fut = fut.then(callback)
return fut
fut.add_done_callback(callback)
return work
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be using _WorkWrapper ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's returned by manager. the manager calls this method inside allreduce

4 changes: 2 additions & 2 deletions torchft/collectives_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ def _run_all_reduce_collective(
)
]

fut = allreduce_quantized(tensors, reduce_op, pg)
fut.wait()
work = allreduce_quantized(tensors, reduce_op, pg)
work.wait()

work = pg.allreduce([expected], reduce_op)
work.get_future().wait()
Expand Down
3 changes: 2 additions & 1 deletion torchft/ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
def _comm_hook(
state: "Manager", bucket: dist.GradBucket
) -> torch.futures.Future[torch.Tensor]:
return state.allreduce(bucket.buffer())
work = state.allreduce(bucket.buffer())
return work.get_future()


class PureDistributedDataParallel(nn.Module):
Expand Down
10 changes: 6 additions & 4 deletions torchft/ddp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed.distributed_c10d import Work
from torch.futures import Future

from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
from torchft.manager import Manager
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
from torchft.work import _DummyWork


class TestDDP(TestCase):
Expand All @@ -39,14 +41,14 @@ def test_ddp(self) -> None:

call_count = 0

def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
def allreduce(
tensor: torch.Tensor,
) -> Work:
nonlocal call_count

call_count += 1

fut = Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

manager.allreduce = allreduce

Expand Down
33 changes: 21 additions & 12 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch
import torch.distributed as dist
from torch import nn, optim
from torch.distributed.distributed_c10d import Work
from torch.distributed.tensor import DTensor
from torch.nn.parameter import Parameter
from torch.optim.optimizer import Optimizer
Expand Down Expand Up @@ -200,7 +201,7 @@ def __init__(
self._outer_optimizer = outer_optimizer

# Stores pending all reduce
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
self._allreduce_work: list[Work] = []
self._stream: Optional[torch.cuda.Stream] = (
torch.cuda.Stream() if torch.cuda.is_available() else None
)
Expand Down Expand Up @@ -368,15 +369,15 @@ def wait(self) -> None:
"""
Waits for the previously scheduled allreduce to finish
"""
if len(self._allreduce_futures) == 0:
if len(self._allreduce_work) == 0:
return

if self._stream is not None:
assert self._stop_event is not None
self._stop_event.synchronize()
self._stop_event = None

self._allreduce_futures = []
self._allreduce_work = []

@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
def prepare_sync(self) -> None:
Expand All @@ -386,7 +387,7 @@ def prepare_sync(self) -> None:
"""
self._save_grads()

assert len(self._allreduce_futures) == 0
assert len(self._allreduce_work) == 0

# Make sure tensors are available to `_stream`
if self._stream is not None:
Expand All @@ -399,7 +400,7 @@ def prepare_sync(self) -> None:
):
self._average_grads()

for work in self._allreduce_futures:
for work in self._allreduce_work:
work.wait()

if self._stream is not None:
Expand All @@ -413,7 +414,7 @@ def perform_sync(self) -> bool:
steps using the outer optimizer.
"""
# Waiting for an allreduce before it has been sent is currently not supported.
assert len(self._allreduce_futures) > 0
assert len(self._allreduce_work) > 0

self.wait()

Expand Down Expand Up @@ -467,7 +468,8 @@ def _allreduce_per_param(self) -> None:
work = self._manager.allreduce(
self._grads[name], should_quantize=self.should_quantize
)
self._allreduce_futures.append(work)

self._allreduce_work.append(work)

def _bucketize_and_allreduce(
self,
Expand Down Expand Up @@ -513,12 +515,19 @@ def _bucketize_and_allreduce(
)

def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
nonlocal bucket_tensors, flat_buffer
for t, pack_offset, numel in bucket_tensors:
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
nonlocal bucket_tensors, flat_buffer
# Setup stream dependency
fut.wait()
for t, pack_offset, numel in bucket_tensors:
t.copy_(
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
)

fut = work.get_future()
fut.add_done_callback(callback)

work = work.then(callback)
self._allreduce_futures.append(work)
self._allreduce_work.append(work)

offset += chunk_size

Expand Down
25 changes: 12 additions & 13 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
import torch
from parameterized import parameterized
from torch import Tensor, nn, optim
from torch.distributed.distributed_c10d import Work
from torch.distributed.tensor import DTensor

from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
from torchft.manager import Manager
from torchft.work import _DummyWork


def create_manager() -> MagicMock:
Expand All @@ -26,6 +28,11 @@ def create_manager() -> MagicMock:

manager.errored.return_value = None

def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
return _DummyWork(tensor)

manager.allreduce.side_effect = mock_allreduce

return manager


Expand Down Expand Up @@ -66,7 +73,7 @@ class LocalSGDTest(TestCase):
def test_local_sgd_healthy(self) -> None:
model = SimpleModel()
optimizer = optim.SGD(model.parameters())
manager = create_autospec(Manager)
manager = create_manager()
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
self.assertEqual(local_sgd._local_step, 0)
inp = torch.rand(2, 3)
Expand Down Expand Up @@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None:
manager.should_commit.return_value = True

# Define fake allreduce: multiplies buffer by 2
def fake_allreduce(
tensor: Tensor, should_quantize: bool
) -> torch.futures.Future[Tensor]:
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
tensor.mul_(2)
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

manager.allreduce.side_effect = fake_allreduce

Expand Down Expand Up @@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None:
manager.should_commit.return_value = True

# Define fake allreduce: multiplies buffer by 2
def fake_allreduce(
tensor: Tensor, should_quantize: bool
) -> torch.futures.Future[Tensor]:
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
tensor.mul_(2)
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

manager.allreduce.side_effect = fake_allreduce

Expand Down
25 changes: 12 additions & 13 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@

import torch
from torch.distributed import ReduceOp, TCPStore
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work

from torchft._torchft import ManagerClient, ManagerServer
from torchft.checkpointing import CheckpointTransport, HTTPTransport
from torchft.futures import future_timeout
from torchft.work import _DummyWork, _WorkWrapper

if TYPE_CHECKING:
from torchft.process_group import ProcessGroup
Expand Down Expand Up @@ -343,9 +344,7 @@ def shutdown(self, wait: bool = True) -> None:
self._executor.shutdown(wait=wait)

@torch.profiler.record_function("torchft::manager::allreduce")
def allreduce(
self, tensor: torch.Tensor, should_quantize: bool = False
) -> torch.futures.Future[torch.Tensor]:
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
"""
Fault tolerant allreduce the tensor and return a Future that will be completed when
the tensor is ready.
Expand All @@ -365,9 +364,7 @@ def allreduce(
a Future that will be completed with the allreduced tensor
"""
if self.errored():
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

self.wait_quorum()
num_participants: int = self.num_participants()
Expand All @@ -380,13 +377,14 @@ def allreduce(
# Run the allreduce async and save the work object so we can wait on
# it later.
if should_quantize and IS_TRITON_AVAILABLE:
fut = allreduce_quantized(
work = allreduce_quantized(
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
)
else:
work = self._pg.allreduce([tensor], ReduceOp.SUM)
work.wait()
fut = work.get_future()

fut = work.get_future()

stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
Expand All @@ -403,6 +401,8 @@ def callback(
# change the stream to avoid making the callback stream
# dependent on process group stream running the allreduce
with torch.cuda.stream(stream) if stream is not None else nullcontext():
# Setup stream dependency
fut.wait()
fut.value()
tensor /= num_participants

Expand All @@ -411,17 +411,16 @@ def callback(
fut = fut.then(callback)

fut = self.wrap_future(fut, tensor)
return fut

return _WorkWrapper(work, fut)

except Exception as e:
self._logger.exception(
f"got exception in all reduce -- skipping remaining: {e}"
)
self.report_error(e)

fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut
return _DummyWork(tensor)

def report_error(self, e: Exception) -> None:
"""
Expand Down
4 changes: 2 additions & 2 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ def all_reduce_callback(

manager.start_quorum()
t1 = torch.ones((1, 3), device=device)
fut = manager.allreduce(t1)
fut.wait()
work = manager.allreduce(t1)
work.wait()
return t1
return None
11 changes: 6 additions & 5 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from torchft._torchft import QuorumResult
from torchft.checkpointing.transport import CheckpointTransport
from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode
from torchft.process_group import ProcessGroup, _DummyWork
from torchft.process_group import ProcessGroup
from torchft.work import _DummyWork


def mock_should_commit(
Expand Down Expand Up @@ -586,16 +587,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
manager._pg.allreduce.return_value = _DummyWork(None)

self.assertTrue(manager.is_participating())
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut = manager.allreduce(torch.tensor([1.0]))
work = manager.allreduce(torch.tensor([1.0]))
fut = work.get_future()
result = fut.value()
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))

# check healing numerics
manager._healing = True
self.assertFalse(manager.is_participating())
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut = manager.allreduce(torch.tensor([1.0]))
work = manager.allreduce(torch.tensor([1.0]))
fut = work.get_future()
result = fut.value()
torch.testing.assert_close(result, torch.tensor([0.0]))

Expand Down
Loading