Skip to content

Commit cb39d98

Browse files
committed
only use nightly pytorch in ci
Summary: - change ci to only use nightly since block_current_stream is not in stable yet - fix new errors in nightly version of pyre - remove fixme[29] about future not being a function - make reduce_scatter_quantized return Work object
1 parent d1d5844 commit cb39d98

File tree

10 files changed

+56
-62
lines changed

10 files changed

+56
-62
lines changed

.github/workflows/lint.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ jobs:
2626
pip install lintrunner lintrunner-adapters
2727
lintrunner init
2828
29+
pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
2930
pip install .[dev] -v
3031
- name: Run lintrunner
3132
run: |

.github/workflows/unittest.yaml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,7 @@ jobs:
1515
- runs-on: "linux.2xlarge"
1616
gpu-arch-type: "cpu"
1717
gpu-arch-version: ""
18-
torch-version: "stable"
19-
- runs-on: "linux.g5.12xlarge.nvidia.gpu"
20-
gpu-arch-type: "cuda"
21-
gpu-arch-version: "12.4"
22-
torch-version: "stable"
18+
torch-version: "nightly"
2319
- runs-on: "linux.g5.12xlarge.nvidia.gpu"
2420
gpu-arch-type: "cuda"
2521
gpu-arch-version: "12.4"

torchft/collectives.py

Lines changed: 42 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def reduce_scatter_quantized(
162162
opts: ReduceScatterOptions | ReduceOp,
163163
process_group: "ProcessGroup",
164164
sync_stream: cuda.Stream | None = None,
165-
) -> Future[None]:
165+
) -> Work:
166166
"""
167167
Performs a quantized reduce-scatter operation on a list of tensors.
168168
@@ -196,10 +196,10 @@ def reduce_scatter_quantized(
196196
"""
197197

198198
if isinstance(opts, ReduceOp):
199-
reducescatter_opts = ReduceScatterOptions()
199+
reducescatter_opts: ReduceScatterOptions = ReduceScatterOptions()
200200
reducescatter_opts.reduceOp = opts
201201
else:
202-
reducescatter_opts = opts
202+
reducescatter_opts: ReduceScatterOptions = opts
203203

204204
# Check if the reduceOp is AVG or SUM
205205
if reducescatter_opts.reduceOp not in {
@@ -211,15 +211,15 @@ def reduce_scatter_quantized(
211211
f"for quantized reduce-scatter, only AVG and SUM are supported"
212212
)
213213

214-
rank = process_group.rank()
215-
world_size = process_group.size()
214+
rank: int = process_group.rank()
215+
world_size: int = process_group.size()
216216

217217
reduce_output_sizes = [
218218
torch.Size((s[0] // world_size, *s[1:]))
219219
for s in get_padded_sizes(inputs, world_size)
220220
]
221221
reduce_output_numels = [s.numel() for s in reduce_output_sizes]
222-
reduce_outputs = [
222+
reduce_outputs: list[torch.Tensor] = [
223223
o.view(s)
224224
for o, s in zip(
225225
output.split(reduce_output_numels),
@@ -240,48 +240,51 @@ def reduce_scatter_quantized(
240240
quantized_inputs = fused_quantize_into_fp8(inputs, world_size)
241241

242242
# Allocate output tensor where all-reduce results will be stored
243-
quantized_inputs_out = torch.zeros_like(quantized_inputs)
243+
quantized_inputs_out: torch.Tensor = torch.zeros_like(quantized_inputs)
244244
# Collect chunks and their scales from other ranks
245-
process_group.alltoall_base(
245+
work = process_group.alltoall_base(
246246
quantized_inputs_out.view(world_size, -1),
247247
quantized_inputs.view(world_size, -1),
248248
[],
249249
[],
250250
_to_alltoall_options(reducescatter_opts),
251-
).wait()
252-
253-
# Reduce chunks locally in higher precision after dequantization.
254-
# The output is again quantized.
255-
fused_reduce_fp8(
256-
inputs,
257-
quantized_inputs_out,
258-
world_size,
259-
rank,
260-
reducescatter_opts.reduceOp,
261251
)
252+
work.wait()
262253

263-
# Get view into the output tensor that corresponds to the
264-
# current rank
265-
quantized_reduce_scatter = (
266-
quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0)
267-
)
268-
# Dequantize the result back to the original precision for
269-
# the current rank
270-
fused_dequantize_from_fp8(
271-
reduce_outputs,
272-
quantized_reduce_scatter,
273-
1,
274-
)
254+
fut = work.get_future()
275255

276-
# pyre-ignore[29]
277-
return _QuantizedOpFuture(
278-
sync_stream,
279-
[
280-
quantized_inputs,
281-
quantized_inputs_out,
282-
],
283-
[output],
284-
)
256+
def callback(fut: Future[list[torch.Tensor]]) -> None:
257+
nonlocal inputs, quantized_inputs_out, world_size, sync_stream, rank, reduce_outputs, reducescatter_opts
258+
259+
with torch.cuda.stream(sync_stream):
260+
# Setup stream dependency
261+
fut.wait()
262+
# Reduce chunks locally in higher precision after dequantization.
263+
# The output is again quantized.
264+
fused_reduce_fp8(
265+
inputs,
266+
quantized_inputs_out,
267+
world_size,
268+
rank,
269+
reducescatter_opts.reduceOp,
270+
)
271+
272+
# Get view into the output tensor that corresponds to the
273+
# current rank
274+
quantized_reduce_scatter = (
275+
quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0)
276+
)
277+
# Dequantize the result back to the original precision for
278+
# the current rank
279+
fused_dequantize_from_fp8(
280+
reduce_outputs,
281+
quantized_reduce_scatter,
282+
1,
283+
)
284+
285+
fut.add_done_callback(callback)
286+
287+
return work
285288

286289

287290
def allreduce_quantized(

torchft/collectives_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def _run_reduce_scatter_collective(
141141
opts = ReduceScatterOptions()
142142
opts.reduceOp = reduce_op
143143

144-
fut = reduce_scatter_quantized(actual_output, tensors, opts, pg)
145-
fut.wait()
144+
work = reduce_scatter_quantized(actual_output, tensors, opts, pg)
145+
work.get_future().wait()
146146

147147
padded_sizes = get_padded_sizes(tensors, world_size)
148148
padded_numel = sum(s.numel() for s in padded_sizes)

torchft/futures.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
148148

149149
loop = self._maybe_start_event_loop()
150150

151-
# pyre-fixme[29]: Future is not a function
152151
timed_fut: Future[T] = Future()
153152
handle: _TimerHandle = _TimerHandle()
154153
loop.call_soon_threadsafe(

torchft/futures_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,38 +24,32 @@ def tearDown(self) -> None:
2424
_TIMEOUT_MANAGER._watchdog_interval = self._original_watchdog_interval
2525

2626
def test_future_wait(self) -> None:
27-
# pyre-fixme[29]: Future is not a function
2827
fut = Future()
2928
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
3029
future_wait(fut, timeout=timedelta(seconds=0.01))
3130

32-
# pyre-fixme[29]: Future is not a function
3331
fut = Future()
3432
fut.set_result(1)
3533
self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1)
3634

37-
# pyre-fixme[29]: Future is not a function
3835
fut = Future()
3936
fut.set_exception(RuntimeError("test"))
4037
with self.assertRaisesRegex(RuntimeError, "test"):
4138
future_wait(fut, timeout=timedelta(seconds=1.0))
4239

4340
def test_future_timeout(self) -> None:
44-
# pyre-fixme[29]: Future is not a function
4541
fut = Future()
4642
timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01))
4743
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
4844
timed_fut.wait()
4945

5046
def test_future_timeout_result(self) -> None:
51-
# pyre-fixme[29]: Future is not a function
5247
fut = Future()
5348
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
5449
fut.set_result(1)
5550
self.assertEqual(timed_fut.wait(), 1)
5651

5752
def test_future_timeout_exception(self) -> None:
58-
# pyre-fixme[29]: Future is not a function
5953
fut = Future()
6054
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
6155
fut.set_exception(RuntimeError("test"))

torchft/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,8 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
388388
)
389389
else:
390390
work = self._pg.allreduce([tensor], ReduceOp.SUM)
391-
work.wait()
391+
if torch.cuda.is_available():
392+
work.block_current_stream()
392393

393394
fut = work.get_future()
394395

torchft/manager_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None:
401401

402402
self.assertFalse(manager._errored)
403403

404-
bad_fut = torch.futures.Future() # pyre-fixme[29]: not a function
404+
bad_fut = torch.futures.Future()
405405
bad_fut.set_exception(RuntimeError("injected failure"))
406406
manager._pg.allreduce.return_value.get_future.return_value = bad_fut
407407
manager.allreduce(torch.tensor([1.0])).wait()
@@ -542,7 +542,7 @@ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
542542

543543
self.assertIsNone(manager.errored())
544544

545-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
545+
fut = torch.futures.Future()
546546
wrapped_fut = manager.wrap_future(fut, 2)
547547
self.assertIsNone(manager.errored())
548548

@@ -559,7 +559,7 @@ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
559559

560560
self.assertFalse(manager.errored())
561561

562-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
562+
fut = torch.futures.Future()
563563
wrapped_fut = manager.wrap_future(fut, 2)
564564
wrapped_fut.wait()
565565
error = manager.errored()

torchft/process_group.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def alltoall_base(
183183
"""
184184
raise NotImplementedError("not implemented")
185185

186+
# pyre-fixme[14]: inconsistent override
186187
def barrier(self, opts: BarrierOptions) -> Work:
187188
"""
188189
Synchronizes all processes.
@@ -496,7 +497,7 @@ def alltoall_base(
496497
opts,
497498
)
498499

499-
def barrier(self, opts: BarrierOptions) -> Work:
500+
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
500501
with self._run_context():
501502
return self._wrap_work(self.parent.barrier(self._opts_hook(opts)), opts)
502503

@@ -866,7 +867,7 @@ def alltoall_base(
866867
self._work.append(res)
867868
return res
868869

869-
def barrier(self, opts: BarrierOptions) -> Work:
870+
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
870871
return _DummyWork(None)
871872

872873
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
@@ -1497,7 +1498,7 @@ def _get_future(
14971498
self, op_id: int, stream: Optional[torch.cuda.Stream]
14981499
) -> Future[object]:
14991500
with self._futures_lock:
1500-
fut = Future() # pyre-fixme[29]: is not a function
1501+
fut = Future()
15011502
self._futures[op_id] = _FutureMetadata(future=fut, stream=stream)
15021503
assert self._pipe is not None
15031504
self._pipe.send(("future", op_id))
@@ -1629,7 +1630,7 @@ def alltoall_base(
16291630
opts,
16301631
)
16311632

1632-
def barrier(self, opts: BarrierOptions) -> Work:
1633+
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
16331634
return self._run_func("barrier", opts)
16341635

16351636
def broadcast(

torchft/work.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ class _DummyWork(dist._Work):
1010
def __init__(self, result: object) -> None:
1111
super().__init__()
1212
self.result_ = result
13-
# pyre-fixme[29]: Future is not a function
1413
self.future_: torch.futures.Future[object] = torch.futures.Future()
1514
self.future_.set_result(result)
1615

0 commit comments

Comments
 (0)