Skip to content

Commit add9c84

Browse files
committed
Group allreduce futures
1 parent db07843 commit add9c84

File tree

6 files changed

+269
-54
lines changed

6 files changed

+269
-54
lines changed

torchft/local_sgd.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,8 @@ def __init__(
197197
self._outer_optimizer = outer_optimizer
198198

199199
# Stores pending all reduce
200-
self._allreduce_futures: list[
201-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
202-
] = []
200+
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
201+
self._stream: torch.cuda.Stream = torch.cuda.Stream()
203202

204203
if bucket_cap_mb is not None:
205204
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -222,13 +221,15 @@ def __init__(
222221
t = t.pin_memory()
223222
self.original_parameters[name] = t
224223

224+
@torch.profiler.record_function("torchft::local_sgd::save_parameters")
225225
def save_parameters(self) -> None:
226226
with torch.no_grad():
227227
# TODO: consider running copy on a separate stream
228228
for name, p in self._model_fragment.named_parameters():
229229
param_to_local = extract_local_tensor(p.data)
230230
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
231231

232+
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
232233
def restore_parameters(self) -> None:
233234
with torch.no_grad():
234235
# TODO: consider running copy on a separate stream
@@ -248,6 +249,7 @@ def restore_parameters(self) -> None:
248249
else:
249250
p.data.copy_(self.original_parameters[name], non_blocking=False)
250251

252+
@torch.profiler.record_function("torchft::local_sgd::wait")
251253
def wait(self) -> None:
252254
"""
253255
Waits for the previously scheduled allreduce to finish
@@ -272,22 +274,27 @@ def should_sync_fragment(self, step: int) -> bool:
272274
step_to_sync = step - self._fragment_sync_offset - self._fragment_sync_delay
273275
return step_to_sync % self._sync_every == 0
274276

277+
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
275278
def prepare_sync(self) -> None:
276279
"""
277280
Calculate the pseugradient, average them across the manager group and starts
278281
allreduce on the pseudo-gradients but doesn't wait for it to finish.
279282
"""
280-
# Set the .grad field of each parameter to its pseudogradient
281-
for name, p in self._model_fragment.named_parameters():
282-
local_param = extract_local_tensor(p.data)
283-
pseudogradient = local_param - self.original_parameters[name].to(p.device)
284-
if isinstance(p, DTensor):
285-
p.grad._local_tensor = pseudogradient
286-
else:
287-
p.grad = pseudogradient
283+
with torch.cuda.stream(self._stream):
284+
# Set the .grad field of each parameter to its pseudogradient
285+
for name, p in self._model_fragment.named_parameters():
286+
local_param = extract_local_tensor(p.data)
287+
pseudogradient = local_param - self.original_parameters[name].to(
288+
p.device
289+
)
290+
if isinstance(p, DTensor):
291+
p.grad._local_tensor = pseudogradient
292+
else:
293+
p.grad = pseudogradient
288294

289-
self._average_grads()
295+
self._average_grads()
290296

297+
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
291298
def perform_sync(self) -> bool:
292299
"""
293300
Overrides the sync method to wait for the scheduled allreduce to finish and
@@ -297,6 +304,7 @@ def perform_sync(self) -> bool:
297304
return True
298305

299306
self.wait()
307+
self._stream.synchronize()
300308

301309
# Restore the parameters back to the previous state
302310
self.restore_parameters()
@@ -467,16 +475,6 @@ def __init__(
467475
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
468476
raise ValueError("fragment_update_alpha must be between 0 and 1")
469477

470-
# TODO: Support multiple fragments
471-
# This requires changing the manager to support `should_commit` for each
472-
# fragment separately.
473-
if len(model_fragments) != 1:
474-
raise ValueError("Multiple fragments are not supported yet")
475-
476-
# TODO: Support `fragment_sync_delay`
477-
if fragment_sync_delay != 0:
478-
raise ValueError("Fragment synchronization delay is not supported yet")
479-
480478
# TODO: Support `fragment_update_alpha`
481479
if fragment_update_alpha != 0.0:
482480
raise ValueError(
@@ -522,6 +520,8 @@ def __init__(
522520
use_bucketization,
523521
bucket_cap_mb,
524522
should_quantize,
523+
fragment_sync_delay,
524+
fragment_update_alpha,
525525
)
526526
for i, model_fragment in enumerate(model_fragments)
527527
]
@@ -606,16 +606,20 @@ def _step_post_hook(
606606
step = self._local_step
607607

608608
# Start sending fragments
609-
for fragment in self._fragments:
609+
for i, fragment in enumerate(self._fragments):
610610
if not fragment.should_prepare_fragment(step):
611611
continue
612612

613+
logger.info(f"preparing fragment {i} at step {step}")
614+
613615
fragment.prepare_sync()
614616

615-
for fragment in self._fragments:
617+
for i, fragment in enumerate(self._fragments):
616618
if not fragment.should_sync_fragment(step):
617619
continue
618620

621+
logger.info(f"syncing fragment {i} at step {step}")
622+
619623
if not fragment.perform_sync():
620624
# Cancel all the previously scheduled allreduce by simply
621625
# waiting for them. They should have failed but lets be

torchft/manager.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def __init__(
259259
self._quorum_id = -1
260260
self._errored: Optional[ExceptionWithTraceback] = None
261261
self._healing = False
262-
self._pending_work: List[torch.futures.Future[object]] = []
263262
self._batches_committed = 0
264263

265264
# first step is 1
@@ -296,6 +295,7 @@ def shutdown(self, wait: bool = True) -> None:
296295
self._manager.shutdown()
297296
self._executor.shutdown(wait=wait)
298297

298+
@torch.profiler.record_function("torchft::manager::allreduce")
299299
def allreduce(
300300
self, tensor: torch.Tensor, should_quantize: bool = False
301301
) -> torch.futures.Future[torch.Tensor]:
@@ -331,34 +331,33 @@ def allreduce(
331331
try:
332332
# Run the allreduce async and save the work object so we can wait on
333333
# it later.
334-
fut: Optional[
335-
torch.futures.Future[None]
336-
| torch.futures.Future[torch.Tensor]
337-
| torch.futures.Future[List[torch.Tensor]]
338-
] = None
339334
if should_quantize and IS_TRITON_AVAILABLE:
340-
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
335+
assert False, "allreduce_quantized is not supported yet"
336+
# TODO: Support `allreduce_quantized`
337+
# fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg)
341338
else:
342339
work = self._pg.allreduce([tensor], ReduceOp.SUM)
343340
fut = work.get_future()
344341

342+
stream: torch.cuda.Stream = torch.cuda.current_stream()
343+
345344
# schedule grad normalization as a continuation
346345
# on the Future
347346
def callback(
348347
fut: torch.futures.Future[List[torch.Tensor]],
349-
) -> torch.Tensor:
350-
nonlocal tensor
348+
) -> None:
349+
nonlocal tensor, stream
351350

352351
# check for exceptions
353352
fut.value()
354353

355354
tensor /= self.num_participants()
356355

356+
stream.wait_stream(torch.cuda.current_stream())
357+
357358
return tensor
358359

359-
assert fut is not None
360-
if not should_quantize:
361-
fut = fut.then(callback)
360+
fut = cast(torch.futures.Future[torch.Tensor], fut.then(callback))
362361
fut = self.wrap_future(fut, tensor)
363362
return fut
364363

@@ -429,7 +428,6 @@ def callback(
429428
return default
430429

431430
fut = fut.then(callback)
432-
self._pending_work.append(cast(torch.futures.Future[object], fut))
433431
return fut
434432

435433
def start_quorum(
@@ -562,7 +560,7 @@ def _async_quorum(
562560
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
563561
# We use the replica rank and world as we want all replicas in the PG.
564562
try:
565-
with torch.profiler.record_function("torchft::manager::_pg.configure"):
563+
with torch.profiler.record_function("torchft::manager::_pg::configure"):
566564
self._pg.configure(
567565
store_prefixed_addr, replica_rank, replica_world_size
568566
)
@@ -694,20 +692,11 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
694692
Raises:
695693
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
696694
"""
697-
for work in self._pending_work:
698-
# check at the beginning of since .wait() may trigger errors
699-
if self._errored is not None:
700-
break
701-
702-
# We swallow the error at in a future then callback so this will
703-
# never return an error.
704-
work.wait()
705-
706695
# make sure recovery is complete before committing
707696
if self._recovery_stream is not None:
708697
self._recovery_stream.synchronize()
709698

710-
self._pending_work = []
699+
torch.cuda.current_stream().synchronize()
711700

712701
if err := self._pg.errored():
713702
self.report_error(err)

torchft/manager_test.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,7 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None:
164164

165165
manager.start_quorum()
166166
manager.allreduce(torch.tensor([1.0])).wait()
167-
self.assertEqual(len(manager._pending_work), 1)
168167
self.assertTrue(manager.should_commit())
169-
self.assertEqual(len(manager._pending_work), 0)
170168

171169
self.assertEqual(manager._quorum_id, 123)
172170
self.assertEqual(manager.current_step(), 1)
@@ -554,8 +552,6 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
554552
self.assertIs(error.original_exception, e)
555553
self.assertEqual(wrapped_fut.value(), 2)
556554

557-
self.assertEqual(manager._pending_work, [wrapped_fut])
558-
559555
@patch("torchft.manager.ManagerClient", autospec=True)
560556
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
561557
manager = self._create_manager(timeout=timedelta(seconds=0.01))

torchft/process_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def abort(self) -> None:
775775

776776
def errored(self) -> Optional[Exception]:
777777
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
778+
torch.cuda.current_stream().synchronize()
779779

780780
return self._errored
781781

train_ddp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def main() -> None:
5151
# majority of groups will be available so few batches will be dropped.
5252
sampler = DistributedSampler(
5353
trainset,
54-
replica_group=REPLICA_GROUP_ID,
54+
replica_rank=REPLICA_GROUP_ID,
5555
num_replica_groups=NUM_REPLICA_GROUPS,
5656
group_rank=0,
5757
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.

0 commit comments

Comments
 (0)