Skip to content

Commit ad04693

Browse files
committed
Group allreduce futures
1 parent db07843 commit ad04693

File tree

4 files changed

+16
-37
lines changed

4 files changed

+16
-37
lines changed

torchft/collectives.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def allocate_reduce_scatter_output(
135135
return tensor, padded_sizes
136136

137137

138-
class _QuantizedOpFuture(Future[None]):
138+
class _QuantizedOpFuture(Future[list[torch.Tensor]]):
139139
def __init__(
140140
self,
141141
sync_stream: cuda.Stream,
@@ -145,11 +145,12 @@ def __init__(
145145
self._sync_stream = sync_stream
146146
self._keep_alive_tensors = keep_alive_tensors
147147

148-
def wait(self) -> None:
148+
def wait(self) -> list[torch.Tensor]:
149149
# Wait for the synchronization to complete.
150150
cuda.current_stream().wait_stream(self._sync_stream)
151151
# Clean up intermediate buffers.
152152
del self._keep_alive_tensors
153+
return []
153154

154155

155156
def reduce_scatter_quantized(
@@ -284,7 +285,7 @@ def allreduce_quantized(
284285
opts: AllreduceOptions | ReduceOp,
285286
process_group: "ProcessGroup",
286287
sync_stream: cuda.Stream | None = None,
287-
) -> Future[None]:
288+
) -> Future[list[torch.Tensor]]:
288289
"""
289290
Performs a quantized all-reduce operation on a list of tensors.
290291
@@ -314,6 +315,8 @@ def allreduce_quantized(
314315
A Future that can be used to wait for the operation to complete and
315316
clean up intermediate buffers.
316317
318+
The future's value is set to an empty list
319+
317320
Raises:
318321
NotImplementedError: If the reduce operation is not ReduceOp.AVG.
319322
"""

torchft/local_sgd.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -467,16 +467,6 @@ def __init__(
467467
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
468468
raise ValueError("fragment_update_alpha must be between 0 and 1")
469469

470-
# TODO: Support multiple fragments
471-
# This requires changing the manager to support `should_commit` for each
472-
# fragment separately.
473-
if len(model_fragments) != 1:
474-
raise ValueError("Multiple fragments are not supported yet")
475-
476-
# TODO: Support `fragment_sync_delay`
477-
if fragment_sync_delay != 0:
478-
raise ValueError("Fragment synchronization delay is not supported yet")
479-
480470
# TODO: Support `fragment_update_alpha`
481471
if fragment_update_alpha != 0.0:
482472
raise ValueError(

torchft/manager.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,6 @@ def __init__(
259259
self._quorum_id = -1
260260
self._errored: Optional[ExceptionWithTraceback] = None
261261
self._healing = False
262-
self._pending_work: List[torch.futures.Future[object]] = []
263262
self._batches_committed = 0
264263

265264
# first step is 1
@@ -332,15 +331,17 @@ def allreduce(
332331
# Run the allreduce async and save the work object so we can wait on
333332
# it later.
334333
fut: Optional[
335-
torch.futures.Future[None]
336-
| torch.futures.Future[torch.Tensor]
337-
| torch.futures.Future[List[torch.Tensor]]
334+
torch.futures.Future[torch.Tensor]
335+
| torch.futures.Future[list[torch.Tensor]]
338336
] = None
339337
if should_quantize and IS_TRITON_AVAILABLE:
340-
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
338+
fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg)
341339
else:
342-
work = self._pg.allreduce([tensor], ReduceOp.SUM)
343-
fut = work.get_future()
340+
sync_stream = torch.cuda.Stream()
341+
sync_stream.wait_stream(torch.cuda.current_stream())
342+
with torch.cuda.stream(sync_stream):
343+
work = self._pg.allreduce([tensor], ReduceOp.SUM)
344+
fut = work.get_future()
344345

345346
# schedule grad normalization as a continuation
346347
# on the Future
@@ -357,11 +358,9 @@ def callback(
357358
return tensor
358359

359360
assert fut is not None
360-
if not should_quantize:
361-
fut = fut.then(callback)
361+
fut = fut.then(callback)
362362
fut = self.wrap_future(fut, tensor)
363363
return fut
364-
365364
except Exception as e:
366365
self._logger.exception(
367366
f"got exception in all reduce -- skipping remaining: {e}"
@@ -429,7 +428,6 @@ def callback(
429428
return default
430429

431430
fut = fut.then(callback)
432-
self._pending_work.append(cast(torch.futures.Future[object], fut))
433431
return fut
434432

435433
def start_quorum(
@@ -694,21 +692,10 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
694692
Raises:
695693
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
696694
"""
697-
for work in self._pending_work:
698-
# check at the beginning of since .wait() may trigger errors
699-
if self._errored is not None:
700-
break
701-
702-
# We swallow the error at in a future then callback so this will
703-
# never return an error.
704-
work.wait()
705-
706695
# make sure recovery is complete before committing
707696
if self._recovery_stream is not None:
708697
self._recovery_stream.synchronize()
709698

710-
self._pending_work = []
711-
712699
if err := self._pg.errored():
713700
self.report_error(err)
714701

torchft/process_group.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -775,8 +775,7 @@ def abort(self) -> None:
775775

776776
def errored(self) -> Optional[Exception]:
777777
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
779-
778+
torch.cuda.current_stream.synchronize()
780779
return self._errored
781780

782781
def getBackendName(self) -> str:

0 commit comments

Comments
 (0)