Skip to content

Commit bb82e26

Browse files
committed
rm/fuse element-wise ops in ds-r1-fp8 model.
Signed-off-by: Fanrong Li <[email protected]>
1 parent be6d92f commit bb82e26

File tree

4 files changed

+81
-50
lines changed

4 files changed

+81
-50
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,16 @@ def __init__(
279279
self.routed_scaling_factor = routed_scaling_factor
280280
self.is_fused = is_fused
281281

282-
def noaux_tc(self, logits, e_score_correction_bias):
283-
n_group = self.n_group
282+
@torch.compile(options={"max-autotune": True})
283+
def get_scores(self, logits, e_score_correction_bias):
284284
scores = F.sigmoid(logits)
285285
scores_with_bias = scores + e_score_correction_bias
286+
return scores, scores_with_bias
287+
288+
def noaux_tc(self, logits, e_score_correction_bias):
289+
n_group = self.n_group
290+
scores, scores_with_bias = self.get_scores(logits,
291+
e_score_correction_bias)
286292
scores_shape = list(scores_with_bias.shape)
287293

288294
if enable_llm_debug():

tensorrt_llm/_torch/modules/attention.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -555,33 +555,7 @@ def fp8_block_scaling_bmm_out(
555555
torch.ops.trtllm.fp8_block_scaling_bmm_out(mat1_fp8, mat2_fp8,
556556
mat1_scale, mat2_scale, out)
557557
elif sm_version == 100:
558-
output = torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2))
559-
out.copy_(output)
560-
561-
# low_latency = True
562-
# use_deep_seek_fp8 = True
563-
# tile_size = 8
564-
# epilogue_tile_m = 64 if use_deep_seek_fp8 else 128
565-
# m_size = mat1.shape[0]
566-
# if m_size % tile_size != 0:
567-
# tiled_shape = ((m_size + tile_size - 1) // tile_size) * tile_size
568-
# mat1 = torch.nn.functional.pad(
569-
# mat1, (0, 0, 0, 0, 0, tiled_shape - m_size), "constant", 0)
570-
571-
# mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
572-
# mat1)
573-
# output, output_sf = torch.ops.trtllm.fp8_batched_gemm_trtllmgen(
574-
# mat1_fp8,
575-
# mat2_fp8,
576-
# tile_size=tile_size,
577-
# epilogue_tile_m=epilogue_tile_m,
578-
# use_deep_seek_fp8=use_deep_seek_fp8,
579-
# low_latency=low_latency,
580-
# dq_sfs_a=mat1_scale.reshape(mat1.shape[-1] // 128, -1),
581-
# dq_sfs_b=mat2_scale,
582-
# out_dtype=out.dtype,
583-
# )
584-
# out.copy_(output[:, :m_size])
558+
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
585559
else:
586560
raise NotImplementedError(f"SM{sm_version} is not supported")
587561

tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py

Lines changed: 71 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Dict, List, Optional, Union
22

33
import torch
4-
import torch.nn.functional as F
54
import triton
65
import triton.language as tl
76

@@ -216,30 +215,82 @@ def triton_masked_index_gather(output, input, start_offsets, row_indices):
216215
return
217216

218217

219-
@nvtx_range("[DG] act")
220-
@torch.compile(dynamic=True)
221-
def swiglu_fused_moe(x):
222-
x, gate = x.chunk(2, dim=-1)
223-
return F.silu(gate) * x
224-
225-
226-
@nvtx_range("[DG] indexing")
227-
@torch.compile(dynamic=True)
228-
def indexing(x, mask):
229-
return x[mask > 0, :].contiguous()
218+
@triton.jit
219+
def _preprocess_after_permute_kernel(
220+
expert_offsets_ptr,
221+
masked_m_ptr,
222+
token_map_ptr,
223+
TOTAL_TOKENS: tl.constexpr,
224+
NUM_EXPERTS: tl.constexpr,
225+
BLOCK_SIZE_M: tl.constexpr,
226+
):
227+
pid_x = tl.program_id(0)
228+
pid_y = tl.program_id(1)
229+
230+
if pid_y == 0:
231+
token_offsets = pid_x * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
232+
token_mask = token_offsets < TOTAL_TOKENS
233+
# get expert_id for each token in the block
234+
expert_ids = tl.full((BLOCK_SIZE_M, ), NUM_EXPERTS - 1, dtype=tl.int32)
235+
found_mask = tl.zeros((BLOCK_SIZE_M, ), dtype=tl.int1)
236+
for i in tl.static_range(NUM_EXPERTS):
237+
boundary = tl.load(expert_offsets_ptr + i + 1)
238+
cond = (token_offsets < boundary) & ~found_mask
239+
expert_ids = tl.where(cond, i, expert_ids)
240+
found_mask = found_mask | cond
241+
tl.store(token_map_ptr + token_offsets, expert_ids, mask=token_mask)
242+
243+
elif pid_y == 1:
244+
# get num_tokens for each expert
245+
expert_mask = pid_x < NUM_EXPERTS
246+
next_offset = tl.load(expert_offsets_ptr + pid_x + 1,
247+
mask=expert_mask,
248+
other=0)
249+
current_offset = tl.load(expert_offsets_ptr + pid_x,
250+
mask=expert_mask,
251+
other=0)
252+
tokens_per_expert = next_offset - current_offset
253+
tl.store(masked_m_ptr + pid_x,
254+
tokens_per_expert.to(tl.int32),
255+
mask=expert_mask)
230256

231257

232258
@nvtx_range("[DG] preprocess_after_permute")
233259
def preprocess_after_permute(expert_first_token_offset_tensor,
234260
permuted_data_tensor):
235-
# get tokens per expert
236-
masked_m = expert_first_token_offset_tensor[
237-
1:] - expert_first_token_offset_tensor[:-1]
238-
token_to_expert_map = torch.searchsorted(
239-
expert_first_token_offset_tensor[1:],
240-
torch.arange(permuted_data_tensor.shape[0], device='cuda'),
241-
right=True)
242-
return masked_m.to(torch.int32), token_to_expert_map
261+
"""
262+
Python wrapper that launches a single fused kernel to get the token-to-expert map
263+
and the number of tokens per expert.
264+
"""
265+
total_tokens = permuted_data_tensor.shape[0]
266+
num_experts = expert_first_token_offset_tensor.shape[0] - 1
267+
268+
# create output tensors
269+
masked_m = torch.empty(num_experts, dtype=torch.int32, device='cuda')
270+
token_to_expert_map = torch.empty(total_tokens,
271+
dtype=torch.int32,
272+
device='cuda')
273+
274+
# calculate the grid size
275+
DEFAULT_BLOCK_SIZE_M = 256
276+
grid_m_size = triton.cdiv(total_tokens, DEFAULT_BLOCK_SIZE_M)
277+
if grid_m_size >= num_experts:
278+
BLOCK_SIZE_M = DEFAULT_BLOCK_SIZE_M
279+
grid = (grid_m_size, 2)
280+
else:
281+
BLOCK_SIZE_M = triton.cdiv(total_tokens, num_experts)
282+
grid = (num_experts, 2)
283+
284+
# launch the kernel
285+
_preprocess_after_permute_kernel[grid](
286+
expert_first_token_offset_tensor,
287+
masked_m,
288+
token_to_expert_map,
289+
TOTAL_TOKENS=total_tokens,
290+
NUM_EXPERTS=num_experts,
291+
BLOCK_SIZE_M=BLOCK_SIZE_M,
292+
)
293+
return masked_m, token_to_expert_map
243294

244295

245296
@nvtx_range("[DG]")

tensorrt_llm/quantization/utils/fp8_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ def per_token_quant_and_transform(
476476
scale_k = ceil_div(k, quant_group_size)
477477
m_padded = align(m, alignment)
478478
scale_k_padded = align(scale_k, alignment)
479-
output_scale = torch.zeros((scale_k_padded // 4, m_padded),
479+
output_scale = torch.empty((scale_k_padded // 4, m_padded),
480480
dtype=torch.int32,
481481
device='cuda')
482482

0 commit comments

Comments
 (0)