Skip to content

Commit 0e69a62

Browse files
committed
Group allreduce futures
1 parent db07843 commit 0e69a62

File tree

6 files changed

+282
-54
lines changed

6 files changed

+282
-54
lines changed

torchft/local_sgd.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import logging
1212
import math
1313
import threading
14+
from contextlib import nullcontext
1415
from types import TracebackType
1516
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type
1617

@@ -197,9 +198,10 @@ def __init__(
197198
self._outer_optimizer = outer_optimizer
198199

199200
# Stores pending all reduce
200-
self._allreduce_futures: list[
201-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
202-
] = []
201+
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
202+
self._stream: Optional[torch.cuda.Stream] = (
203+
torch.cuda.Stream() if torch.cuda.is_available() else None
204+
)
203205

204206
if bucket_cap_mb is not None:
205207
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -222,13 +224,15 @@ def __init__(
222224
t = t.pin_memory()
223225
self.original_parameters[name] = t
224226

227+
@torch.profiler.record_function("torchft::local_sgd::save_parameters")
225228
def save_parameters(self) -> None:
226229
with torch.no_grad():
227230
# TODO: consider running copy on a separate stream
228231
for name, p in self._model_fragment.named_parameters():
229232
param_to_local = extract_local_tensor(p.data)
230233
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
231234

235+
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
232236
def restore_parameters(self) -> None:
233237
with torch.no_grad():
234238
# TODO: consider running copy on a separate stream
@@ -248,6 +252,7 @@ def restore_parameters(self) -> None:
248252
else:
249253
p.data.copy_(self.original_parameters[name], non_blocking=False)
250254

255+
@torch.profiler.record_function("torchft::local_sgd::wait")
251256
def wait(self) -> None:
252257
"""
253258
Waits for the previously scheduled allreduce to finish
@@ -272,22 +277,31 @@ def should_sync_fragment(self, step: int) -> bool:
272277
step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay
273278
return step_to_sync % self._sync_every == 0
274279

280+
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
275281
def prepare_sync(self) -> None:
276282
"""
277283
Calculate the pseugradient, average them across the manager group and starts
278284
allreduce on the pseudo-gradients but doesn't wait for it to finish.
279285
"""
280-
# Set the .grad field of each parameter to its pseudogradient
281-
for name, p in self._model_fragment.named_parameters():
282-
local_param = extract_local_tensor(p.data)
283-
pseudogradient = local_param - self.original_parameters[name].to(p.device)
284-
if isinstance(p, DTensor):
285-
p.grad._local_tensor = pseudogradient
286-
else:
287-
p.grad = pseudogradient
286+
with (
287+
torch.cuda.stream(self._stream)
288+
if self._stream is not None
289+
else nullcontext()
290+
):
291+
# Set the .grad field of each parameter to its pseudogradient
292+
for name, p in self._model_fragment.named_parameters():
293+
local_param = extract_local_tensor(p.data)
294+
pseudogradient = local_param - self.original_parameters[name].to(
295+
p.device
296+
)
297+
if isinstance(p, DTensor):
298+
p.grad._local_tensor = pseudogradient
299+
else:
300+
p.grad = pseudogradient
288301

289-
self._average_grads()
302+
self._average_grads()
290303

304+
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
291305
def perform_sync(self) -> bool:
292306
"""
293307
Overrides the sync method to wait for the scheduled allreduce to finish and
@@ -298,6 +312,9 @@ def perform_sync(self) -> bool:
298312

299313
self.wait()
300314

315+
if self._stream is not None:
316+
self._stream.synchronize()
317+
301318
# Restore the parameters back to the previous state
302319
self.restore_parameters()
303320

@@ -467,16 +484,6 @@ def __init__(
467484
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
468485
raise ValueError("fragment_update_alpha must be between 0 and 1")
469486

470-
# TODO: Support multiple fragments
471-
# This requires changing the manager to support `should_commit` for each
472-
# fragment separately.
473-
if len(model_fragments) != 1:
474-
raise ValueError("Multiple fragments are not supported yet")
475-
476-
# TODO: Support `fragment_sync_delay`
477-
if fragment_sync_delay != 0:
478-
raise ValueError("Fragment synchronization delay is not supported yet")
479-
480487
# TODO: Support `fragment_update_alpha`
481488
if fragment_update_alpha != 0.0:
482489
raise ValueError(
@@ -522,6 +529,8 @@ def __init__(
522529
use_bucketization,
523530
bucket_cap_mb,
524531
should_quantize,
532+
fragment_sync_delay,
533+
fragment_update_alpha,
525534
)
526535
for i, model_fragment in enumerate(model_fragments)
527536
]
@@ -606,16 +615,20 @@ def _step_post_hook(
606615
step = self._local_step
607616

608617
# Start sending fragments
609-
for fragment in self._fragments:
618+
for i, fragment in enumerate(self._fragments):
610619
if not fragment.should_prepare_fragment(step):
611620
continue
612621

622+
logger.info(f"preparing fragment {i} at step {step}")
623+
613624
fragment.prepare_sync()
614625

615-
for fragment in self._fragments:
626+
for i, fragment in enumerate(self._fragments):
616627
if not fragment.should_sync_fragment(step):
617628
continue
618629

630+
logger.info(f"syncing fragment {i} at step {step}")
631+
619632
if not fragment.perform_sync():
620633
# Cancel all the previously scheduled allreduce by simply
621634
# waiting for them. They should have failed but lets be

torchft/manager.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def __init__(
259259
self._quorum_id = -1
260260
self._errored: Optional[ExceptionWithTraceback] = None
261261
self._healing = False
262-
self._pending_work: List[torch.futures.Future[object]] = []
263262
self._batches_committed = 0
264263

265264
# first step is 1
@@ -296,6 +295,7 @@ def shutdown(self, wait: bool = True) -> None:
296295
self._manager.shutdown()
297296
self._executor.shutdown(wait=wait)
298297

298+
@torch.profiler.record_function("torchft::manager::allreduce")
299299
def allreduce(
300300
self, tensor: torch.Tensor, should_quantize: bool = False
301301
) -> torch.futures.Future[torch.Tensor]:
@@ -331,34 +331,36 @@ def allreduce(
331331
try:
332332
# Run the allreduce async and save the work object so we can wait on
333333
# it later.
334-
fut: Optional[
335-
torch.futures.Future[None]
336-
| torch.futures.Future[torch.Tensor]
337-
| torch.futures.Future[List[torch.Tensor]]
338-
] = None
339334
if should_quantize and IS_TRITON_AVAILABLE:
340-
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
335+
assert False, "allreduce_quantized is not supported yet"
336+
# TODO: Support `allreduce_quantized`
337+
# fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg)
341338
else:
342339
work = self._pg.allreduce([tensor], ReduceOp.SUM)
343340
fut = work.get_future()
344341

342+
stream: Optional[torch.cuda.Stream] = (
343+
torch.cuda.current_stream() if torch.cuda.is_available() else None
344+
)
345+
345346
# schedule grad normalization as a continuation
346347
# on the Future
347348
def callback(
348349
fut: torch.futures.Future[List[torch.Tensor]],
349-
) -> torch.Tensor:
350-
nonlocal tensor
350+
) -> None:
351+
nonlocal tensor, stream
351352

352353
# check for exceptions
353354
fut.value()
354355

355356
tensor /= self.num_participants()
356357

358+
if stream is not None:
359+
stream.wait_stream(torch.cuda.current_stream())
360+
357361
return tensor
358362

359-
assert fut is not None
360-
if not should_quantize:
361-
fut = fut.then(callback)
363+
fut = cast(torch.futures.Future[torch.Tensor], fut.then(callback))
362364
fut = self.wrap_future(fut, tensor)
363365
return fut
364366

@@ -429,7 +431,6 @@ def callback(
429431
return default
430432

431433
fut = fut.then(callback)
432-
self._pending_work.append(cast(torch.futures.Future[object], fut))
433434
return fut
434435

435436
def start_quorum(
@@ -562,7 +563,7 @@ def _async_quorum(
562563
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
563564
# We use the replica rank and world as we want all replicas in the PG.
564565
try:
565-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
566+
with torch.profiler.record_function("torchft::manager::_pg::configure"):
566567
self._pg.configure(
567568
store_prefixed_addr, replica_rank, replica_world_size
568569
)
@@ -694,20 +695,12 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
694695
Raises:
695696
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
696697
"""
697-
for work in self._pending_work:
698-
# check at the beginning of since .wait() may trigger errors
699-
if self._errored is not None:
700-
break
701-
702-
# We swallow the error at in a future then callback so this will
703-
# never return an error.
704-
work.wait()
705-
706698
# make sure recovery is complete before committing
707699
if self._recovery_stream is not None:
708700
self._recovery_stream.synchronize()
709701

710-
self._pending_work = []
702+
if torch.cuda.is_available():
703+
torch.cuda.current_stream().synchronize()
711704

712705
if err := self._pg.errored():
713706
self.report_error(err)

torchft/manager_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
164164

165165
manager.start_quorum()
166166
manager.allreduce(torch.tensor([1.0])).wait()
167-
self.assertEqual(len(manager._pending_work), 1)
168167
self.assertTrue(manager.should_commit())
169-
self.assertEqual(len(manager._pending_work), 0)
170168

171169
self.assertEqual(manager._quorum_id, 123)
172170
self.assertEqual(manager.current_step(), 1)
@@ -554,8 +552,6 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
554552
self.assertIs(error.original_exception, e)
555553
self.assertEqual(wrapped_fut.value(), 2)
556554

557-
self.assertEqual(manager._pending_work, [wrapped_fut])
558-
559555
@patch("torchft.manager.ManagerClient", autospec=True)
560556
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
561557
manager = self._create_manager(timeout=timedelta(seconds=0.01))

torchft/process_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def abort(self) -> None:
775775

776776
def errored(self) -> Optional[Exception]:
777777
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
778+
torch.cuda.current_stream().synchronize()
779779

780780
return self._errored
781781

train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main() -> None:
5151
# majority of groups will be available so few batches will be dropped.
5252
sampler = DistributedSampler(
5353
trainset,
54-
replica_group=REPLICA_GROUP_ID,
54+
replica_rank=REPLICA_GROUP_ID,
5555
num_replica_groups=NUM_REPLICA_GROUPS,
5656
group_rank=0,
5757
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.

0 commit comments

Comments
 (0)