Skip to content

Commit a3b7ea2

Browse files
committed
Group allreduce futures
1 parent db07843 commit a3b7ea2

File tree

9 files changed

+325
-81
lines changed

9 files changed

+325
-81
lines changed

torchft/ddp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,15 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
work = state.allreduce(bucket.buffer())
72+
fut = work.get_future()
73+
74+
def callback(fut: torch.futures.Future[None]) -> torch.Tensor:
75+
nonlocal bucket
76+
return bucket.buffer()
77+
78+
fut = fut.then(callback)
79+
return fut
7280

7381

7482
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import torch
1111
import torch.distributed as dist
1212
from torch import nn
13+
from torch.distributed.distributed_c10d import Work
1314
from torch.futures import Future
1415

1516
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
1617
from torchft.manager import Manager
1718
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
19+
from torchft.work import DummyWork
1820

1921

2022
class TestDDP(TestCase):
@@ -39,14 +41,12 @@ def test_ddp(self) -> None:
3941

4042
call_count = 0
4143

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
44+
def allreduce(tensor: torch.Tensor) -> Work:
4345
nonlocal call_count
4446

4547
call_count += 1
4648

47-
fut = Future() # pyre-fixme[29]: not a function
48-
fut.set_result(tensor)
49-
return fut
49+
return DummyWork(None)
5050

5151
manager.allreduce = allreduce
5252

torchft/local_sgd.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.distributed as dist
1919
from torch import nn, optim
20+
from torch.distributed.distributed_c10d import Work
2021
from torch.distributed.tensor import DTensor
2122
from torch.nn.parameter import Parameter
2223
from torch.optim.optimizer import Optimizer
@@ -197,9 +198,7 @@ def __init__(
197198
self._outer_optimizer = outer_optimizer
198199

199200
# Stores pending all reduce
200-
self._allreduce_futures: list[
201-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
202-
] = []
201+
self._allreduce_futures: list[Work] = []
203202

204203
if bucket_cap_mb is not None:
205204
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -467,16 +466,6 @@ def __init__(
467466
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
468467
raise ValueError("fragment_update_alpha must be between 0 and 1")
469468

470-
# TODO: Support multiple fragments
471-
# This requires changing the manager to support `should_commit` for each
472-
# fragment separately.
473-
if len(model_fragments) != 1:
474-
raise ValueError("Multiple fragments are not supported yet")
475-
476-
# TODO: Support `fragment_sync_delay`
477-
if fragment_sync_delay != 0:
478-
raise ValueError("Fragment synchronization delay is not supported yet")
479-
480469
# TODO: Support `fragment_update_alpha`
481470
if fragment_update_alpha != 0.0:
482471
raise ValueError(
@@ -522,6 +511,8 @@ def __init__(
522511
use_bucketization,
523512
bucket_cap_mb,
524513
should_quantize,
514+
fragment_sync_delay,
515+
fragment_update_alpha,
525516
)
526517
for i, model_fragment in enumerate(model_fragments)
527518
]

torchft/manager.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939

4040
import torch
4141
from torch.distributed import ReduceOp, TCPStore
42-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
42+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4343

4444
from torchft._torchft import ManagerClient, ManagerServer
4545
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4646
from torchft.futures import future_timeout
47+
from torchft.work import DummyWork, ErrorSwallowingWork
4748

4849
if TYPE_CHECKING:
4950
from torchft.process_group import ProcessGroup
@@ -259,7 +260,6 @@ def __init__(
259260
self._quorum_id = -1
260261
self._errored: Optional[ExceptionWithTraceback] = None
261262
self._healing = False
262-
self._pending_work: List[torch.futures.Future[object]] = []
263263
self._batches_committed = 0
264264

265265
# first step is 1
@@ -296,9 +296,8 @@ def shutdown(self, wait: bool = True) -> None:
296296
self._manager.shutdown()
297297
self._executor.shutdown(wait=wait)
298298

299-
def allreduce(
300-
self, tensor: torch.Tensor, should_quantize: bool = False
301-
) -> torch.futures.Future[torch.Tensor]:
299+
@torch.profiler.record_function("torchft::manager::allreduce")
300+
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
302301
"""
303302
Fault tolerant allreduce the tensor and return a Future that will be completed when
304303
the tensor is ready.
@@ -318,9 +317,8 @@ def allreduce(
318317
a Future that will be completed with the allreduced tensor
319318
"""
320319
if self.errored():
321-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
322-
fut.set_result(tensor)
323-
return fut
320+
work = DummyWork(None)
321+
return work
324322

325323
self.wait_quorum()
326324

@@ -332,45 +330,44 @@ def allreduce(
332330
# Run the allreduce async and save the work object so we can wait on
333331
# it later.
334332
fut: Optional[
335-
torch.futures.Future[None]
336-
| torch.futures.Future[torch.Tensor]
337-
| torch.futures.Future[List[torch.Tensor]]
333+
torch.futures.Future[None] | torch.futures.Future[list[torch.Tensor]]
338334
] = None
335+
work: Optional[Work] = None
336+
339337
if should_quantize and IS_TRITON_AVAILABLE:
340-
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
338+
assert False, "allreduce_quantized is not supported yet"
339+
# TODO: Support `allreduce_quantized`
340+
# fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
341341
else:
342342
work = self._pg.allreduce([tensor], ReduceOp.SUM)
343+
assert work is not None
343344
fut = work.get_future()
344345

345346
# schedule grad normalization as a continuation
346347
# on the Future
347348
def callback(
348349
fut: torch.futures.Future[List[torch.Tensor]],
349-
) -> torch.Tensor:
350+
) -> None:
350351
nonlocal tensor
351352

352353
# check for exceptions
353354
fut.value()
354355

355356
tensor /= self.num_participants()
356357

357-
return tensor
358-
359358
assert fut is not None
360-
if not should_quantize:
361-
fut = fut.then(callback)
362-
fut = self.wrap_future(fut, tensor)
363-
return fut
364-
359+
fut = fut.then(callback)
360+
fut = self.wrap_future(fut, None)
361+
return ErrorSwallowingWork(work, self.report_error, None)
365362
except Exception as e:
366363
self._logger.exception(
367364
f"got exception in all reduce -- skipping remaining: {e}"
368365
)
369366
self.report_error(e)
370367

371-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
372-
fut.set_result(tensor)
373-
return fut
368+
work = DummyWork(None)
369+
370+
return work
374371

375372
def report_error(self, e: Exception) -> None:
376373
"""
@@ -429,7 +426,6 @@ def callback(
429426
return default
430427

431428
fut = fut.then(callback)
432-
self._pending_work.append(cast(torch.futures.Future[object], fut))
433429
return fut
434430

435431
def start_quorum(
@@ -562,7 +558,7 @@ def _async_quorum(
562558
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
563559
# We use the replica rank and world as we want all replicas in the PG.
564560
try:
565-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
561+
with torch.profiler.record_function("torchft::manager::_pg::configure"):
566562
self._pg.configure(
567563
store_prefixed_addr, replica_rank, replica_world_size
568564
)
@@ -694,21 +690,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
694690
Raises:
695691
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
696692
"""
697-
for work in self._pending_work:
698-
# check at the beginning of since .wait() may trigger errors
699-
if self._errored is not None:
700-
break
701-
702-
# We swallow the error at in a future then callback so this will
703-
# never return an error.
704-
work.wait()
705-
706693
# make sure recovery is complete before committing
707694
if self._recovery_stream is not None:
708695
self._recovery_stream.synchronize()
709696

710-
self._pending_work = []
711-
712697
if err := self._pg.errored():
713698
self.report_error(err)
714699

torchft/manager_test.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
164164

165165
manager.start_quorum()
166166
manager.allreduce(torch.tensor([1.0])).wait()
167-
self.assertEqual(len(manager._pending_work), 1)
168167
self.assertTrue(manager.should_commit())
169-
self.assertEqual(len(manager._pending_work), 0)
170168

171169
self.assertEqual(manager._quorum_id, 123)
172170
self.assertEqual(manager.current_step(), 1)
@@ -554,8 +552,6 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
554552
self.assertIs(error.original_exception, e)
555553
self.assertEqual(wrapped_fut.value(), 2)
556554

557-
self.assertEqual(manager._pending_work, [wrapped_fut])
558-
559555
@patch("torchft.manager.ManagerClient", autospec=True)
560556
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
561557
manager = self._create_manager(timeout=timedelta(seconds=0.01))
@@ -590,18 +586,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
590586
manager._pg.allreduce.return_value = _DummyWork(None)
591587

592588
self.assertTrue(manager.is_participating())
593-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
594-
fut = manager.allreduce(torch.tensor([1.0]))
595-
result = fut.value()
596-
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
589+
tensor = torch.tensor([1.0])
590+
manager.allreduce(tensor).wait()
591+
torch.testing.assert_close(tensor, torch.tensor([1.0 / 5]))
597592

598593
# check healing numerics
599594
manager._healing = True
600595
self.assertFalse(manager.is_participating())
601-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
602-
fut = manager.allreduce(torch.tensor([1.0]))
603-
result = fut.value()
604-
torch.testing.assert_close(result, torch.tensor([0.0]))
596+
tensor = torch.tensor([1.0])
597+
manager.allreduce(tensor).wait()
598+
torch.testing.assert_close(tensor, torch.tensor([0.0]))
605599

606600
@patch("torchft.manager.ManagerClient", autospec=True)
607601
def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:

torchft/process_group.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchft.device_mesh import * # noqa: F401
7070
from torchft.futures import context_timeout, stream_timeout
7171
from torchft.multiprocessing import _MonitoredPipe
72+
from torchft.work import DummyWork
7273

7374
if TYPE_CHECKING:
7475
from torchft.manager import Manager
@@ -775,27 +776,15 @@ def abort(self) -> None:
775776

776777
def errored(self) -> Optional[Exception]:
777778
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
779+
torch.cuda.current_stream().synchronize()
779780

780781
return self._errored
781782

782783
def getBackendName(self) -> str:
783784
return "torchft-nccl"
784785

785786

786-
class _DummyWork(dist._Work):
787-
def __init__(self, result: object) -> None:
788-
super().__init__()
789-
self.result_ = result
790-
# pyre-fixme[29]: Future is not a function
791-
self.future_: torch.futures.Future[object] = torch.futures.Future()
792-
self.future_.set_result(result)
793-
794-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
795-
return True
796-
797-
def get_future(self) -> torch.futures.Future[object]:
798-
return self.future_
787+
_DummyWork = DummyWork
799788

800789

801790
class ProcessGroupDummy(ProcessGroup):

torchft/work.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import logging
2+
from datetime import timedelta
3+
from typing import TYPE_CHECKING, Callable, Optional
4+
5+
import torch
6+
import torch.distributed as dist
7+
from torch.distributed.distributed_c10d import Work
8+
9+
logger: logging.Logger = logging.getLogger(__name__)
10+
11+
12+
class DummyWork(dist._Work):
13+
def __init__(self, result: object) -> None:
14+
super().__init__()
15+
self.result_ = result
16+
# pyre-fixme[29]: Future is not a function
17+
self.future_: torch.futures.Future[object] = torch.futures.Future()
18+
self.future_.set_result(result)
19+
20+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
21+
return True
22+
23+
def get_future(self) -> torch.futures.Future[object]:
24+
return self.future_
25+
26+
27+
class ErrorSwallowingWork(Work):
28+
def __init__(
29+
self,
30+
work: Work,
31+
report_error: Callable[[Exception], None],
32+
default_result: object,
33+
) -> None:
34+
super().__init__()
35+
36+
self._work = work
37+
self._default_result = default_result
38+
self._report_error = report_error
39+
40+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
41+
try:
42+
self._work.wait()
43+
except Exception as e:
44+
self._report_error(e)
45+
46+
return True
47+
48+
def get_future(self) -> torch.futures.Future[object]:
49+
return self._work.get_future()

train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main() -> None:
5151
# majority of groups will be available so few batches will be dropped.
5252
sampler = DistributedSampler(
5353
trainset,
54-
replica_group=REPLICA_GROUP_ID,
54+
replica_rank=REPLICA_GROUP_ID,
5555
num_replica_groups=NUM_REPLICA_GROUPS,
5656
group_rank=0,
5757
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.

0 commit comments

Comments
 (0)