diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 77bf550601346..0c1afcbd7c0b9 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank) - pynccl.disabled = False s = torch.cuda.Stream() with torch.cuda.stream(s): diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index f702d7c46ea73..fb24d6bc2c100 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -60,7 +60,7 @@ def worker_fn(): tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == pynccl_comm.world_size @@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn(): with pynccl_comm.change_state(enable=True): # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.all_reduce(tensor) - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 4 else: - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 2 @@ -140,14 +140,11 @@ def worker_fn_with_cudagraph(): with torch.cuda.graph( graph, stream=pynccl_comm.stream), pynccl_comm.change_state( enable=True): - # operation during the graph capture is recorded but not executed - # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - pynccl_comm.all_reduce(a) + a_out = pynccl_comm.all_reduce(a) pynccl_comm.stream.synchronize() - assert a.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() pynccl_comm.stream.synchronize() - assert a.mean().cpu().item() == pynccl_comm.world_size**1 + assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 @worker_fn_wrapper diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 686b697c98e03..5fb1ae7b29fd2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): rank=rank, world_size=WORLD_SIZE) pynccl1 = PyNcclCommunicator(pg1, device=rank) - pynccl1.disabled = False if rank <= 2: pg2 = StatelessProcessGroup.create(host="127.0.0.1", port=port2, rank=rank, world_size=3) pynccl2 = PyNcclCommunicator(pg2, device=rank) - pynccl2.disabled = False data = torch.tensor([rank]).cuda() pynccl1.all_reduce(data) pg1.barrier() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7411304eb18fa..d4e3f81747038 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -106,30 +106,30 @@ def __init__( self.stream.synchronize() del data - # by default it is disabled, e.g. in profiling models and prefill phase. - # to use it, use under `with obj.change_state(enable=True)`, usually - # when we are using CUDA graph. - self.disabled = True - def all_reduce(self, - tensor: torch.Tensor, + in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None): + stream=None) -> torch.Tensor: if self.disabled: - return + return None # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" - assert tensor.device == self.device, ( + assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + if stream is None: stream = self.stream - self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), - buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + return out_tensor def all_gather(self, output_tensor: torch.Tensor, diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 87ade377266a2..ccbe00386c5da 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -96,42 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) -if supports_custom_op(): - - def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce_in_place(tensor) - - def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: - return +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) - direct_register_custom_op( - op_name="inplace_all_reduce", - op_func=inplace_all_reduce, - mutates_args=["tensor"], - fake_impl=inplace_all_reduce_fake, - ) - def outplace_all_reduce(tensor: torch.Tensor, - group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor) +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) - def outplace_all_reduce_fake(tensor: torch.Tensor, - group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) +if supports_custom_op(): direct_register_custom_op( - op_name="outplace_all_reduce", - op_func=outplace_all_reduce, + op_name="all_reduce", + op_func=all_reduce, mutates_args=[], - fake_impl=outplace_all_reduce_fake, + fake_impl=all_reduce_fake, ) @@ -317,30 +299,13 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| - # - # Note that custom allreduce will have a runtime check, if the - # tensor size is too large, it will fallback to the next - # available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. pynccl_comm = self.pynccl_comm maybe_pynccl_context: Any if not pynccl_comm: maybe_pynccl_context = nullcontext() else: maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) + stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context @@ -356,8 +321,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: coordinator. In addition, PyTorch custom ops do not support mutation or returning - a new tensor in the same op. So we need to figure out if the op is - in-place or out-of-place ahead of time. + a new tensor in the same op. So we always make the all-reduce operation + out-of-place. """ # Bypass the function if we are using only 1 GPU. if self.world_size == 1: @@ -368,10 +333,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: ipex.distributed.all_reduce(input_, group=self.device_group) return input_ - if not supports_custom_op(): - self._all_reduce_in_place(input_) - return input_ - if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. @@ -385,30 +346,31 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: not self.xpu_communicator.disabled: return self.xpu_communicator.all_reduce(input_) - if self.ca_comm is not None and \ - not self.ca_comm.disabled and \ - self.ca_comm.should_custom_ar(input_): - return torch.ops.vllm.outplace_all_reduce( - input_, group_name=self.unique_name) - else: - torch.ops.vllm.inplace_all_reduce(input_, - group_name=self.unique_name) - return input_ + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + # always try custom allreduce first, + # and then pynccl. ca_comm = self.ca_comm - assert ca_comm is not None - assert not ca_comm.disabled - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out - - def _all_reduce_in_place(self, input_: torch.Tensor) -> None: + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm - if (pynccl_comm is not None and not pynccl_comm.disabled): - pynccl_comm.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, group=self.device_group) + assert pynccl_comm is not None + # TODO: pynccl should not use `stream=` + # it can just always use the current stream. + out = pynccl_comm.all_reduce(input_, + stream=torch.cuda.current_stream()) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) + return out def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 02f9498142bb7..13cbc8fa39c03 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,6 +10,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -570,8 +571,9 @@ def capture_model(self) -> None: # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - for num_tokens in reversed(self.cudagraph_batch_sizes): - self._dummy_run(self.model, num_tokens, self.kv_caches) + with graph_capture(): + for num_tokens in reversed(self.cudagraph_batch_sizes): + self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0]