diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 56274f875f4..2e257c306ae 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -110,13 +110,11 @@ def __init__( assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - if self.use_dp: - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < max_num_tokens: + if self.moe_max_num_tokens < moe_max_num_tokens: self.aux_stream = aux_stream_dict[ AuxStreamType. MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index fcfdeb0fad2..a5ca05694b9 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -11,7 +11,7 @@ from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import AuxStreamType, Fp4QuantizedTensor +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, MoEWeightLoadingMode, UnquantizedFusedMoEMethod) @@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8( def masked_index_copy_group_quant_fp8( output: torch.Tensor, + output_s: torch.Tensor, input: torch.Tensor, start_offsets: torch.Tensor, row_indices: torch.Tensor, @@ -108,14 +109,10 @@ def masked_index_copy_group_quant_fp8( col_size = output.shape[1] dim_size = output.shape[2] - # create padded output_s alignment = 4 scale_dim = (dim_size + group_size - 1) // group_size padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment padded_col_size = (col_size + alignment - 1) // alignment * alignment - output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size), - dtype=torch.int32, - device='cuda') # get block/grid/stage/warp num_groups = (dim_size + group_size - 1) // group_size @@ -247,6 +244,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor, @nvtx_range("[DG]") def deepgemm_fp8_group_blockwise_gemm( + d: torch.Tensor, a: torch.Tensor, b: torch.Tensor, sfa: torch.Tensor, @@ -254,10 +252,6 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: - d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), - device=b.device, - dtype=torch.bfloat16) - # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` assert a.stride(-1) == 1 assert b.stride(-1) == 1 @@ -287,7 +281,16 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m, expected_m, disable_ue8m0_cast=True) - return d + return + + +def set_strides(workspace: torch.Tensor, g: int, m: int, k: int): + workspace = workspace[0:g * m * k] + workspace = workspace.as_strided( + size=(g, m, k), + stride=(m * k, k, 1), + ) + return workspace class DeepGemmFusedMoE(CutlassFusedMoE): @@ -327,6 +330,18 @@ def __init__( apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, ): + if model_config.moe_max_num_tokens is None: + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + # The default moe_max_num_tokens is calculated from the following formula: + # max_isl = 8196, max_batch_size = 1024, mtp = 0 + # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344 + # moe_max_num_tokens = max_num_tokens * 2 = 18688 + # It can avoid OOM for 8k/1k cases. + default_moe_max_num_tokens = 18688 + if moe_max_num_tokens > default_moe_max_num_tokens: + model_config._frozen = False + model_config.moe_max_num_tokens = default_moe_max_num_tokens + model_config._frozen = True super().__init__( routing_method=routing_method, @@ -342,6 +357,37 @@ def __init__( layer_idx=layer_idx, ) + def get_workspace(self, m_max: int, group_size: int): + hidden_size = self.hidden_size + intermediate_size = self.intermediate_size + num_experts = self.expert_size_per_partition + + # create workspace + fp8_dim = max(hidden_size, intermediate_size) + workspace_0 = torch.empty((num_experts * m_max * fp8_dim), + dtype=torch.float8_e4m3fn, + device='cuda') + workspace_1 = torch.empty( + (num_experts * m_max * max(intermediate_size * 2, hidden_size)), + dtype=torch.bfloat16, + device='cuda') + + # create workspace for scaling factors + m_padded = fp8_utils.align(m_max, 4) + scale_k = fp8_utils.ceil_div(fp8_dim, group_size) + scale_k_padded = fp8_utils.align(scale_k, 4) + workspace_sf = torch.empty( + (num_experts * (scale_k_padded // 4) * m_padded), + dtype=torch.int32, + device='cuda') + + workspace = { + "workspace_0": workspace_0, + "workspace_1": workspace_1, + "workspace_sf": workspace_sf, + } + return workspace + def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( exclude_kv_cache=True): @@ -362,6 +408,7 @@ def forward_chunk( output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, use_dp_padding: Optional[bool] = None, + workspace: Optional[dict] = None, ) -> torch.Tensor: if isinstance(x, Fp4QuantizedTensor): assert output_dtype is not None @@ -437,22 +484,38 @@ def forward_chunk( masked_m, token_to_expert_map = preprocess_after_permute( expert_first_token_offset_tensor, permuted_data_tensor) - m_max = (x.shape[0] + 127) // 128 * 128 expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition - act_input_fp8 = torch.empty( - (self.expert_size_per_partition, m_max, self.hidden_size), - dtype=torch.float8_e4m3fn, - device='cuda') + + # padding and quantization + m_max = fp8_utils.align(x.shape[0], 128) + act_input_fp8 = set_strides(workspace["workspace_0"], + self.expert_size_per_partition, m_max, + self.hidden_size) + + m_padded = fp8_utils.align(m_max, 4) + scale_k = fp8_utils.ceil_div(self.hidden_size, 128) + scale_k_padded = fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + self.expert_size_per_partition, + scale_k_padded // 4, m_padded) + act_input_sf = masked_index_copy_group_quant_fp8( act_input_fp8, + act_input_sf, permuted_data_tensor, expert_first_token_offset_tensor, token_to_expert_map, group_size=128) - h1 = deepgemm_fp8_group_blockwise_gemm( + # grouped gemm 1 + h1 = set_strides(workspace["workspace_1"], + self.expert_size_per_partition, m_max, + self.intermediate_size * 2) + + deepgemm_fp8_group_blockwise_gemm( + d=h1, a=act_input_fp8, b=self.w3_w1_weight, sfa=act_input_sf, @@ -460,9 +523,33 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( - input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) - h3 = deepgemm_fp8_group_blockwise_gemm( + + # activation and quantization + act_input_fp8 = set_strides(workspace["workspace_0"], + self.expert_size_per_partition, m_max, + self.intermediate_size) + + scale_k = fp8_utils.ceil_div(self.intermediate_size, 128) + scale_k_padded = fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + self.expert_size_per_partition, + scale_k_padded // 4, m_padded) + + act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + output=act_input_fp8, + output_scale=act_input_sf, + input=h1, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) + + # grouped gemm 2 + h3 = set_strides(workspace["workspace_1"], + self.expert_size_per_partition, m_max, + self.hidden_size) + + deepgemm_fp8_group_blockwise_gemm( + d=h3, a=act_input_fp8, b=self.w2_weight, sfa=act_input_sf, @@ -471,6 +558,7 @@ def forward_chunk( expected_m=expected_m, ) + # gather and finalize triton_masked_index_gather(permuted_data_tensor, h3, expert_first_token_offset_tensor, token_to_expert_map) @@ -495,3 +583,137 @@ def forward_chunk( ) return final_hidden_states + + def forward( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + do_finalize: bool = True, # used by other MoE backends + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + all_rank_max_num_tokens: Optional[int] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + assert do_finalize, "CutlassFusedMoE does not support do_finalize=False" + if self.use_dp and self.parallel_size > 1: + assert all_rank_num_tokens is not None + assert use_dp_padding is not None + num_rows = sum(all_rank_num_tokens) + else: + num_rows = x.shape[0] + + # In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks. + # Because we will use two streams in chunked moe and preallocate two workspaces. + num_chunks = 1 + if num_rows > self.moe_max_num_tokens * 2: + num_chunks = (num_rows + self.moe_max_num_tokens - + 1) // self.moe_max_num_tokens + + if use_dp_padding: + all_rank_num_tokens_padded = [all_rank_max_num_tokens + ] * len(all_rank_num_tokens) + else: + all_rank_num_tokens_padded = all_rank_num_tokens + + if num_chunks == 1: + # create workspace + num_rows = x.shape[0] + if self.use_dp: + num_rows = sum(all_rank_num_tokens_padded) + m_max = fp8_utils.align(num_rows, 128) + workspace = self.get_workspace(m_max, 128) + outputs = self.forward_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding, + workspace=workspace) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding) + else: + if self.use_dp: + all_rank_chunk_size_list = [ + self.split_chunk(val, num_chunks) + for val in all_rank_num_tokens_padded + ] + all_rank_num_tokens_list = [[ + val[idx_chunk] for val in all_rank_chunk_size_list + ] for idx_chunk in range(num_chunks)] + chunk_size_list = all_rank_chunk_size_list[self.rank] + else: + all_rank_num_tokens_list = [None] * num_chunks + chunk_size_list = self.split_chunk(x.shape[0], num_chunks) + + # create workspace + chunk_size_0 = sum(all_rank_num_tokens_list[0] + ) if self.use_dp else chunk_size_list[0] + chunk_size_1 = sum(all_rank_num_tokens_list[1] + ) if self.use_dp else chunk_size_list[1] + workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128), + 128) + workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128), + 128) + + x_list = x.split(chunk_size_list) + router_logits_list = router_logits.split(chunk_size_list) + + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + + def _forward_chunk(x_, router_logits_, idx, workspace): + return self.forward_chunk( + x_, + router_logits_, + all_rank_num_tokens=all_rank_num_tokens_list[idx] + if self.use_dp else None, + use_dp_padding=use_dp_padding, + workspace=workspace) + + def _reducescatter_or_allreduce(x_, idx): + return self.reducescatter_or_allreduce( + x_, + all_rank_num_tokens=all_rank_num_tokens_list[idx], + use_dp_padding=use_dp_padding) + + outputs_list = [] + # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap + for idx_chunk, (x, router_logits) in enumerate( + zip(x_list, router_logits_list)): + + if idx_chunk % 2 == 0: + with torch.cuda.stream(self.aux_stream): + outputs = _forward_chunk(x, router_logits, idx_chunk, + workspace_0) + if idx_chunk > 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + else: + outputs = _forward_chunk(x, router_logits, idx_chunk, + workspace_1) + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + + outputs_list.append(outputs) + + if num_chunks % 2 == 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + else: + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeChunkingOverlap].record() + self.event_dict[EventType.MoeChunkingOverlap].wait() + + outputs = torch.cat(outputs_list) + + if self.use_dp and self.parallel_size > 1: + rank = self.mapping.tp_rank + outputs = outputs[:all_rank_num_tokens[rank]] + return outputs diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index 3249bac979b..ed6f11993b2 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -81,13 +81,9 @@ def __init__( self.num_experts) self.expert_size_per_partition = self.expert_end - self.expert_start - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - if self.use_dp: - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = (model_config.moe_max_num_tokens - if model_config.moe_max_num_tokens - is not None else max_num_tokens) + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens self._weights_created = False if not model_config.skip_create_weights_in_init: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index c74c8966b68..4eb9d77606b 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -150,12 +150,11 @@ def __init__( assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < max_num_tokens: + if self.moe_max_num_tokens < moe_max_num_tokens: self.aux_stream = aux_stream_dict[ AuxStreamType. MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index f78fe093f71..22824ea350d 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -372,6 +372,10 @@ def node_rank(self): def local_rank(self): return self.rank % self.gpus_per_node + @property + def dp_size(self): + return self.tp_size if self.enable_attention_dp else 1 + def has_cp(self): return self.cp_size > 1 diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 19bd24671dd..4c486a15115 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -302,6 +302,8 @@ def _silu_and_mul_post_quant_kernel( def silu_and_mul_masked_post_quant_fwd( + output: torch.Tensor, + output_scale: torch.Tensor, input: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, @@ -328,18 +330,6 @@ def silu_and_mul_masked_post_quant_fwd( g, m, k = input.shape k = k // 2 - # Create output - output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") - - # Create output scale - alignment = 4 - scale_k = ceil_div(k, quant_group_size) - m_padded = align(m, alignment) - scale_k_padded = align(scale_k, alignment) - output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), - dtype=torch.int32, - device='cuda') - # Get block/grid/stage/warp expert_num = len(masked_m) @@ -382,7 +372,7 @@ def silu_and_mul_masked_post_quant_fwd( g, tma_stride_check=True, ) - return output, output_scale + return output_scale @triton.jit