From 7bed9ba8a9579fe6be8040933f83e86b06db24be Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Thu, 7 Nov 2024 15:41:35 +0000 Subject: [PATCH 01/13] init --- .../device_communicators/pynccl.py | 21 ++-- vllm/distributed/parallel_state.py | 101 ++++++------------ 2 files changed, 46 insertions(+), 76 deletions(-) diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7319566545678..cefa9524c300d 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -102,24 +102,29 @@ def __init__( self.disabled = True def all_reduce(self, - tensor: torch.Tensor, + in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, - stream=None): + stream=None) -> torch.Tensor: if self.disabled: - return + return None # nccl communicator created on a specific device # will only work on tensors on the same device # otherwise it will cause "illegal memory access" - assert tensor.device == self.device, ( + assert in_tensor.device == self.device, ( f"this nccl communicator is created to work on {self.device}, " - f"but the input tensor is on {tensor.device}") + f"but the input tensor is on {in_tensor.device}") + + out_tensor = torch.empty_like(in_tensor) + if stream is None: stream = self.stream - self.nccl.ncclAllReduce(buffer_type(tensor.data_ptr()), - buffer_type(tensor.data_ptr()), tensor.numel(), - ncclDataTypeEnum.from_torch(tensor.dtype), + self.nccl.ncclAllReduce(buffer_type(in_tensor.data_ptr()), + buffer_type(out_tensor.data_ptr()), + in_tensor.numel(), + ncclDataTypeEnum.from_torch(in_tensor.dtype), ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + return out_tensor def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 0d15403264eee..b21f66ede4ce7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -96,43 +96,24 @@ def _register_group(group: "GroupCoordinator") -> None: _groups[group.unique_name] = weakref.ref(group) -if supports_custom_op(): +def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) - def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - group._all_reduce_in_place(tensor) - def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: - return +def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + return torch.empty_like(tensor) - direct_register_custom_op( - op_name="inplace_all_reduce", - op_func=inplace_all_reduce, - mutates_args=["tensor"], - fake_impl=inplace_all_reduce_fake, - ) - def outplace_all_reduce(tensor: torch.Tensor, - group_name: str) -> torch.Tensor: - assert group_name in _groups, f"Group {group_name} is not found." - group = _groups[group_name]() - if group is None: - raise ValueError(f"Group {group_name} is destroyed.") - return group._all_reduce_out_place(tensor) - - def outplace_all_reduce_fake(tensor: torch.Tensor, - group_name: str) -> torch.Tensor: - return torch.empty_like(tensor) - - direct_register_custom_op( - op_name="outplace_all_reduce", - op_func=outplace_all_reduce, - mutates_args=[], - fake_impl=outplace_all_reduce_fake, - ) +direct_register_custom_op( + op_name="all_reduce", + op_func=all_reduce, + mutates_args=[], + fake_impl=all_reduce_fake, +) class GroupCoordinator: @@ -361,8 +342,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return input_ if not supports_custom_op(): - self._all_reduce_in_place(input_) - return input_ + return torch.ops.vllm.all_reduce(input_, + group_name=self.unique_name) if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: @@ -373,31 +354,21 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: not self.hpu_communicator.disabled: return self.hpu_communicator.all_reduce(input_) - if self.ca_comm is not None and \ - not self.ca_comm.disabled and \ - self.ca_comm.should_custom_ar(input_): - return torch.ops.vllm.outplace_all_reduce( - input_, group_name=self.unique_name) - else: - torch.ops.vllm.inplace_all_reduce(input_, - group_name=self.unique_name) - return input_ + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: - ca_comm = self.ca_comm - assert ca_comm is not None - assert not ca_comm.disabled - out = ca_comm.custom_all_reduce(input_) + if supports_custom_op(): + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled: + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + pynccl_comm = self.pynccl_comm + assert pynccl_comm is not None and not pynccl_comm.disabled + out = pynccl_comm.all_reduce(input_) assert out is not None return out - def _all_reduce_in_place(self, input_: torch.Tensor) -> None: - pynccl_comm = self.pynccl_comm - if (pynccl_comm is not None and not pynccl_comm.disabled): - pynccl_comm.all_reduce(input_) - else: - torch.distributed.all_reduce(input_, group=self.device_group) - def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size # Bypass the function if we are using only 1 GPU. @@ -436,8 +407,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: output_tensor = output_tensor.reshape((world_size, ) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * - input_size[dim], ) + + (world_size * input_size[dim], ) + input_size[dim + 1:]) return output_tensor @@ -591,8 +561,7 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" assert src != self.rank_in_group, ( - "Invalid source rank. Source rank is the same as the current rank." - ) + "Invalid source rank. Source rank is the same as the current rank.") size_tensor = torch.empty(1, dtype=torch.long, device="cpu") @@ -754,9 +723,7 @@ def send_tensor_dict( group=metadata_group) else: # use group for GPU tensors - torch.distributed.send(tensor, - dst=self.ranks[dst], - group=group) + torch.distributed.send(tensor, dst=self.ranks[dst], group=group) return None def recv_tensor_dict( @@ -938,8 +905,7 @@ def get_tp_group() -> GroupCoordinator: def get_pp_group() -> GroupCoordinator: - assert _PP is not None, ( - "pipeline model parallel group is not initialized") + assert _PP is not None, ("pipeline model parallel group is not initialized") return _PP @@ -1050,8 +1016,8 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if (world_size != - tensor_model_parallel_size * pipeline_model_parallel_size): + if (world_size + != tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " @@ -1080,8 +1046,7 @@ def initialize_model_parallel( num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) global _PP - assert _PP is None, ( - "pipeline model parallel group is already initialized") + assert _PP is None, ("pipeline model parallel group is already initialized") group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) From fa683d45fc43bc60031c7233cbb742dc5f8f6fdb Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 8 Nov 2024 15:37:59 +0000 Subject: [PATCH 02/13] use context manager when running with pynccl --- vllm/distributed/parallel_state.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b21f66ede4ce7..e3ddd6e850c3f 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -364,10 +364,12 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: assert out is not None return out pynccl_comm = self.pynccl_comm - assert pynccl_comm is not None and not pynccl_comm.disabled - out = pynccl_comm.all_reduce(input_) - assert out is not None - return out + assert pynccl_comm is not None + with pynccl_comm.change_state(enable=True, + stream=torch.cuda.current_stream()): + out = pynccl_comm.all_reduce(input_) + assert out is not None + return out def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size @@ -1016,8 +1018,8 @@ def initialize_model_parallel( backend = backend or torch.distributed.get_backend( get_world_group().device_group) - if (world_size - != tensor_model_parallel_size * pipeline_model_parallel_size): + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): raise RuntimeError( f"world_size ({world_size}) is not equal to " f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " From 85766106aea5e602c5f13874dbebb5402d8b422d Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Fri, 8 Nov 2024 15:39:03 +0000 Subject: [PATCH 03/13] only run custom ar when should_custom_ar is true --- vllm/distributed/parallel_state.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index e3ddd6e850c3f..9a6fa8455b198 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -359,7 +359,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: if supports_custom_op(): ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled: + if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar( + input_): out = ca_comm.custom_all_reduce(input_) assert out is not None return out From 564cdadc52d61c91ed999e9697196eec94e86e51 Mon Sep 17 00:00:00 2001 From: Sage Moore Date: Tue, 12 Nov 2024 19:30:42 +0000 Subject: [PATCH 04/13] temporarily disable custom ar --- vllm/distributed/parallel_state.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 1106c8f1a1a5c..36dbbed8bde91 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -368,13 +368,13 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: - if supports_custom_op(): - ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar( - input_): - out = ca_comm.custom_all_reduce(input_) - assert out is not None - return out + # if supports_custom_op(): + # ca_comm = self.ca_comm + # if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar( + # input_): + # out = ca_comm.custom_all_reduce(input_) + # assert out is not None + # return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None with pynccl_comm.change_state(enable=True, From a86bb7c4a7fd14701f747b867cb743b8307ae4a9 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:09:55 -0800 Subject: [PATCH 05/13] update Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 52 ++++++++++++++++-------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 36dbbed8bde91..ec42365fa0758 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -108,12 +108,13 @@ def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor: return torch.empty_like(tensor) -direct_register_custom_op( - op_name="all_reduce", - op_func=all_reduce, - mutates_args=[], - fake_impl=all_reduce_fake, -) +if supports_custom_op(): + direct_register_custom_op( + op_name="all_reduce", + op_func=all_reduce, + mutates_args=[], + fake_impl=all_reduce_fake, + ) class GroupCoordinator: @@ -304,7 +305,6 @@ def graph_capture( # -------------------------------------------- # custom allreduce | enabled | enabled | # PyNccl | disabled| enabled | - # torch.distributed | enabled | disabled| # # Note that custom allreduce will have a runtime check, if the # tensor size is too large, it will fallback to the next @@ -349,10 +349,6 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: ipex.distributed.all_reduce(input_, group=self.device_group) return input_ - if not supports_custom_op(): - return torch.ops.vllm.all_reduce(input_, - group_name=self.unique_name) - if self.tpu_communicator is not None and \ not self.tpu_communicator.disabled: # TPU handles Dynamo with its own logic. @@ -365,16 +361,16 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.xpu_communicator is not None and \ not self.xpu_communicator.disabled: return self.xpu_communicator.all_reduce(input_) + return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: - # if supports_custom_op(): - # ca_comm = self.ca_comm - # if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar( - # input_): - # out = ca_comm.custom_all_reduce(input_) - # assert out is not None - # return out + ca_comm = self.ca_comm + if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar( + input_): + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None with pynccl_comm.change_state(enable=True, @@ -421,7 +417,8 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: output_tensor = output_tensor.reshape((world_size, ) + input_size) output_tensor = output_tensor.movedim(0, dim) output_tensor = output_tensor.reshape(input_size[:dim] + - (world_size * input_size[dim], ) + + (world_size * + input_size[dim], ) + input_size[dim + 1:]) return output_tensor @@ -445,8 +442,8 @@ def gather(self, dim += input_.dim() if self.xpu_communicator is not None and \ not self.xpu_communicator.disabled: - return self.xpu_communicator.gather(input_, self.rank_in_group, dst, - dim) + return self.xpu_communicator.gather(input_, self.rank_in_group, + dst, dim) # Allocate output tensor. if self.rank_in_group == dst: gather_list = [torch.empty_like(input_) for _ in range(world_size)] @@ -557,7 +554,8 @@ def recv_object(self, src: int) -> Any: assert src < self.world_size, f"Invalid src rank ({src})" assert src != self.rank_in_group, ( - "Invalid source rank. Source rank is the same as the current rank.") + "Invalid source rank. Source rank is the same as the current rank." + ) size_tensor = torch.empty(1, dtype=torch.long, device="cpu") @@ -719,7 +717,9 @@ def send_tensor_dict( group=metadata_group) else: # use group for GPU tensors - torch.distributed.send(tensor, dst=self.ranks[dst], group=group) + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) return None def recv_tensor_dict( @@ -903,7 +903,8 @@ def get_tp_group() -> GroupCoordinator: def get_pp_group() -> GroupCoordinator: - assert _PP is not None, ("pipeline model parallel group is not initialized") + assert _PP is not None, ( + "pipeline model parallel group is not initialized") return _PP @@ -1044,7 +1045,8 @@ def initialize_model_parallel( num_pipeline_model_parallel_groups: int = (world_size // pipeline_model_parallel_size) global _PP - assert _PP is None, ("pipeline model parallel group is already initialized") + assert _PP is None, ( + "pipeline model parallel group is already initialized") group_ranks = [] for i in range(num_pipeline_model_parallel_groups): ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) From 35d0e178a0460a4ffdc1346425c4141b49cfac88 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:14:51 -0800 Subject: [PATCH 06/13] improving Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 26 ++++++-------------------- 1 file changed, 6 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index ec42365fa0758..a0ec264bca847 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -299,22 +299,6 @@ def graph_capture( stream.wait_stream(curr_stream) with torch.cuda.stream(stream), maybe_ca_context: - # In graph mode, we have to be very careful about the collective - # operations. The current status is: - # allreduce \ Mode | Eager | Graph | - # -------------------------------------------- - # custom allreduce | enabled | enabled | - # PyNccl | disabled| enabled | - # - # Note that custom allreduce will have a runtime check, if the - # tensor size is too large, it will fallback to the next - # available option. - # In summary: When using CUDA graph, we use - # either custom all-reduce kernel or pynccl. When not using - # CUDA graph, we use either custom all-reduce kernel or - # PyTorch NCCL. We always prioritize using custom all-reduce - # kernel but fall back to PyTorch or pynccl if it is - # disabled or not supported. pynccl_comm = self.pynccl_comm maybe_pynccl_context: Any if not pynccl_comm: @@ -337,8 +321,8 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: coordinator. In addition, PyTorch custom ops do not support mutation or returning - a new tensor in the same op. So we need to figure out if the op is - in-place or out-of-place ahead of time. + a new tensor in the same op. So we always make the all-reduce operation + out-of-place. """ # Bypass the function if we are using only 1 GPU. if self.world_size == 1: @@ -365,9 +349,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name) def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + # always try custom allreduce first, + # and then pynccl. ca_comm = self.ca_comm - if ca_comm is not None and not ca_comm.disabled and ca_comm.should_custom_ar( - input_): + if ca_comm is not None and not ca_comm.disabled and \ + ca_comm.should_custom_ar(input_): out = ca_comm.custom_all_reduce(input_) assert out is not None return out From 3e9d2180270ffbced901dde956678b4cdf60453e Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:19:20 -0800 Subject: [PATCH 07/13] enable pynccl by default Signed-off-by: youkaichao --- docs/source/getting_started/debugging.rst | 1 - tests/distributed/test_utils.py | 2 -- vllm/distributed/device_communicators/pynccl.py | 5 ----- 3 files changed, 8 deletions(-) diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 77bf550601346..0c1afcbd7c0b9 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -86,7 +86,6 @@ If GPU/CPU communication cannot be established, you can use the following Python from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator pynccl = PyNcclCommunicator(group=gloo_group, device=local_rank) - pynccl.disabled = False s = torch.cuda.Stream() with torch.cuda.stream(s): diff --git a/tests/distributed/test_utils.py b/tests/distributed/test_utils.py index 686b697c98e03..5fb1ae7b29fd2 100644 --- a/tests/distributed/test_utils.py +++ b/tests/distributed/test_utils.py @@ -70,14 +70,12 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2): rank=rank, world_size=WORLD_SIZE) pynccl1 = PyNcclCommunicator(pg1, device=rank) - pynccl1.disabled = False if rank <= 2: pg2 = StatelessProcessGroup.create(host="127.0.0.1", port=port2, rank=rank, world_size=3) pynccl2 = PyNcclCommunicator(pg2, device=rank) - pynccl2.disabled = False data = torch.tensor([rank]).cuda() pynccl1.all_reduce(data) pg1.barrier() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index fc86d7a0b4d93..d4e3f81747038 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -106,11 +106,6 @@ def __init__( self.stream.synchronize() del data - # by default it is disabled, e.g. in profiling models and prefill phase. - # to use it, use under `with obj.change_state(enable=True)`, usually - # when we are using CUDA graph. - self.disabled = True - def all_reduce(self, in_tensor: torch.Tensor, op: ReduceOp = ReduceOp.SUM, From 54c7fb134be973a80d3a78cde249a126d4e74c6d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:23:34 -0800 Subject: [PATCH 08/13] draft Signed-off-by: youkaichao --- vllm/v1/worker/gpu_model_runner.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 02f9498142bb7..13cbc8fa39c03 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -10,6 +10,7 @@ from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig +from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger @@ -570,8 +571,9 @@ def capture_model(self) -> None: # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. - for num_tokens in reversed(self.cudagraph_batch_sizes): - self._dummy_run(self.model, num_tokens, self.kv_caches) + with graph_capture(): + for num_tokens in reversed(self.cudagraph_batch_sizes): + self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] From 145ad3b49314f5ae784303cb732aff42c08dc00f Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:25:42 -0800 Subject: [PATCH 09/13] simplify Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a0ec264bca847..65d619ea8992e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -304,8 +304,7 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state( - enable=True, stream=torch.cuda.current_stream()) + maybe_pynccl_context = pynccl_comm.change_state(stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context From 100d26ca641061edd82bd533b214f24098041c32 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:27:37 -0800 Subject: [PATCH 10/13] simplify Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 65d619ea8992e..d71c84f7c098e 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -358,11 +358,11 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: return out pynccl_comm = self.pynccl_comm assert pynccl_comm is not None - with pynccl_comm.change_state(enable=True, - stream=torch.cuda.current_stream()): - out = pynccl_comm.all_reduce(input_) - assert out is not None - return out + # TODO: pynccl should not use `stream=` + # it can just always use the current stream. + out = pynccl_comm.all_reduce(input_, stream=torch.cuda.current_stream()) + assert out is not None + return out def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: world_size = self.world_size From a2dceca926cf570977b40f28e131363fe8c61d6d Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 16:34:14 -0800 Subject: [PATCH 11/13] add fallback Signed-off-by: youkaichao --- vllm/distributed/parallel_state.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d71c84f7c098e..ccbe00386c5da 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -304,7 +304,8 @@ def graph_capture( if not pynccl_comm: maybe_pynccl_context = nullcontext() else: - maybe_pynccl_context = pynccl_comm.change_state(stream=torch.cuda.current_stream()) + maybe_pynccl_context = pynccl_comm.change_state( + stream=torch.cuda.current_stream()) with maybe_pynccl_context: yield graph_capture_context @@ -360,8 +361,15 @@ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: assert pynccl_comm is not None # TODO: pynccl should not use `stream=` # it can just always use the current stream. - out = pynccl_comm.all_reduce(input_, stream=torch.cuda.current_stream()) - assert out is not None + out = pynccl_comm.all_reduce(input_, + stream=torch.cuda.current_stream()) + if out is None: + # fall back to the default all-reduce using PyTorch. + # this usually happens during testing. + # when we run the model, allreduce only happens for the TP + # group, where we always have either custom allreduce or pynccl. + out = input_.clone() + torch.distributed.all_reduce(out, group=self.device_group) return out def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: From c50abd226853cd51ea879e11abb5f8c5c0cc05d7 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 19:10:58 -0800 Subject: [PATCH 12/13] fix tests (update for out-of-place allreduce) Signed-off-by: youkaichao --- tests/distributed/test_pynccl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index f702d7c46ea73..4d63a80a0fe52 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -60,7 +60,7 @@ def worker_fn(): tensor = torch.ones(16, 1024, 1024, dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == pynccl_comm.world_size @@ -84,12 +84,12 @@ def multiple_allreduce_worker_fn(): with pynccl_comm.change_state(enable=True): # two groups can communicate independently if torch.distributed.get_rank() in [0, 1]: - pynccl_comm.all_reduce(tensor) - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 4 else: - pynccl_comm.all_reduce(tensor) + tensor = pynccl_comm.all_reduce(tensor) result = tensor.mean().cpu().item() assert result == 2 @@ -140,14 +140,12 @@ def worker_fn_with_cudagraph(): with torch.cuda.graph( graph, stream=pynccl_comm.stream), pynccl_comm.change_state( enable=True): - # operation during the graph capture is recorded but not executed - # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#creating-a-graph-using-stream-capture # noqa - pynccl_comm.all_reduce(a) + a_out = pynccl_comm.all_reduce(a) pynccl_comm.stream.synchronize() - assert a.mean().cpu().item() == pynccl_comm.world_size**0 + assert a_out.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() pynccl_comm.stream.synchronize() - assert a.mean().cpu().item() == pynccl_comm.world_size**1 + assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 @worker_fn_wrapper From 53e009905720a5d4cd6796e80e7662ae4643d199 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 25 Nov 2024 19:31:55 -0800 Subject: [PATCH 13/13] fix tests Signed-off-by: youkaichao --- tests/distributed/test_pynccl.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index 4d63a80a0fe52..fb24d6bc2c100 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -142,7 +142,6 @@ def worker_fn_with_cudagraph(): enable=True): a_out = pynccl_comm.all_reduce(a) pynccl_comm.stream.synchronize() - assert a_out.mean().cpu().item() == pynccl_comm.world_size**0 graph.replay() pynccl_comm.stream.synchronize() assert a_out.mean().cpu().item() == pynccl_comm.world_size**1