From 31be8c716365869aa964f1b8bfd01abf5d270def Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 1 Nov 2024 21:51:14 +0000 Subject: [PATCH] add flux support Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 173 ++++++++++++++++-- .../device_communicators/pynccl.py | 20 ++ .../device_communicators/pynccl_wrapper.py | 20 ++ vllm/distributed/parallel_state.py | 14 +- vllm/envs.py | 5 + 5 files changed, 214 insertions(+), 18 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index d9db7bc949de6..80668039a9cf7 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -6,14 +6,26 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) +import vllm.envs as envs + from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) -from vllm.distributed import tensor_model_parallel_all_reduce +from vllm.distributed import tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather from vllm.distributed.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.logger import init_logger +use_flux = False +if envs.VLLM_USE_FLUX: + try: + import flux + use_flux = True + print("USE FLUX") + except ImportError: + use_flux = False + + logger = init_logger(__name__) # TODO: factor out somehow @@ -67,7 +79,7 @@ def match_gemm_rs_ag_gemm( return mm_2, new_residual -@torch.library.custom_op("vllm::gemm_rs_ag_gemm", mutates_args=()) +@torch.library.custom_op("vllm::gemm_rs_ag_gemm_old", mutates_args=()) def gemm_rs_ag_gemm( residual: torch.Tensor, old_my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, @@ -131,7 +143,7 @@ def gemm_rs_ag_gemm( return mm_2[0], new_residual.clone(), slice_scatter -@torch.library.register_fake("vllm::gemm_rs_ag_gemm") +#@torch.library.register_fake("vllm::gemm_rs_ag_gemm") def gemm_rs_ag_gemm_fake( residual: torch.Tensor, my_residual: torch.Tensor, @@ -159,6 +171,124 @@ def gemm_rs_ag_gemm_fake( return (mm_res, my_residual, residual) +def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weights: torch.Size): + + if use_flux: + gemm_rs_op = flux.GemmRS( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + gemm_1_weights[0], # N + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: bfloat16 requires fuse_reduction=False. + fuse_reduction=False, + ) + + ag_gemm_op = flux.AGKernel( + get_tp_group().device_group, + 1, # One node + 8192, # Max M. TODO: Pass in correctly. + gemm_2_weights[0], # N + gemm_2_weights[1], # K + # TODO: Pass in input dtype correctly. + # TODO: It would be nicer to modify flux to dispatch based on dtype + # at run time, but I don't know what the downside would be. + # Similar comment for max m. + torch.float16, + torch.float16, + # Note: transpose_weight=False means that B is transposed + transpose_weight=False, + # Note: if local_copy=True, I hit the following runtime error: + # /flux/src/all_gather/ths_op/all_gather_gemm_kernel.cc:648 + # Check failed: 33554432((input.numel() * input.element_size())) + # == 139836453421056((this->chunk_size)) + local_copy=False, + ) + + gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0) + ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt) + + name = f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_{gemm_2_weights[0]}_{gemm_2_weights[1]}" + else: + group_name = get_world_name() + + gemm_rs = lambda act, wt: torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + act, wt.transpose(1,0), 'avg', 0, group_name) + + ag_gemm = lambda act, wt: torch.ops.symm_mem.fused_all_gather_matmul.default( + act, [wt.transpose(1,0)], 0, group_name)[1] + + name = "gemm_rs_ag_gemm" + + def gemm_rs_ag_gemm( + residual: torch.Tensor, old_my_residual: torch.Tensor, + gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, + rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, + first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + + if should_slice(residual.shape) and first_layer: + res_slices = slice_residual(residual) + slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] + residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) + my_residual = residual_chunk[0] + else: + my_residual = residual + slice_size = residual.shape[0] + + if not should_slice(residual.shape): + output = torch.matmul(gemm_1_activations, gemm_1_weights.transpose(1,0)) + reduced_output = tensor_model_parallel_all_reduce(output) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=reduced_output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1,0)) + return mm_2, new_residual, new_residual.clone() + else: + group_name = get_world_name() + output = gemm_rs(gemm_1_activations, gemm_1_weights) + + norm_res = torch.ops.higher_order.auto_functionalized( + torch.ops._C.fused_add_rms_norm.default, + input=output, + residual=my_residual, + weight=rms_norm_weight, + epsilon=1e-05) + normalized = norm_res[1] + new_residual = norm_res[2] + + residual_1 = residual if first_layer else old_my_residual + slice_scatter = torch.ops.aten.slice_scatter.default( + residual_1, new_residual, 0, 0, slice_size) + split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) + new_residual = split_2[0] + + mm_2 = ag_gemm(normalized, gemm_2_weights) + + # TODO: can we avoid clone here? + return mm_2[0], new_residual.clone(), slice_scatter + + if not hasattr(torch.ops.vllm, name): + logger.info("registering torch.ops.vllm.%s", name) + torch.library.custom_op(f"vllm::{name}", gemm_rs_ag_gemm, mutates_args=()) + torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake) + assert getattr(torch.ops.vllm, name) + + return getattr(torch.ops.vllm, name).default + + def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -181,7 +311,8 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return normalized - +# Register this as a custom op since all reduce cannot be torch.compiled. +#@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: @@ -195,12 +326,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, # is this the right thing to call it on? if should_slice(gemm_1_activations.shape): - group_name = get_world_name() - world_size = get_tensor_model_parallel_world_size() - all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( - my_residual, world_size, group_name) - wait_tensor = torch.ops._c10d_functional.wait_tensor.default( - all_gather) + if True: + group_name = get_world_name() + world_size = get_tensor_model_parallel_world_size() + all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( + my_residual, world_size, group_name) + wait_tensor = torch.ops._c10d_functional.wait_tensor.default( + all_gather) + else: + wait_tensor = tensor_model_parallel_all_gather(my_residual) else: wait_tensor = my_residual @@ -214,6 +348,15 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return norm_res[1] +#@torch.library.register_fake("vllm::gemm_ag_final") +def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, + gemm_1_activations: torch.Tensor, + rms_norm_weights: torch.Tensor) -> torch.Tensor: + return torch.empty([gemm_1_activations.shape[0], + my_residual.shape[1]], + dtype=my_residual.dtype, device=my_residual.device) + + class CollectiveFusionPass(InductorPass): def __init__(self): @@ -273,8 +416,14 @@ def find_min_index(match: Match) -> int: res_replacements) > 0 else match.kwargs["residual"] kwargs["old_my_residual"] = my_res_replacements[-1] if len( my_res_replacements) > 0 else match.kwargs["residual"] + + # TODO: use get + gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape + gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape + fused_node = graph.call_function( - torch.ops.vllm.gemm_rs_ag_gemm.default, kwargs=kwargs) + get_gemm_rs_ag_gemm(use_flux, gemm_1_w, gemm_2_w), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7319566545678..0cff8e175b916 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -121,6 +121,26 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 7619c98f22148..dbaa8b6b34716 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,17 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -271,6 +282,15 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index d8dba0b7e2c06..f1c87d1a14c21 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -30,12 +30,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from unittest.mock import patch -try: - import flux - has_flux = True -except ImportError: - has_flux = False - import torch import torch.distributed from torch.distributed import Backend, ProcessGroup, _symmetric_memory @@ -45,6 +39,14 @@ from vllm.platforms import current_platform from vllm.utils import supports_custom_op +has_flux = False +if envs.VLLM_USE_FLUX: + try: + import flux + has_flux = True + except ImportError: + has_flux = False + @dataclass class GraphCaptureContext: diff --git a/vllm/envs.py b/vllm/envs.py index 720b56b4b8c94..35cb7439dbd8f 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -70,6 +70,7 @@ VLLM_CUSTOM_OPS: List[str] = [] VLLM_DISABLED_KERNELS: List[str] = [] VLLM_USE_V1: bool = False + VLLM_USE_FLUX: bool = False def get_default_cache_root(): @@ -465,6 +466,10 @@ def get_default_config_root(): # If set, use the V1 code path. "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), + + # If set, try to use the flux fused collective comminucation gemm kernels + "VLLM_USE_FLUX": + lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))), } # end-env-vars-definition