Skip to content

Commit 6d12e6f

Browse files
committed
Group allreduce futures
1 parent db07843 commit 6d12e6f

File tree

9 files changed

+328
-83
lines changed

9 files changed

+328
-83
lines changed

torchft/ddp.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,15 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
work = state.allreduce(bucket.buffer())
72+
fut = work.get_future()
73+
74+
def callback(fut: torch.futures.Future[None]) -> torch.Tensor:
75+
nonlocal bucket
76+
return bucket.buffer()
77+
78+
fut = fut.then(callback)
79+
return fut
7280

7381

7482
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import torch
1111
import torch.distributed as dist
1212
from torch import nn
13+
from torch.distributed.distributed_c10d import Work
1314
from torch.futures import Future
1415

1516
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
1617
from torchft.manager import Manager
1718
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
19+
from torchft.work import DummyWork
1820

1921

2022
class TestDDP(TestCase):
@@ -39,14 +41,12 @@ def test_ddp(self) -> None:
3941

4042
call_count = 0
4143

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
44+
def allreduce(tensor: torch.Tensor) -> Work:
4345
nonlocal call_count
4446

4547
call_count += 1
4648

47-
fut = Future() # pyre-fixme[29]: not a function
48-
fut.set_result(tensor)
49-
return fut
49+
return DummyWork(None)
5050

5151
manager.allreduce = allreduce
5252

torchft/local_sgd.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import torch
1818
import torch.distributed as dist
1919
from torch import nn, optim
20+
from torch.distributed.distributed_c10d import Work
2021
from torch.distributed.tensor import DTensor
2122
from torch.nn.parameter import Parameter
2223
from torch.optim.optimizer import Optimizer
@@ -197,9 +198,7 @@ def __init__(
197198
self._outer_optimizer = outer_optimizer
198199

199200
# Stores pending all reduce
200-
self._allreduce_futures: list[
201-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
202-
] = []
201+
self._allreduce_futures: list[Work] = []
203202

204203
if bucket_cap_mb is not None:
205204
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -272,6 +271,7 @@ def should_sync_fragment(self, step: int) -> bool:
272271
step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay
273272
return step_to_sync % self._sync_every == 0
274273

274+
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
275275
def prepare_sync(self) -> None:
276276
"""
277277
Calculate the pseugradient, average them across the manager group and starts
@@ -288,6 +288,7 @@ def prepare_sync(self) -> None:
288288

289289
self._average_grads()
290290

291+
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
291292
def perform_sync(self) -> bool:
292293
"""
293294
Overrides the sync method to wait for the scheduled allreduce to finish and
@@ -467,16 +468,6 @@ def __init__(
467468
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
468469
raise ValueError("fragment_update_alpha must be between 0 and 1")
469470

470-
# TODO: Support multiple fragments
471-
# This requires changing the manager to support `should_commit` for each
472-
# fragment separately.
473-
if len(model_fragments) != 1:
474-
raise ValueError("Multiple fragments are not supported yet")
475-
476-
# TODO: Support `fragment_sync_delay`
477-
if fragment_sync_delay != 0:
478-
raise ValueError("Fragment synchronization delay is not supported yet")
479-
480471
# TODO: Support `fragment_update_alpha`
481472
if fragment_update_alpha != 0.0:
482473
raise ValueError(
@@ -522,6 +513,8 @@ def __init__(
522513
use_bucketization,
523514
bucket_cap_mb,
524515
should_quantize,
516+
fragment_sync_delay,
517+
fragment_update_alpha,
525518
)
526519
for i, model_fragment in enumerate(model_fragments)
527520
]
@@ -606,16 +599,20 @@ def _step_post_hook(
606599
step = self._local_step
607600

608601
# Start sending fragments
609-
for fragment in self._fragments:
602+
for i, fragment in enumerate(self._fragments):
610603
if not fragment.should_prepare_fragment(step):
611604
continue
612605

606+
logger.info(f"preparing fragment {i} at step {step}")
607+
613608
fragment.prepare_sync()
614609

615-
for fragment in self._fragments:
610+
for i, fragment in enumerate(self._fragments):
616611
if not fragment.should_sync_fragment(step):
617612
continue
618613

614+
logger.info(f"syncing fragment {i} at step {step}")
615+
619616
if not fragment.perform_sync():
620617
# Cancel all the previously scheduled allreduce by simply
621618
# waiting for them. They should have failed but lets be

torchft/manager.py

Lines changed: 18 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939

4040
import torch
4141
from torch.distributed import ReduceOp, TCPStore
42-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
42+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4343

4444
from torchft._torchft import ManagerClient, ManagerServer
4545
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4646
from torchft.futures import future_timeout
47+
from torchft.work import DummyWork, ErrorSwallowingWork
4748

4849
if TYPE_CHECKING:
4950
from torchft.process_group import ProcessGroup
@@ -259,7 +260,6 @@ def __init__(
259260
self._quorum_id = -1
260261
self._errored: Optional[ExceptionWithTraceback] = None
261262
self._healing = False
262-
self._pending_work: List[torch.futures.Future[object]] = []
263263
self._batches_committed = 0
264264

265265
# first step is 1
@@ -296,9 +296,8 @@ def shutdown(self, wait: bool = True) -> None:
296296
self._manager.shutdown()
297297
self._executor.shutdown(wait=wait)
298298

299-
def allreduce(
300-
self, tensor: torch.Tensor, should_quantize: bool = False
301-
) -> torch.futures.Future[torch.Tensor]:
299+
@torch.profiler.record_function("torchft::manager::allreduce")
300+
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
302301
"""
303302
Fault tolerant allreduce the tensor and return a Future that will be completed when
304303
the tensor is ready.
@@ -318,9 +317,7 @@ def allreduce(
318317
a Future that will be completed with the allreduced tensor
319318
"""
320319
if self.errored():
321-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
322-
fut.set_result(tensor)
323-
return fut
320+
return DummyWork(None)
324321

325322
self.wait_quorum()
326323

@@ -332,45 +329,42 @@ def allreduce(
332329
# Run the allreduce async and save the work object so we can wait on
333330
# it later.
334331
fut: Optional[
335-
torch.futures.Future[None]
336-
| torch.futures.Future[torch.Tensor]
337-
| torch.futures.Future[List[torch.Tensor]]
332+
torch.futures.Future[None] | torch.futures.Future[list[torch.Tensor]]
338333
] = None
334+
work: Optional[Work] = None
335+
339336
if should_quantize and IS_TRITON_AVAILABLE:
340-
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
337+
assert False, "allreduce_quantized is not supported yet"
338+
# TODO: Support `allreduce_quantized`
339+
# fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
341340
else:
342341
work = self._pg.allreduce([tensor], ReduceOp.SUM)
342+
assert work is not None
343343
fut = work.get_future()
344344

345345
# schedule grad normalization as a continuation
346346
# on the Future
347347
def callback(
348348
fut: torch.futures.Future[List[torch.Tensor]],
349-
) -> torch.Tensor:
349+
) -> None:
350350
nonlocal tensor
351351

352352
# check for exceptions
353353
fut.value()
354354

355355
tensor /= self.num_participants()
356356

357-
return tensor
358-
359357
assert fut is not None
360-
if not should_quantize:
361-
fut = fut.then(callback)
362-
fut = self.wrap_future(fut, tensor)
363-
return fut
364-
358+
fut = fut.then(callback)
359+
fut = self.wrap_future(fut, None)
360+
return ErrorSwallowingWork(work, self.report_error, None)
365361
except Exception as e:
366362
self._logger.exception(
367363
f"got exception in all reduce -- skipping remaining: {e}"
368364
)
369365
self.report_error(e)
370366

371-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
372-
fut.set_result(tensor)
373-
return fut
367+
return DummyWork(None)
374368

375369
def report_error(self, e: Exception) -> None:
376370
"""
@@ -429,7 +423,6 @@ def callback(
429423
return default
430424

431425
fut = fut.then(callback)
432-
self._pending_work.append(cast(torch.futures.Future[object], fut))
433426
return fut
434427

435428
def start_quorum(
@@ -562,7 +555,7 @@ def _async_quorum(
562555
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
563556
# We use the replica rank and world as we want all replicas in the PG.
564557
try:
565-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
558+
with torch.profiler.record_function("torchft::manager::_pg::configure"):
566559
self._pg.configure(
567560
store_prefixed_addr, replica_rank, replica_world_size
568561
)
@@ -694,21 +687,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
694687
Raises:
695688
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
696689
"""
697-
for work in self._pending_work:
698-
# check at the beginning of since .wait() may trigger errors
699-
if self._errored is not None:
700-
break
701-
702-
# We swallow the error at in a future then callback so this will
703-
# never return an error.
704-
work.wait()
705-
706690
# make sure recovery is complete before committing
707691
if self._recovery_stream is not None:
708692
self._recovery_stream.synchronize()
709693

710-
self._pending_work = []
711-
712694
if err := self._pg.errored():
713695
self.report_error(err)
714696

torchft/manager_test.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
164164

165165
manager.start_quorum()
166166
manager.allreduce(torch.tensor([1.0])).wait()
167-
self.assertEqual(len(manager._pending_work), 1)
168167
self.assertTrue(manager.should_commit())
169-
self.assertEqual(len(manager._pending_work), 0)
170168

171169
self.assertEqual(manager._quorum_id, 123)
172170
self.assertEqual(manager.current_step(), 1)
@@ -554,8 +552,6 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
554552
self.assertIs(error.original_exception, e)
555553
self.assertEqual(wrapped_fut.value(), 2)
556554

557-
self.assertEqual(manager._pending_work, [wrapped_fut])
558-
559555
@patch("torchft.manager.ManagerClient", autospec=True)
560556
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
561557
manager = self._create_manager(timeout=timedelta(seconds=0.01))
@@ -590,18 +586,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None:
590586
manager._pg.allreduce.return_value = _DummyWork(None)
591587

592588
self.assertTrue(manager.is_participating())
593-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
594-
fut = manager.allreduce(torch.tensor([1.0]))
595-
result = fut.value()
596-
torch.testing.assert_close(result, torch.tensor([1.0 / 5]))
589+
tensor = torch.tensor([1.0])
590+
manager.allreduce(tensor).wait()
591+
torch.testing.assert_close(tensor, torch.tensor([1.0 / 5]))
597592

598593
# check healing numerics
599594
manager._healing = True
600595
self.assertFalse(manager.is_participating())
601-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
602-
fut = manager.allreduce(torch.tensor([1.0]))
603-
result = fut.value()
604-
torch.testing.assert_close(result, torch.tensor([0.0]))
596+
tensor = torch.tensor([1.0])
597+
manager.allreduce(tensor).wait()
598+
torch.testing.assert_close(tensor, torch.tensor([0.0]))
605599

606600
@patch("torchft.manager.ManagerClient", autospec=True)
607601
def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:

torchft/process_group.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
from torchft.device_mesh import * # noqa: F401
7070
from torchft.futures import context_timeout, stream_timeout
7171
from torchft.multiprocessing import _MonitoredPipe
72+
from torchft.work import DummyWork
7273

7374
if TYPE_CHECKING:
7475
from torchft.manager import Manager
@@ -775,27 +776,15 @@ def abort(self) -> None:
775776

776777
def errored(self) -> Optional[Exception]:
777778
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
779+
torch.cuda.current_stream().synchronize()
779780

780781
return self._errored
781782

782783
def getBackendName(self) -> str:
783784
return "torchft-nccl"
784785

785786

786-
class _DummyWork(dist._Work):
787-
def __init__(self, result: object) -> None:
788-
super().__init__()
789-
self.result_ = result
790-
# pyre-fixme[29]: Future is not a function
791-
self.future_: torch.futures.Future[object] = torch.futures.Future()
792-
self.future_.set_result(result)
793-
794-
def wait(self, timeout: Optional[timedelta] = None) -> bool:
795-
return True
796-
797-
def get_future(self) -> torch.futures.Future[object]:
798-
return self.future_
787+
_DummyWork = DummyWork
799788

800789

801790
class ProcessGroupDummy(ProcessGroup):

0 commit comments

Comments
 (0)