From 518399909bdd31b9169a9af3f154b0c9699c4f49 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 1 Nov 2024 22:00:19 +0000 Subject: [PATCH] add flux support Signed-off-by: Bill Nell --- vllm/compilation/collective_fusion.py | 126 +++++------------- .../device_communicators/pynccl_wrapper.py | 1 + vllm/envs.py | 2 +- 3 files changed, 37 insertions(+), 92 deletions(-) diff --git a/vllm/compilation/collective_fusion.py b/vllm/compilation/collective_fusion.py index 80668039a9cf7..816227c2a4857 100644 --- a/vllm/compilation/collective_fusion.py +++ b/vllm/compilation/collective_fusion.py @@ -7,27 +7,27 @@ fwd_only, register_replacement) import vllm.envs as envs - from vllm.compilation.inductor_pass import InductorPass from vllm.compilation.utils import (find_auto_fn, find_fn, find_getitem, last_node_in_match) -from vllm.distributed import tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather +from vllm.distributed import (tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import ( - get_tp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) from vllm.logger import init_logger +logger = init_logger(__name__) + use_flux = False if envs.VLLM_USE_FLUX: try: import flux use_flux = True - print("USE FLUX") + logger.info("USING FLUX") except ImportError: use_flux = False - -logger = init_logger(__name__) - # TODO: factor out somehow TP_GROUP_NAME = "tp:0" @@ -79,71 +79,6 @@ def match_gemm_rs_ag_gemm( return mm_2, new_residual -@torch.library.custom_op("vllm::gemm_rs_ag_gemm_old", mutates_args=()) -def gemm_rs_ag_gemm( - residual: torch.Tensor, old_my_residual: torch.Tensor, - gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, - rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - - if should_slice(residual.shape) and first_layer: - res_slices = slice_residual(residual) - slice_size = res_slices[get_tensor_model_parallel_rank()].shape[0] - residual_chunk = torch.ops.aten.split.Tensor(residual, slice_size) - my_residual = residual_chunk[0] - else: - my_residual = residual - slice_size = residual.shape[0] - - gemm_1_w_perm = torch.ops.aten.permute.default(gemm_1_weights, [1, 0]) - - if not should_slice(residual.shape): - output = torch.matmul(gemm_1_activations, gemm_1_w_perm) - reduced_output = tensor_model_parallel_all_reduce(output) - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=reduced_output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - mm_2 = torch.matmul(normalized, gemm_2_w_perm) - return mm_2, new_residual, new_residual.clone() - else: - group_name = get_world_name() - output = torch.ops.symm_mem.fused_matmul_reduce_scatter.default( - gemm_1_activations, gemm_1_w_perm, 'avg', 0, group_name) - - norm_res = torch.ops.higher_order.auto_functionalized( - torch.ops._C.fused_add_rms_norm.default, - input=output, - residual=my_residual, - weight=rms_norm_weight, - epsilon=1e-05) - normalized = norm_res[1] - new_residual = norm_res[2] - - residual_1 = residual if first_layer else old_my_residual - slice_scatter = torch.ops.aten.slice_scatter.default( - residual_1, new_residual, 0, 0, slice_size) - split_2 = torch.ops.aten.split.Tensor(slice_scatter, slice_size) - new_residual = split_2[0] - - gemm_2_w_perm = torch.ops.aten.permute.default(gemm_2_weights, [1, 0]) - fused_all_gather_matmul = ( - torch.ops.symm_mem.fused_all_gather_matmul.default( - normalized, [gemm_2_w_perm], 0, group_name)) - mm_2 = fused_all_gather_matmul[1] - - # TODO: can we avoid clone here? - return mm_2[0], new_residual.clone(), slice_scatter - - -#@torch.library.register_fake("vllm::gemm_rs_ag_gemm") def gemm_rs_ag_gemm_fake( residual: torch.Tensor, my_residual: torch.Tensor, @@ -171,7 +106,8 @@ def gemm_rs_ag_gemm_fake( return (mm_res, my_residual, residual) -def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weights: torch.Size): +def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, + gemm_2_weights: torch.Size): if use_flux: gemm_rs_op = flux.GemmRS( @@ -214,15 +150,18 @@ def get_gemm_rs_ag_gemm(use_flux: bool, gemm_1_weights: torch.Size, gemm_2_weigh gemm_rs = lambda act, wt: gemm_rs_op.forward(act, wt).squeeze(0) ag_gemm = lambda act, wt: ag_gemm_op.forward(act, wt) - name = f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_{gemm_2_weights[0]}_{gemm_2_weights[1]}" + name = (f"gemm_rs_ag_gemm_{gemm_1_weights[0]}_" + f"{gemm_2_weights[0]}_{gemm_2_weights[1]}") else: group_name = get_world_name() - gemm_rs = lambda act, wt: torch.ops.symm_mem.fused_matmul_reduce_scatter.default( - act, wt.transpose(1,0), 'avg', 0, group_name) + gemm_rs = lambda act, wt: \ + torch.ops.symm_mem.fused_matmul_reduce_scatter.default( + act, wt.transpose(1, 0), 'avg', 0, group_name) - ag_gemm = lambda act, wt: torch.ops.symm_mem.fused_all_gather_matmul.default( - act, [wt.transpose(1,0)], 0, group_name)[1] + ag_gemm = lambda act, wt: \ + torch.ops.symm_mem.fused_all_gather_matmul.default( + act, [wt.transpose(1, 0)], 0, group_name)[1] name = "gemm_rs_ag_gemm" @@ -230,7 +169,8 @@ def gemm_rs_ag_gemm( residual: torch.Tensor, old_my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weight: torch.Tensor, gemm_2_weights: torch.Tensor, - first_layer: bool) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + first_layer: bool + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: if should_slice(residual.shape) and first_layer: res_slices = slice_residual(residual) @@ -242,7 +182,8 @@ def gemm_rs_ag_gemm( slice_size = residual.shape[0] if not should_slice(residual.shape): - output = torch.matmul(gemm_1_activations, gemm_1_weights.transpose(1,0)) + output = torch.matmul(gemm_1_activations, + gemm_1_weights.transpose(1, 0)) reduced_output = tensor_model_parallel_all_reduce(output) norm_res = torch.ops.higher_order.auto_functionalized( @@ -254,10 +195,9 @@ def gemm_rs_ag_gemm( normalized = norm_res[1] new_residual = norm_res[2] - mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1,0)) + mm_2 = torch.matmul(normalized, gemm_2_weights.transpose(1, 0)) return mm_2, new_residual, new_residual.clone() else: - group_name = get_world_name() output = gemm_rs(gemm_1_activations, gemm_1_weights) norm_res = torch.ops.higher_order.auto_functionalized( @@ -282,7 +222,9 @@ def gemm_rs_ag_gemm( if not hasattr(torch.ops.vllm, name): logger.info("registering torch.ops.vllm.%s", name) - torch.library.custom_op(f"vllm::{name}", gemm_rs_ag_gemm, mutates_args=()) + torch.library.custom_op(f"vllm::{name}", + gemm_rs_ag_gemm, + mutates_args=()) torch.library.register_fake(f"vllm::{name}", gemm_rs_ag_gemm_fake) assert getattr(torch.ops.vllm, name) @@ -311,6 +253,7 @@ def match_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, return normalized + # Register this as a custom op since all reduce cannot be torch.compiled. #@torch.library.custom_op("vllm::gemm_ag_final", mutates_args=()) def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, @@ -329,8 +272,9 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, if True: group_name = get_world_name() world_size = get_tensor_model_parallel_world_size() - all_gather = torch.ops._c10d_functional.all_gather_into_tensor.default( - my_residual, world_size, group_name) + all_gather = ( + torch.ops._c10d_functional.all_gather_into_tensor.default( + my_residual, world_size, group_name)) wait_tensor = torch.ops._c10d_functional.wait_tensor.default( all_gather) else: @@ -352,9 +296,9 @@ def replace_final(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, def replace_final_fake(my_residual: torch.Tensor, gemm_1_weights: torch.Tensor, gemm_1_activations: torch.Tensor, rms_norm_weights: torch.Tensor) -> torch.Tensor: - return torch.empty([gemm_1_activations.shape[0], - my_residual.shape[1]], - dtype=my_residual.dtype, device=my_residual.device) + return torch.empty([gemm_1_activations.shape[0], my_residual.shape[1]], + dtype=my_residual.dtype, + device=my_residual.device) class CollectiveFusionPass(InductorPass): @@ -421,9 +365,9 @@ def find_min_index(match: Match) -> int: gemm_1_w = kwargs["gemm_1_weights"].meta["val"].shape gemm_2_w = kwargs["gemm_2_weights"].meta["val"].shape - fused_node = graph.call_function( - get_gemm_rs_ag_gemm(use_flux, gemm_1_w, gemm_2_w), - kwargs=kwargs) + fused_node = graph.call_function(get_gemm_rs_ag_gemm( + use_flux, gemm_1_w, gemm_2_w), + kwargs=kwargs) graph.inserting_after(fused_node) result_node_new = graph.call_function(operator.getitem, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index dbaa8b6b34716..2c5a1dbe1b389 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -292,6 +292,7 @@ def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, datatype, comm, stream)) + __all__ = [ "NCCLLibrary", "ncclDataTypeEnum", "ncclRedOpTypeEnum", "ncclUniqueId", "ncclComm_t", "cudaStream_t", "buffer_type" diff --git a/vllm/envs.py b/vllm/envs.py index 35cb7439dbd8f..bbaf9aa26c7b8 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -467,7 +467,7 @@ def get_default_config_root(): "VLLM_USE_V1": lambda: bool(int(os.getenv("VLLM_USE_V1", "0"))), - # If set, try to use the flux fused collective comminucation gemm kernels + # If set, try to use the flux fused collective communication gemm kernels "VLLM_USE_FLUX": lambda: bool(int(os.getenv("VLLM_USE_FLUX", "0"))), }