Skip to content

Commit a6ec9ef

Browse files
committed
Group allreduce futures
1 parent db07843 commit a6ec9ef

File tree

8 files changed

+275
-70
lines changed

8 files changed

+275
-70
lines changed

torchft/ddp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,15 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
work = state.allreduce(bucket.buffer())
72+
fut = work.get_future()
73+
74+
def callback(fut: torch.futures.Future[None]) -> torch.Tensor:
75+
nonlocal bucket
76+
return bucket.buffer()
77+
78+
fut = fut.then(callback)
79+
return fut
7280

7381

7482
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import torch
1111
import torch.distributed as dist
1212
from torch import nn
13+
from torch.distributed.distributed_c10d import Work
1314
from torch.futures import Future
1415

1516
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
1617
from torchft.manager import Manager
17-
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
18+
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo, _DummyWork
1819

1920

2021
class TestDDP(TestCase):
@@ -39,14 +40,13 @@ def test_ddp(self) -> None:
3940

4041
call_count = 0
4142

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
43+
def allreduce(tensor: torch.Tensor) -> Work:
4344
nonlocal call_count
4445

4546
call_count += 1
4647

47-
fut = Future() # pyre-fixme[29]: not a function
48-
fut.set_result(tensor)
49-
return fut
48+
work = _DummyWork(None)
49+
return work
5050

5151
manager.allreduce = allreduce
5252

torchft/local_sgd.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.distributed as dist
1919
from torch import nn, optim
20+
from torch.distributed.distributed_c10d import Work
2021
from torch.distributed.tensor import DTensor
2122
from torch.nn.parameter import Parameter
2223
from torch.optim.optimizer import Optimizer
@@ -197,9 +198,7 @@ def __init__(
197198
self._outer_optimizer = outer_optimizer
198199

199200
# Stores pending all reduce
200-
self._allreduce_futures: list[
201-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
202-
] = []
201+
self._allreduce_futures: list[Work] = []
203202

204203
if bucket_cap_mb is not None:
205204
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -467,16 +466,6 @@ def __init__(
467466
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
468467
raise ValueError("fragment_update_alpha must be between 0 and 1")
469468

470-
# TODO: Support multiple fragments
471-
# This requires changing the manager to support `should_commit` for each
472-
# fragment separately.
473-
if len(model_fragments) != 1:
474-
raise ValueError("Multiple fragments are not supported yet")
475-
476-
# TODO: Support `fragment_sync_delay`
477-
if fragment_sync_delay != 0:
478-
raise ValueError("Fragment synchronization delay is not supported yet")
479-
480469
# TODO: Support `fragment_update_alpha`
481470
if fragment_update_alpha != 0.0:
482471
raise ValueError(
@@ -522,6 +511,8 @@ def __init__(
522511
use_bucketization,
523512
bucket_cap_mb,
524513
should_quantize,
514+
fragment_sync_delay,
515+
fragment_update_alpha,
525516
)
526517
for i, model_fragment in enumerate(model_fragments)
527518
]

torchft/manager.py

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@
3939

4040
import torch
4141
from torch.distributed import ReduceOp, TCPStore
42-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
42+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4343

4444
from torchft._torchft import ManagerClient, ManagerServer
4545
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4646
from torchft.futures import future_timeout
4747

4848
if TYPE_CHECKING:
49-
from torchft.process_group import ProcessGroup
49+
from torchft.process_group import ProcessGroup, _DummyWork
5050

5151
IS_TRITON_AVAILABLE = True
5252
try:
@@ -259,7 +259,6 @@ def __init__(
259259
self._quorum_id = -1
260260
self._errored: Optional[ExceptionWithTraceback] = None
261261
self._healing = False
262-
self._pending_work: List[torch.futures.Future[object]] = []
263262
self._batches_committed = 0
264263

265264
# first step is 1
@@ -296,9 +295,8 @@ def shutdown(self, wait: bool = True) -> None:
296295
self._manager.shutdown()
297296
self._executor.shutdown(wait=wait)
298297

299-
def allreduce(
300-
self, tensor: torch.Tensor, should_quantize: bool = False
301-
) -> torch.futures.Future[torch.Tensor]:
298+
@torch.profiler.record_function("torchft::manager::allreduce")
299+
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
302300
"""
303301
Fault tolerant allreduce the tensor and return a Future that will be completed when
304302
the tensor is ready.
@@ -318,9 +316,8 @@ def allreduce(
318316
a Future that will be completed with the allreduced tensor
319317
"""
320318
if self.errored():
321-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
322-
fut.set_result(tensor)
323-
return fut
319+
work = _DummyWork(None)
320+
return work
324321

325322
self.wait_quorum()
326323

@@ -332,45 +329,44 @@ def allreduce(
332329
# Run the allreduce async and save the work object so we can wait on
333330
# it later.
334331
fut: Optional[
335-
torch.futures.Future[None]
336-
| torch.futures.Future[torch.Tensor]
337-
| torch.futures.Future[List[torch.Tensor]]
332+
torch.futures.Future[None] | torch.futures.Future[list[torch.Tensor]]
338333
] = None
334+
work: Optional[Work] = None
335+
339336
if should_quantize and IS_TRITON_AVAILABLE:
340-
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
337+
assert False, "allreduce_quantized is not supported yet"
338+
# TODO: Support `allreduce_quantized`
339+
# fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
341340
else:
342341
work = self._pg.allreduce([tensor], ReduceOp.SUM)
342+
assert work is not None
343343
fut = work.get_future()
344344

345345
# schedule grad normalization as a continuation
346346
# on the Future
347347
def callback(
348348
fut: torch.futures.Future[List[torch.Tensor]],
349-
) -> torch.Tensor:
349+
) -> None:
350350
nonlocal tensor
351351

352352
# check for exceptions
353353
fut.value()
354354

355355
tensor /= self.num_participants()
356356

357-
return tensor
358-
359357
assert fut is not None
360-
if not should_quantize:
361-
fut = fut.then(callback)
362-
fut = self.wrap_future(fut, tensor)
363-
return fut
364-
358+
fut = fut.then(callback)
359+
fut = self.wrap_future(fut, None)
360+
return work
365361
except Exception as e:
366362
self._logger.exception(
367363
f"got exception in all reduce -- skipping remaining: {e}"
368364
)
369365
self.report_error(e)
370366

371-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
372-
fut.set_result(tensor)
373-
return fut
367+
work = _DummyWork(None)
368+
369+
return work
374370

375371
def report_error(self, e: Exception) -> None:
376372
"""
@@ -429,7 +425,6 @@ def callback(
429425
return default
430426

431427
fut = fut.then(callback)
432-
self._pending_work.append(cast(torch.futures.Future[object], fut))
433428
return fut
434429

435430
def start_quorum(
@@ -562,7 +557,7 @@ def _async_quorum(
562557
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
563558
# We use the replica rank and world as we want all replicas in the PG.
564559
try:
565-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
560+
with torch.profiler.record_function("torchft::manager::_pg::configure"):
566561
self._pg.configure(
567562
store_prefixed_addr, replica_rank, replica_world_size
568563
)
@@ -694,21 +689,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
694689
Raises:
695690
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
696691
"""
697-
for work in self._pending_work:
698-
# check at the beginning of since .wait() may trigger errors
699-
if self._errored is not None:
700-
break
701-
702-
# We swallow the error at in a future then callback so this will
703-
# never return an error.
704-
work.wait()
705-
706692
# make sure recovery is complete before committing
707693
if self._recovery_stream is not None:
708694
self._recovery_stream.synchronize()
709695

710-
self._pending_work = []
711-
712696
if err := self._pg.errored():
713697
self.report_error(err)
714698

torchft/manager_test.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
164164

165165
manager.start_quorum()
166166
manager.allreduce(torch.tensor([1.0])).wait()
167-
self.assertEqual(len(manager._pending_work), 1)
168167
self.assertTrue(manager.should_commit())
169-
self.assertEqual(len(manager._pending_work), 0)
170168

171169
self.assertEqual(manager._quorum_id, 123)
172170
self.assertEqual(manager.current_step(), 1)
@@ -554,8 +552,6 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
554552
self.assertIs(error.original_exception, e)
555553
self.assertEqual(wrapped_fut.value(), 2)
556554

557-
self.assertEqual(manager._pending_work, [wrapped_fut])
558-
559555
@patch("torchft.manager.ManagerClient", autospec=True)
560556
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
561557
manager = self._create_manager(timeout=timedelta(seconds=0.01))
@@ -590,18 +586,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
590586
manager._pg.allreduce.return_value = _DummyWork(None)
591587

592588
self.assertTrue(manager.is_participating())
593-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
594-
fut = manager.allreduce(torch.tensor([1.0]))
595-
result = fut.value()
596-
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
589+
tensor = torch.tensor([1.0])
590+
manager.allreduce(tensor).wait()
591+
torch.testing.assert_close(tensor, torch.tensor([1.0 / 5]))
597592

598593
# check healing numerics
599594
manager._healing = True
600595
self.assertFalse(manager.is_participating())
601-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
602-
fut = manager.allreduce(torch.tensor([1.0]))
603-
result = fut.value()
604-
torch.testing.assert_close(result, torch.tensor([0.0]))
596+
tensor = torch.tensor([1.0])
597+
manager.allreduce(tensor).wait()
598+
torch.testing.assert_close(tensor, torch.tensor([0.0]))
605599

606600
@patch("torchft.manager.ManagerClient", autospec=True)
607601
def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:

torchft/process_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def abort(self) -> None:
775775

776776
def errored(self) -> Optional[Exception]:
777777
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
778+
torch.cuda.current_stream().synchronize()
779779

780780
return self._errored
781781

train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main() -> None:
5151
# majority of groups will be available so few batches will be dropped.
5252
sampler = DistributedSampler(
5353
trainset,
54-
replica_group=REPLICA_GROUP_ID,
54+
replica_rank=REPLICA_GROUP_ID,
5555
num_replica_groups=NUM_REPLICA_GROUPS,
5656
group_rank=0,
5757
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.

0 commit comments

Comments
 (0)