From fef4abc8772d34b9a50a8e7a4d3d37b2c1a8009e Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Fri, 25 Jul 2025 18:30:02 -0700 Subject: [PATCH 1/4] use http transport Summary: use http transport instead of pg transport -- pg transport fails to resolve address when running locally --- train_diloco.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train_diloco.py b/train_diloco.py index 0c6b9cf..e207e73 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -34,7 +34,7 @@ ProcessGroupGloo, ProcessGroupNCCL, ) -from torchft.checkpointing.pg_transport import PGTransport +from torchft.checkpointing.http_transport import HTTPTransport from torchft.local_sgd import DiLoCo logging.basicConfig(level=logging.INFO) @@ -67,13 +67,12 @@ def state_dict(): timeout=timedelta(seconds=10), ) if torch.cuda.is_available() and USE_NCCL - else ProcessGroupGloo(timeout=timedelta(seconds=5)) + else ProcessGroupGloo(timeout=timedelta(seconds=10)) ) - transport = PGTransport( - pg, + transport = HTTPTransport( timeout=timedelta(seconds=10), - device=device, + num_chunks=0, ) manager = Manager( From 09bbdea7ca569378f6ac0f9a39161bb9b6421bbb Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 31 Jul 2025 18:55:23 -0700 Subject: [PATCH 2/4] fix stream dependencies in callbacks Summary: - call future.wait in callbacks to make sure the continuation executes after the future has completed - set the stream correctly to execute callback scheduled by bucketized allreduce --- torchft/collectives.py | 2 ++ torchft/local_sgd.py | 11 ++++++++--- torchft/manager.py | 2 ++ 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/torchft/collectives.py b/torchft/collectives.py index 837fbcd..927309a 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -387,6 +387,8 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: nonlocal tensors, quantized_tensors, world_size, sync_stream with torch.cuda.stream(sync_stream): + # Setup stream dependency + fut.wait() # Dequantize the result back to the original precision fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) return tensors diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index c7230ee..d0eeccc 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -513,9 +513,14 @@ def _bucketize_and_allreduce( ) def callback(fut: torch.futures.Future[torch.Tensor]) -> None: - nonlocal bucket_tensors, flat_buffer - for t, pack_offset, numel in bucket_tensors: - t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t)) + with torch.cuda.stream(self._stream) if self._stream else nullcontext(): + nonlocal bucket_tensors, flat_buffer + # Setup stream dependency + fut.wait() + for t, pack_offset, numel in bucket_tensors: + t.copy_( + flat_buffer[pack_offset : pack_offset + numel].view_as(t) + ) work = work.then(callback) self._allreduce_futures.append(work) diff --git a/torchft/manager.py b/torchft/manager.py index e01a965..09100c3 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -403,6 +403,8 @@ def callback( # change the stream to avoid making the callback stream # dependent on process group stream running the allreduce with torch.cuda.stream(stream) if stream is not None else nullcontext(): + # Setup stream dependency + fut.wait() fut.value() tensor /= num_participants From 9683ef4db489813a5811ff6a8266372945b33b96 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 31 Jul 2025 18:55:23 -0700 Subject: [PATCH 3/4] return work from manager allreduce Summary: returns the work object so we can be more flexible with the usage --- torchft/collectives.py | 18 ++++++++++++----- torchft/collectives_test.py | 4 ++-- torchft/ddp.py | 3 ++- torchft/ddp_test.py | 10 ++++++---- torchft/local_sgd.py | 22 ++++++++++++--------- torchft/local_sgd_test.py | 25 ++++++++++++----------- torchft/manager.py | 23 ++++++++++------------ torchft/manager_integ_test.py | 4 ++-- torchft/manager_test.py | 11 ++++++----- torchft/process_group.py | 16 +-------------- torchft/process_group_test.py | 2 +- torchft/work.py | 37 +++++++++++++++++++++++++++++++++++ 12 files changed, 105 insertions(+), 70 deletions(-) create mode 100644 torchft/work.py diff --git a/torchft/collectives.py b/torchft/collectives.py index 927309a..af95cbb 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -18,6 +18,7 @@ AllreduceOptions, AllToAllOptions, ReduceScatterOptions, + Work, ) from torch.futures import Future @@ -288,7 +289,7 @@ def allreduce_quantized( opts: AllreduceOptions | ReduceOp, process_group: "ProcessGroup", sync_stream: cuda.Stream | None = None, -) -> Future[list[torch.Tensor]]: +) -> Work: """ Performs a quantized all-reduce operation on a list of tensors. @@ -379,10 +380,18 @@ def allreduce_quantized( [torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]], _to_allgather_options(allreduce_opts), ) + + # NOTE: This is not supposed to be used with gloo, only with NCCL. + # So we setup the stream dependency here by calling work.wait(), + # which doesn't block the CPU. + # + # The future callback below will run after the work has been + # completed. + work.wait() fut = work.get_future() - def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: + def callback(fut: Future[list[torch.Tensor]]) -> None: # Dequantize and copy to output buffer. nonlocal tensors, quantized_tensors, world_size, sync_stream @@ -391,7 +400,6 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: fut.wait() # Dequantize the result back to the original precision fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) - return tensors - fut = fut.then(callback) - return fut + fut.add_done_callback(callback) + return work diff --git a/torchft/collectives_test.py b/torchft/collectives_test.py index c4b826b..6660abe 100644 --- a/torchft/collectives_test.py +++ b/torchft/collectives_test.py @@ -94,8 +94,8 @@ def _run_all_reduce_collective( ) ] - fut = allreduce_quantized(tensors, reduce_op, pg) - fut.wait() + work = allreduce_quantized(tensors, reduce_op, pg) + work.wait() work = pg.allreduce([expected], reduce_op) work.get_future().wait() diff --git a/torchft/ddp.py b/torchft/ddp.py index 6fbea8f..1355317 100644 --- a/torchft/ddp.py +++ b/torchft/ddp.py @@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N def _comm_hook( state: "Manager", bucket: dist.GradBucket ) -> torch.futures.Future[torch.Tensor]: - return state.allreduce(bucket.buffer()) + work = state.allreduce(bucket.buffer()) + return work.get_future() class PureDistributedDataParallel(nn.Module): diff --git a/torchft/ddp_test.py b/torchft/ddp_test.py index 1a56dce..690bfd0 100644 --- a/torchft/ddp_test.py +++ b/torchft/ddp_test.py @@ -10,11 +10,13 @@ import torch import torch.distributed as dist from torch import nn +from torch.distributed.distributed_c10d import Work from torch.futures import Future from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel from torchft.manager import Manager from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo +from torchft.work import _DummyWork class TestDDP(TestCase): @@ -39,14 +41,14 @@ def test_ddp(self) -> None: call_count = 0 - def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]: + def allreduce( + tensor: torch.Tensor, + ) -> Work: nonlocal call_count call_count += 1 - fut = Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) manager.allreduce = allreduce diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index d0eeccc..761a74c 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -18,6 +18,7 @@ import torch import torch.distributed as dist from torch import nn, optim +from torch.distributed.distributed_c10d import Work from torch.distributed.tensor import DTensor from torch.nn.parameter import Parameter from torch.optim.optimizer import Optimizer @@ -200,7 +201,7 @@ def __init__( self._outer_optimizer = outer_optimizer # Stores pending all reduce - self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = [] + self._allreduce_work: list[Work] = [] self._stream: Optional[torch.cuda.Stream] = ( torch.cuda.Stream() if torch.cuda.is_available() else None ) @@ -368,7 +369,7 @@ def wait(self) -> None: """ Waits for the previously scheduled allreduce to finish """ - if len(self._allreduce_futures) == 0: + if len(self._allreduce_work) == 0: return if self._stream is not None: @@ -376,7 +377,7 @@ def wait(self) -> None: self._stop_event.synchronize() self._stop_event = None - self._allreduce_futures = [] + self._allreduce_work = [] @torch.profiler.record_function("torchft::local_sgd::prepare_sync") def prepare_sync(self) -> None: @@ -386,7 +387,7 @@ def prepare_sync(self) -> None: """ self._save_grads() - assert len(self._allreduce_futures) == 0 + assert len(self._allreduce_work) == 0 # Make sure tensors are available to `_stream` if self._stream is not None: @@ -399,7 +400,7 @@ def prepare_sync(self) -> None: ): self._average_grads() - for work in self._allreduce_futures: + for work in self._allreduce_work: work.wait() if self._stream is not None: @@ -413,7 +414,7 @@ def perform_sync(self) -> bool: steps using the outer optimizer. """ # Waiting for an allreduce before it has been sent is currently not supported. - assert len(self._allreduce_futures) > 0 + assert len(self._allreduce_work) > 0 self.wait() @@ -467,7 +468,8 @@ def _allreduce_per_param(self) -> None: work = self._manager.allreduce( self._grads[name], should_quantize=self.should_quantize ) - self._allreduce_futures.append(work) + + self._allreduce_work.append(work) def _bucketize_and_allreduce( self, @@ -522,8 +524,10 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None: flat_buffer[pack_offset : pack_offset + numel].view_as(t) ) - work = work.then(callback) - self._allreduce_futures.append(work) + fut = work.get_future() + fut.add_done_callback(callback) + + self._allreduce_work.append(work) offset += chunk_size diff --git a/torchft/local_sgd_test.py b/torchft/local_sgd_test.py index 04aede4..881b96e 100644 --- a/torchft/local_sgd_test.py +++ b/torchft/local_sgd_test.py @@ -11,10 +11,12 @@ import torch from parameterized import parameterized from torch import Tensor, nn, optim +from torch.distributed.distributed_c10d import Work from torch.distributed.tensor import DTensor from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor from torchft.manager import Manager +from torchft.work import _DummyWork def create_manager() -> MagicMock: @@ -26,6 +28,11 @@ def create_manager() -> MagicMock: manager.errored.return_value = None + def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work: + return _DummyWork(tensor) + + manager.allreduce.side_effect = mock_allreduce + return manager @@ -66,7 +73,7 @@ class LocalSGDTest(TestCase): def test_local_sgd_healthy(self) -> None: model = SimpleModel() optimizer = optim.SGD(model.parameters()) - manager = create_autospec(Manager) + manager = create_manager() with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd: self.assertEqual(local_sgd._local_step, 0) inp = torch.rand(2, 3) @@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None: manager.should_commit.return_value = True # Define fake allreduce: multiplies buffer by 2 - def fake_allreduce( - tensor: Tensor, should_quantize: bool - ) -> torch.futures.Future[Tensor]: + def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work: tensor.mul_(2) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) manager.allreduce.side_effect = fake_allreduce @@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None: manager.should_commit.return_value = True # Define fake allreduce: multiplies buffer by 2 - def fake_allreduce( - tensor: Tensor, should_quantize: bool - ) -> torch.futures.Future[Tensor]: + def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work: tensor.mul_(2) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) manager.allreduce.side_effect = fake_allreduce diff --git a/torchft/manager.py b/torchft/manager.py index 09100c3..0b6e63b 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -39,11 +39,12 @@ import torch from torch.distributed import ReduceOp, TCPStore -from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp +from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work from torchft._torchft import ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout +from torchft.work import _DummyWork, _WorkWrapper if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -343,9 +344,7 @@ def shutdown(self, wait: bool = True) -> None: self._executor.shutdown(wait=wait) @torch.profiler.record_function("torchft::manager::allreduce") - def allreduce( - self, tensor: torch.Tensor, should_quantize: bool = False - ) -> torch.futures.Future[torch.Tensor]: + def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work: """ Fault tolerant allreduce the tensor and return a Future that will be completed when the tensor is ready. @@ -365,9 +364,7 @@ def allreduce( a Future that will be completed with the allreduced tensor """ if self.errored(): - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) self.wait_quorum() num_participants: int = self.num_participants() @@ -380,13 +377,14 @@ def allreduce( # Run the allreduce async and save the work object so we can wait on # it later. if should_quantize and IS_TRITON_AVAILABLE: - fut = allreduce_quantized( + work = allreduce_quantized( [tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream() ) else: work = self._pg.allreduce([tensor], ReduceOp.SUM) work.wait() - fut = work.get_future() + + fut = work.get_future() stream: Optional[torch.cuda.Stream] = ( torch.cuda.current_stream() if torch.cuda.is_available() else None @@ -413,7 +411,8 @@ def callback( fut = fut.then(callback) fut = self.wrap_future(fut, tensor) - return fut + + return _WorkWrapper(work, fut) except Exception as e: self._logger.exception( @@ -421,9 +420,7 @@ def callback( ) self.report_error(e) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut.set_result(tensor) - return fut + return _DummyWork(tensor) def report_error(self, e: Exception) -> None: """ diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index 6bdab58..ed2d11e 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -634,7 +634,7 @@ def all_reduce_callback( manager.start_quorum() t1 = torch.ones((1, 3), device=device) - fut = manager.allreduce(t1) - fut.wait() + work = manager.allreduce(t1) + work.wait() return t1 return None diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 3140319..d0c81a0 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -16,7 +16,8 @@ from torchft._torchft import QuorumResult from torchft.checkpointing.transport import CheckpointTransport from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode -from torchft.process_group import ProcessGroup, _DummyWork +from torchft.process_group import ProcessGroup +from torchft.work import _DummyWork def mock_should_commit( @@ -586,16 +587,16 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: manager._pg.allreduce.return_value = _DummyWork(None) self.assertTrue(manager.is_participating()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut = manager.allreduce(torch.tensor([1.0])) + work = manager.allreduce(torch.tensor([1.0])) + fut = work.get_future() result = fut.value() torch.testing.assert_close(result, torch.tensor([1.0 / 5])) # check healing numerics manager._healing = True self.assertFalse(manager.is_participating()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function - fut = manager.allreduce(torch.tensor([1.0])) + work = manager.allreduce(torch.tensor([1.0])) + fut = work.get_future() result = fut.value() torch.testing.assert_close(result, torch.tensor([0.0])) diff --git a/torchft/process_group.py b/torchft/process_group.py index 8f9c27b..4750dc9 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -69,6 +69,7 @@ from torchft.device_mesh import * # noqa: F401 from torchft.futures import context_timeout, stream_timeout from torchft.multiprocessing import _MonitoredPipe +from torchft.work import _DummyWork if TYPE_CHECKING: from torchft.manager import Manager @@ -790,21 +791,6 @@ def getBackendName(self) -> str: return "torchft-nccl" -class _DummyWork(dist._Work): - def __init__(self, result: object) -> None: - super().__init__() - self.result_ = result - # pyre-fixme[29]: Future is not a function - self.future_: torch.futures.Future[object] = torch.futures.Future() - self.future_.set_result(result) - - def wait(self, timeout: Optional[timedelta] = None) -> bool: - return True - - def get_future(self) -> torch.futures.Future[object]: - return self.future_ - - class ProcessGroupDummy(ProcessGroup): """ This process group discards all data passed to it and returns success. This diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 4c3455d..072cf1e 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -47,12 +47,12 @@ ProcessGroupGloo, ProcessGroupNCCL, ProcessGroupWrapper, - _DummyWork, _ErrorSwallowingWork, _ManagedWork, extend_device_mesh, ft_init_device_mesh, ) +from torchft.work import _DummyWork def dummy_init_pg() -> None: diff --git a/torchft/work.py b/torchft/work.py new file mode 100644 index 0000000..7211c0d --- /dev/null +++ b/torchft/work.py @@ -0,0 +1,37 @@ +from contextlib import nullcontext +from datetime import timedelta +from typing import Optional + +import torch +import torch.distributed as dist + + +class _DummyWork(dist._Work): + def __init__(self, result: object) -> None: + super().__init__() + self.result_ = result + # pyre-fixme[29]: Future is not a function + self.future_: torch.futures.Future[object] = torch.futures.Future() + self.future_.set_result(result) + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + return True + + def get_future(self) -> torch.futures.Future[object]: + return self.future_ + + +class _WorkWrapper(dist._Work): + def __init__( + self, work: dist._Work, fut: torch.futures.Future[torch.Tensor] + ) -> None: + super().__init__() + self._work = work + self._fut = fut + + def wait(self, timeout: Optional[timedelta] = None) -> bool: + self._fut.wait() + return True + + def get_future(self) -> torch.futures.Future[torch.Tensor]: + return self._fut From 07446f69486f04348694f2ea27d36ff0c1ba14e0 Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Thu, 31 Jul 2025 18:55:23 -0700 Subject: [PATCH 4/4] only use nightly pytorch in ci Summary: - change ci to only use nightly since block_current_stream is not in stable yet - fix new errors in nightly version of pyre - remove fixme[29] about future not being a function - make reduce_scatter_quantized return Work object --- .github/workflows/lint.yaml | 1 + .github/workflows/unittest.yaml | 6 +-- torchft/collectives.py | 81 +++++++++++++++++---------------- torchft/collectives_test.py | 4 +- torchft/futures.py | 1 - torchft/futures_test.py | 6 --- torchft/manager_test.py | 6 +-- torchft/process_group.py | 9 ++-- torchft/work.py | 1 - 9 files changed, 54 insertions(+), 61 deletions(-) diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 937e34c..589dada 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -26,6 +26,7 @@ jobs: pip install lintrunner lintrunner-adapters lintrunner init + pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128 pip install .[dev] -v - name: Run lintrunner run: | diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 3cb05d4..d1f5bd5 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -15,11 +15,7 @@ jobs: - runs-on: "linux.2xlarge" gpu-arch-type: "cpu" gpu-arch-version: "" - torch-version: "stable" - - runs-on: "linux.g5.12xlarge.nvidia.gpu" - gpu-arch-type: "cuda" - gpu-arch-version: "12.4" - torch-version: "stable" + torch-version: "nightly" - runs-on: "linux.g5.12xlarge.nvidia.gpu" gpu-arch-type: "cuda" gpu-arch-version: "12.4" diff --git a/torchft/collectives.py b/torchft/collectives.py index af95cbb..cd84b0b 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -162,7 +162,7 @@ def reduce_scatter_quantized( opts: ReduceScatterOptions | ReduceOp, process_group: "ProcessGroup", sync_stream: cuda.Stream | None = None, -) -> Future[None]: +) -> Work: """ Performs a quantized reduce-scatter operation on a list of tensors. @@ -196,10 +196,10 @@ def reduce_scatter_quantized( """ if isinstance(opts, ReduceOp): - reducescatter_opts = ReduceScatterOptions() + reducescatter_opts: ReduceScatterOptions = ReduceScatterOptions() reducescatter_opts.reduceOp = opts else: - reducescatter_opts = opts + reducescatter_opts: ReduceScatterOptions = opts # Check if the reduceOp is AVG or SUM if reducescatter_opts.reduceOp not in { @@ -211,15 +211,15 @@ def reduce_scatter_quantized( f"for quantized reduce-scatter, only AVG and SUM are supported" ) - rank = process_group.rank() - world_size = process_group.size() + rank: int = process_group.rank() + world_size: int = process_group.size() reduce_output_sizes = [ torch.Size((s[0] // world_size, *s[1:])) for s in get_padded_sizes(inputs, world_size) ] reduce_output_numels = [s.numel() for s in reduce_output_sizes] - reduce_outputs = [ + reduce_outputs: list[torch.Tensor] = [ o.view(s) for o, s in zip( output.split(reduce_output_numels), @@ -240,48 +240,51 @@ def reduce_scatter_quantized( quantized_inputs = fused_quantize_into_fp8(inputs, world_size) # Allocate output tensor where all-reduce results will be stored - quantized_inputs_out = torch.zeros_like(quantized_inputs) + quantized_inputs_out: torch.Tensor = torch.zeros_like(quantized_inputs) # Collect chunks and their scales from other ranks - process_group.alltoall_base( + work = process_group.alltoall_base( quantized_inputs_out.view(world_size, -1), quantized_inputs.view(world_size, -1), [], [], _to_alltoall_options(reducescatter_opts), - ).wait() - - # Reduce chunks locally in higher precision after dequantization. - # The output is again quantized. - fused_reduce_fp8( - inputs, - quantized_inputs_out, - world_size, - rank, - reducescatter_opts.reduceOp, ) + work.wait() - # Get view into the output tensor that corresponds to the - # current rank - quantized_reduce_scatter = ( - quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0) - ) - # Dequantize the result back to the original precision for - # the current rank - fused_dequantize_from_fp8( - reduce_outputs, - quantized_reduce_scatter, - 1, - ) + fut = work.get_future() - # pyre-ignore[29] - return _QuantizedOpFuture( - sync_stream, - [ - quantized_inputs, - quantized_inputs_out, - ], - [output], - ) + def callback(fut: Future[list[torch.Tensor]]) -> None: + nonlocal inputs, quantized_inputs_out, world_size, sync_stream, rank, reduce_outputs, reducescatter_opts + + with torch.cuda.stream(sync_stream): + # Setup stream dependency + fut.wait() + # Reduce chunks locally in higher precision after dequantization. + # The output is again quantized. + fused_reduce_fp8( + inputs, + quantized_inputs_out, + world_size, + rank, + reducescatter_opts.reduceOp, + ) + + # Get view into the output tensor that corresponds to the + # current rank + quantized_reduce_scatter = ( + quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0) + ) + # Dequantize the result back to the original precision for + # the current rank + fused_dequantize_from_fp8( + reduce_outputs, + quantized_reduce_scatter, + 1, + ) + + fut.add_done_callback(callback) + + return work def allreduce_quantized( diff --git a/torchft/collectives_test.py b/torchft/collectives_test.py index 6660abe..b73a18b 100644 --- a/torchft/collectives_test.py +++ b/torchft/collectives_test.py @@ -141,8 +141,8 @@ def _run_reduce_scatter_collective( opts = ReduceScatterOptions() opts.reduceOp = reduce_op - fut = reduce_scatter_quantized(actual_output, tensors, opts, pg) - fut.wait() + work = reduce_scatter_quantized(actual_output, tensors, opts, pg) + work.get_future().wait() padded_sizes = get_padded_sizes(tensors, world_size) padded_numel = sum(s.numel() for s in padded_sizes) diff --git a/torchft/futures.py b/torchft/futures.py index 52bb96e..c20ad65 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -148,7 +148,6 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]: loop = self._maybe_start_event_loop() - # pyre-fixme[29]: Future is not a function timed_fut: Future[T] = Future() handle: _TimerHandle = _TimerHandle() loop.call_soon_threadsafe( diff --git a/torchft/futures_test.py b/torchft/futures_test.py index cdc4cb1..59ca73d 100644 --- a/torchft/futures_test.py +++ b/torchft/futures_test.py @@ -24,38 +24,32 @@ def tearDown(self) -> None: _TIMEOUT_MANAGER._watchdog_interval = self._original_watchdog_interval def test_future_wait(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() with self.assertRaisesRegex(TimeoutError, "future did not complete within"): future_wait(fut, timeout=timedelta(seconds=0.01)) - # pyre-fixme[29]: Future is not a function fut = Future() fut.set_result(1) self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1) - # pyre-fixme[29]: Future is not a function fut = Future() fut.set_exception(RuntimeError("test")) with self.assertRaisesRegex(RuntimeError, "test"): future_wait(fut, timeout=timedelta(seconds=1.0)) def test_future_timeout(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01)) with self.assertRaisesRegex(TimeoutError, "future did not complete within"): timed_fut.wait() def test_future_timeout_result(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() timed_fut = future_timeout(fut, timeout=timedelta(seconds=10)) fut.set_result(1) self.assertEqual(timed_fut.wait(), 1) def test_future_timeout_exception(self) -> None: - # pyre-fixme[29]: Future is not a function fut = Future() timed_fut = future_timeout(fut, timeout=timedelta(seconds=10)) fut.set_exception(RuntimeError("test")) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index d0c81a0..b5616bb 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -401,7 +401,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: self.assertFalse(manager._errored) - bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function + bad_fut = torch.futures.Future() bad_fut.set_exception(RuntimeError("injected failure")) manager._pg.allreduce.return_value.get_future.return_value = bad_fut manager.allreduce(torch.tensor([1.0])).wait() @@ -542,7 +542,7 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None: self.assertIsNone(manager.errored()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function + fut = torch.futures.Future() wrapped_fut = manager.wrap_future(fut, 2) self.assertIsNone(manager.errored()) @@ -559,7 +559,7 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None: self.assertFalse(manager.errored()) - fut = torch.futures.Future() # pyre-fixme[29]: not a function + fut = torch.futures.Future() wrapped_fut = manager.wrap_future(fut, 2) wrapped_fut.wait() error = manager.errored() diff --git a/torchft/process_group.py b/torchft/process_group.py index 4750dc9..bfbfe56 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -183,6 +183,7 @@ def alltoall_base( """ raise NotImplementedError("not implemented") + # pyre-fixme[14]: inconsistent override def barrier(self, opts: BarrierOptions) -> Work: """ Synchronizes all processes. @@ -496,7 +497,7 @@ def alltoall_base( opts, ) - def barrier(self, opts: BarrierOptions) -> Work: + def barrier(self, opts: Optional[BarrierOptions] = None) -> Work: with self._run_context(): return self._wrap_work(self.parent.barrier(self._opts_hook(opts)), opts) @@ -866,7 +867,7 @@ def alltoall_base( self._work.append(res) return res - def barrier(self, opts: BarrierOptions) -> Work: + def barrier(self, opts: Optional[BarrierOptions] = None) -> Work: return _DummyWork(None) def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work: @@ -1497,7 +1498,7 @@ def _get_future( self, op_id: int, stream: Optional[torch.cuda.Stream] ) -> Future[object]: with self._futures_lock: - fut = Future() # pyre-fixme[29]: is not a function + fut = Future() self._futures[op_id] = _FutureMetadata(future=fut, stream=stream) assert self._pipe is not None self._pipe.send(("future", op_id)) @@ -1629,7 +1630,7 @@ def alltoall_base( opts, ) - def barrier(self, opts: BarrierOptions) -> Work: + def barrier(self, opts: Optional[BarrierOptions] = None) -> Work: return self._run_func("barrier", opts) def broadcast( diff --git a/torchft/work.py b/torchft/work.py index 7211c0d..8cb056a 100644 --- a/torchft/work.py +++ b/torchft/work.py @@ -10,7 +10,6 @@ class _DummyWork(dist._Work): def __init__(self, result: object) -> None: super().__init__() self.result_ = result - # pyre-fixme[29]: Future is not a function self.future_: torch.futures.Future[object] = torch.futures.Future() self.future_.set_result(result)