Skip to content

Commit b4d5433

Browse files
committed
Group allreduce futures
1 parent a1d65a6 commit b4d5433

File tree

5 files changed

+102
-67
lines changed

5 files changed

+102
-67
lines changed

torchft/collectives.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def allreduce_quantized(
4646
opts: AllreduceOptions | ReduceOp,
4747
process_group: "ProcessGroup",
4848
sync_stream: cuda.Stream | None = None,
49-
) -> Future[None]:
49+
) -> Future[list[torch.Tensor]]:
5050
"""
5151
Performs a quantized all-reduce operation on a list of tensors.
5252
@@ -76,6 +76,8 @@ def allreduce_quantized(
7676
A Future that can be used to wait for the operation to complete and
7777
clean up intermediate buffers.
7878
79+
The future's value is set to an empty list
80+
7981
Raises:
8082
NotImplementedError: If the reduce operation is not ReduceOp.AVG.
8183
"""
@@ -137,7 +139,7 @@ def allreduce_quantized(
137139
# Dequantize and copy to output buffer.
138140
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
139141

140-
class QuantizedAllReduceFuture(Future[None]):
142+
class QuantizedAllReduceFuture(Future[list[torch.Tensor]]):
141143
def __init__(
142144
self,
143145
sync_stream: cuda.Stream,
@@ -149,12 +151,13 @@ def __init__(
149151
self._quantized_tensors = quantized_tensors
150152
self._quantized_tensors_out = quantized_tensors_out
151153

152-
def wait(self) -> None:
154+
def wait(self) -> list[torch.Tensor]:
153155
# Wait for the synchronization to complete.
154156
cuda.current_stream().wait_stream(self._sync_stream)
155157
# Clean up intermediate buffers.
156158
del self._quantized_tensors_out
157159
del self._quantized_tensors
160+
return []
158161

159162
# pyre-ignore[29]
160163
return QuantizedAllReduceFuture(

torchft/local_sgd.py

Lines changed: 41 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -147,15 +147,15 @@ def _average(self) -> list[torch.Tensor]:
147147
"""
148148
Averages the model parameters across the manager and returns the averaged parameters.
149149
"""
150-
works = []
151150
averaged_parameters = []
152151
for p in self._model.parameters():
153152
# Create a new tensor to store the averaged parameter
154153
avg_param = extract_local_tensor(p)
155-
works.append(self._manager.allreduce(avg_param))
156154
averaged_parameters.append(avg_param)
157-
for work in works:
158-
work.wait()
155+
156+
work = self._manager.collect_all_allreduce(averaged_parameters)
157+
work.wait()
158+
159159
return averaged_parameters
160160

161161

@@ -193,9 +193,7 @@ def __init__(
193193
self._outer_optimizer = outer_optimizer
194194

195195
# Stores pending all reduce
196-
self._allreduce_futures: List[
197-
torch.futures.Future[None] | torch.futures.Future[torch.Tensor]
198-
] = []
196+
self._allreduce_futures: list[torch.futures.Future[None]] = []
199197

200198
if bucket_cap_mb is not None:
201199
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
@@ -320,18 +318,27 @@ def _average_grads(self) -> None:
320318

321319
def _allreduce_per_param(self) -> None:
322320
"""Performs allreduce on each gradient tensor separately (original method)."""
321+
tensors = []
322+
323323
for p in self._model_fragment.parameters():
324324
# Perform allreduce on the pseudogradients
325325
assert p.grad is not None
326326
if isinstance(p, DTensor):
327-
work = self._manager.allreduce(
328-
p.grad._local_tensor, should_quantize=self.should_quantize
329-
)
327+
tensors.append(p.grad._local_tensor)
330328
else:
331-
work = self._manager.allreduce(
332-
p.grad, should_quantize=self.should_quantize
333-
)
334-
self._allreduce_futures.append(work)
329+
tensors.append(p.grad)
330+
331+
work = self._manager.collect_all_allreduce(
332+
tensors, should_quantize=self.should_quantize
333+
)
334+
335+
def callback(
336+
fut: torch.futures.Future[List[torch.futures.Future[torch.Tensor]]],
337+
) -> None:
338+
return
339+
340+
work = work.then(callback)
341+
self._allreduce_futures.append(work)
335342

336343
def bucketize_and_allreduce(
337344
self,
@@ -351,6 +358,9 @@ def bucketize_and_allreduce(
351358
total_size = sum(t.numel() for t in tensors)
352359
dtype, device = tensors[0].dtype, tensors[0].device
353360

361+
flat_buffers: list[torch.Tensor] = []
362+
all_bucket_tensors: list[list[Tuple[torch.Tensor, int, int]]] = []
363+
354364
offset = 0
355365
flat_index = 0
356366
while offset < total_size:
@@ -372,19 +382,27 @@ def bucketize_and_allreduce(
372382
pack_offset += numel
373383
flat_index += 1
374384

375-
work = self._manager.allreduce(
376-
flat_buffer, should_quantize=self.should_quantize
377-
)
385+
flat_buffers.append(flat_buffer)
386+
all_bucket_tensors.append(bucket_tensors)
387+
388+
offset += chunk_size
378389

379-
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
380-
nonlocal bucket_tensors, flat_buffer
390+
def callback(
391+
fut: torch.futures.Future[List[torch.futures.Future[torch.Tensor]]],
392+
) -> None:
393+
nonlocal all_bucket_tensors, flat_buffers
394+
395+
for i in range(len(flat_buffers)):
396+
bucket_tensors = all_bucket_tensors[i]
397+
flat_buffer = flat_buffers[i]
381398
for t, pack_offset, numel in bucket_tensors:
382399
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
383400

384-
work = work.then(callback)
385-
self._allreduce_futures.append(work)
386-
387-
offset += chunk_size
401+
work = self._manager.collect_all_allreduce(
402+
flat_buffers, should_quantize=self.should_quantize
403+
)
404+
work = work.then(callback)
405+
self._allreduce_futures.append(work)
388406

389407
def _allreduce_bucketized(self) -> None:
390408
"""
@@ -455,16 +473,6 @@ def __init__(
455473
if sync_every < len(model_fragments):
456474
raise ValueError("Only 1 fragment can be syncrhonized at a time")
457475

458-
# TODO: Support multiple fragments
459-
# This requires changing the manager to support `should_commit` for each
460-
# fragment separately.
461-
if len(model_fragments) != 1:
462-
raise ValueError("Multiple fragments are not supported yet")
463-
464-
# TODO: Support `fragment_sync_delay`
465-
if fragment_sync_delay != 0:
466-
raise ValueError("Fragment synchronization delay is not supported yet")
467-
468476
# TODO: Support `fragment_update_alpha`
469477
if fragment_update_alpha != 0.0:
470478
raise ValueError(

torchft/local_sgd_test.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Dict
7+
from typing import Dict, List
88
from unittest import TestCase
99
from unittest.mock import MagicMock, create_autospec
1010

@@ -86,7 +86,7 @@ def test_local_sgd_healthy(self) -> None:
8686
manager.should_commit.return_value = True
8787
self.assertEqual(local_sgd._local_step, 0)
8888
self.assertEqual(manager.should_commit.call_count, 1)
89-
self.assertEqual(manager.allreduce.call_count, 4)
89+
self.assertEqual(manager.collect_all_allreduce.call_count, 1)
9090

9191
def test_extract_local_tensor(self) -> None:
9292
regular_tensor = torch.rand(3, 3, requires_grad=True)
@@ -172,7 +172,7 @@ def test_diloco_healthy(self) -> None:
172172
diloco._fragments[0].original_parameters, _params_dict(model)
173173
)
174174
self.assertEqual(manager.should_commit.call_count, 1)
175-
self.assertEqual(manager.allreduce.call_count, parameter_count)
175+
self.assertEqual(manager.collect_all_allreduce.call_count, 1)
176176

177177
outer_opt_state = outer_optimizer.state_dict()
178178
self.assertEqual(len(outer_opt_state["state"]), parameter_count)
@@ -220,13 +220,12 @@ def test_diloco_allreduce_call_efficiency(
220220
loss.backward()
221221
inner_optimizer.step()
222222

223-
allreduce_calls = manager.allreduce.call_count
224-
param_count = len([p for p in model.parameters() if p.requires_grad])
223+
allreduce_calls = manager.collect_all_allreduce.call_count
225224

226225
if expect_fewer_calls:
227-
self.assertLess(int(allreduce_calls), int(param_count))
226+
self.assertEqual(int(allreduce_calls), 1)
228227
else:
229-
self.assertEqual(int(allreduce_calls), int(param_count))
228+
self.assertEqual(int(allreduce_calls), 1)
230229

231230
def test_bucketization_correctness(self) -> None:
232231
class TinyModel(nn.Module):
@@ -251,16 +250,20 @@ def forward(self, x):
251250
manager._use_async_quorum = False
252251
manager.should_commit.return_value = True
253252

254-
# Define fake allreduce: multiplies buffer by 2
255-
def fake_allreduce(
256-
tensor: Tensor, should_quantize: bool
257-
) -> torch.futures.Future[Tensor]:
258-
tensor.mul_(2)
253+
# Define fake collect_all_allreduce: multiplies all buffers by 2
254+
def fake_collect_all_allreduce(
255+
tensors: List[Tensor], should_quantize: bool
256+
) -> torch.futures.Future[List[torch.futures.Future[Tensor]]]:
257+
for tensor in tensors:
258+
tensor.mul_(2)
259259
fut = torch.futures.Future() # pyre-fixme[29]: not a function
260-
fut.set_result(tensor)
261-
return fut
260+
fut.set_result(tensors)
262261

263-
manager.allreduce.side_effect = fake_allreduce
262+
futs = torch.futures.Future() # pyre-fixme[29]: not a function
263+
futs.set_result([fut])
264+
return futs
265+
266+
manager.collect_all_allreduce.side_effect = fake_collect_all_allreduce
264267

265268
diloco = DiLoCo(
266269
manager, [model], inner_opt, outer_opt, sync_every=2, use_bucketization=True

torchft/manager.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -278,8 +278,33 @@ def shutdown(self, wait: bool = True) -> None:
278278
self._manager.shutdown()
279279
self._executor.shutdown(wait=wait)
280280

281+
def collect_all_allreduce(
282+
self, tensors: List[torch.Tensor], should_quantize: bool = False
283+
) -> torch.futures.Future[List[torch.futures.Future[torch.Tensor]]]:
284+
futs: List[torch.futures.Future[torch.Tensor]] = []
285+
default_futs: List[torch.futures.Future[torch.Tensor]] = []
286+
287+
for tensor in tensors:
288+
fut = self.allreduce(tensor, should_quantize=should_quantize)
289+
futs.append(fut)
290+
291+
default_fut = torch.futures.Future() # pyre-fixme[29]: not a function
292+
default_fut.set_result(tensor)
293+
default_futs.append(default_fut)
294+
295+
fut = torch.futures.collect_all(futs)
296+
297+
return self.wrap_future(fut, default_futs)
298+
281299
def allreduce(
282300
self, tensor: torch.Tensor, should_quantize: bool = False
301+
) -> torch.futures.Future[torch.Tensor]:
302+
fut = self._allreduce(tensor, should_quantize=should_quantize)
303+
fut = self.wrap_future(fut, tensor)
304+
return fut
305+
306+
def _allreduce(
307+
self, tensor: torch.Tensor, should_quantize: bool = False
283308
) -> torch.futures.Future[torch.Tensor]:
284309
"""
285310
Fault tolerant allreduce the tensor and return a Future that will be completed when
@@ -314,9 +339,8 @@ def allreduce(
314339
# Run the allreduce async and save the work object so we can wait on
315340
# it later.
316341
fut: Optional[
317-
torch.futures.Future[None]
342+
torch.futures.Future[List[torch.Tensor]]
318343
| torch.futures.Future[torch.Tensor]
319-
| torch.futures.Future[List[torch.Tensor]]
320344
] = None
321345
if should_quantize and IS_TRITON_AVAILABLE:
322346
fut = allreduce_quantized([tensor], ReduceOp.AVG, self._pg)
@@ -331,19 +355,16 @@ def callback(
331355
) -> torch.Tensor:
332356
nonlocal tensor
333357

334-
# check for exceptions
335358
fut.value()
336359

337-
tensor /= self.num_participants()
360+
if not should_quantize:
361+
tensor /= self.num_participants()
338362

339363
return tensor
340364

341365
assert fut is not None
342-
if not should_quantize:
343-
fut = fut.then(callback)
344-
fut = self.wrap_future(fut, tensor)
366+
fut = fut.then(callback)
345367
return fut
346-
347368
except Exception as e:
348369
self._logger.exception(
349370
f"got exception in all reduce -- skipping remaining: {e}"
@@ -668,21 +689,24 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
668689
Raises:
669690
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
670691
"""
671-
for work in self._pending_work:
672-
# check at the beginning of since .wait() may trigger errors
673-
if self._errored is not None:
692+
while True:
693+
if len(self._pending_work) == 0:
674694
break
675695

696+
work = self._pending_work.pop(0)
676697
# We swallow the error at in a future then callback so this will
677698
# never return an error.
678699
work.wait()
679700

701+
# Remove all work if there was an error.
702+
# We won't commit in this case as well.
703+
if self._errored is None:
704+
break
705+
680706
# make sure recovery is complete before committing
681707
if self._recovery_stream is not None:
682708
self._recovery_stream.synchronize()
683709

684-
self._pending_work = []
685-
686710
if err := self._pg.errored():
687711
self.report_error(err)
688712

torchft/process_group.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -774,9 +774,6 @@ def abort(self) -> None:
774774
super().abort()
775775

776776
def errored(self) -> Optional[Exception]:
777-
# force a synchronization to ensure all work is complete
778-
torch.cuda.synchronize()
779-
780777
return self._errored
781778

782779
def getBackendName(self) -> str:

0 commit comments

Comments
 (0)