Skip to content

Commit 7ceb5e5

Browse files
chang-llfr-0531
andauthored
[TRTLLM-9198][perf] Add torch.compile + multi-stream support for k-cache scatter and weight scaling (#8988)
Signed-off-by: Chang Liu (Enterprise Products) <[email protected]> Co-authored-by: Fanrong Li <[email protected]>
1 parent c61b44e commit 7ceb5e5

File tree

4 files changed

+43
-19
lines changed

4 files changed

+43
-19
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
maybe_execute_in_parallel
1818
from tensorrt_llm._torch.modules.rotary_embedding import RotaryEmbedding
1919
from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager
20+
from tensorrt_llm._torch.utils import maybe_compile
2021
from tensorrt_llm._utils import get_size_in_bytes
2122
from tensorrt_llm.bindings import DataType
2223
from tensorrt_llm.bindings.executor import KvCacheConfig
@@ -572,6 +573,12 @@ def update_for_spec_dec(self):
572573
self.on_update_kv_lens()
573574

574575

576+
@maybe_compile(dynamic=True)
577+
def _scale(weights: torch.Tensor, q_scale: torch.Tensor,
578+
s: float) -> torch.Tensor:
579+
return weights * q_scale.squeeze(-1) * s
580+
581+
575582
class Indexer(nn.Module):
576583

577584
def __init__(self,
@@ -964,9 +971,6 @@ def sparse_attn_indexer(
964971
if not use_custom_topk:
965972
topk_indices_buffer[:hidden_states.shape[0]] = -1
966973

967-
# Store k_fp8 and k_scale into indexer k cache
968-
self._update_k_cache(k_fp8, k_scale, metadata)
969-
970974
if has_prefill:
971975
# Use chunked prefill to reduce memory footprint
972976
if metadata.indexer_prefill_chunks is not None:
@@ -1121,9 +1125,7 @@ def weight_scale(self, hidden_states: torch.Tensor,
11211125
q_scale: torch.Tensor) -> torch.Tensor:
11221126
weights = indexer_weights if indexer_weights is not None else self.weights_proj(
11231127
hidden_states)
1124-
weights = weights.unsqueeze(-1) * q_scale * self.weight_scale_factor
1125-
# output weights is guaranteed to be float32 due to type promotion from q_scale (float32)
1126-
weights = weights.squeeze(-1)
1128+
weights = _scale(weights, q_scale, self.weight_scale_factor)
11271129
return weights
11281130

11291131
@torch.inference_mode()
@@ -1192,7 +1194,15 @@ def _prep_q_or_k(qk_pe, qk_nope):
11921194
q_fp8 = q_fp8.view(-1, self.n_heads, self.head_dim)
11931195
q_scale = q_scale.view(-1, self.n_heads, 1)
11941196

1195-
weights = self.weight_scale(hidden_states, indexer_weights, q_scale)
1197+
weights, _ = maybe_execute_in_parallel(
1198+
lambda: self.weight_scale(hidden_states, indexer_weights, q_scale),
1199+
lambda: self._update_k_cache(
1200+
k_fp8, k_scale, metadata), # store k_fp8 and k_scale in k cache
1201+
self.ln_events[0],
1202+
self.ln_events[1],
1203+
self.aux_stream,
1204+
)
1205+
11961206
# Return topk indices buffer for sparse attention [num_tokens, index_topk]
11971207
return self.sparse_attn_indexer(metadata, hidden_states, q_fp8, k_fp8,
11981208
k_scale, weights)

tensorrt_llm/_torch/modules/attention.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ..model_config import ModelConfig
2424
from ..peft.lora.layer import LoraLayer, LoraModuleType
2525
from ..utils import (Fp4QuantizedTensor, get_model_extra_attrs,
26-
is_piecewise_running, is_torch_compiling)
26+
is_torch_compiling, maybe_compile)
2727
from .linear import Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig
2828
from .multi_stream_utils import maybe_execute_in_parallel
2929
from .rms_norm import RMSNorm
@@ -76,17 +76,6 @@ def extract_extra_attrs(layer_idx: str, attn_type: str):
7676
return metadata, attn_layer
7777

7878

79-
def maybe_compile(func):
80-
81-
def wrapper(*args, **kwargs):
82-
if is_piecewise_running():
83-
# When piecewise running, we don't need to compile the function to avoid host overhead in attention op.
84-
return func(*args, **kwargs)
85-
return torch.compile(func)(*args, **kwargs)
86-
87-
return wrapper
88-
89-
9079
@maybe_compile
9180
def maybe_compiled_copy_(dst, src):
9281
dst.copy_(src)

tensorrt_llm/_torch/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,26 @@ def get_device_uuid(device_idx: int) -> str:
325325
property = torch.cuda.get_device_properties(device_idx)
326326
uuid = "GPU-" + str(property.uuid)
327327
return uuid
328+
329+
330+
def maybe_compile(func=None, **compile_kwargs):
331+
"""
332+
Conditionally compile a function with torch.compile.
333+
If is_piecewise_running() is True, the function will not be compiled to avoid host overhead in attention op.
334+
Args:
335+
func: The function to decorate (optional, for direct decoration).
336+
**compile_kwargs: Keyword arguments for torch.compile.
337+
Returns:
338+
The conditionally compiled function..
339+
"""
340+
341+
def decorator(f):
342+
343+
def wrapper(*args, **kwargs):
344+
if is_piecewise_running():
345+
return f(*args, **kwargs)
346+
return torch.compile(f, **compile_kwargs)(*args, **kwargs)
347+
348+
return wrapper
349+
350+
return decorator(func) if func else decorator

tests/unittest/_torch/attention/sparse/test_dsa_indexer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
11751175
f" Chunk {i}: Q[{chunk.token_start}:{chunk.token_end}] ({num_q} tokens), "
11761176
f"K[{chunk.k_token_start}:{chunk.k_token_end}] ({num_k} tokens)")
11771177

1178+
indexer._update_k_cache(k_fp8, k_scale, metadata_chunked)
11781179
topk_indices_chunked = indexer.sparse_attn_indexer(metadata_chunked,
11791180
hidden_states, q_fp8,
11801181
k_fp8, k_scale, weights)
@@ -1206,6 +1207,7 @@ def test_indexer_chunked_prefill(chunk_size, seq_lens_list, chunking_type):
12061207
f"✓ Created {num_baseline_chunks} chunk(s) (effectively non-chunked)"
12071208
)
12081209

1210+
indexer._update_k_cache(k_fp8, k_scale, metadata_baseline)
12091211
topk_indices_baseline = indexer.sparse_attn_indexer(metadata_baseline,
12101212
hidden_states, q_fp8,
12111213
k_fp8, k_scale, weights)

0 commit comments

Comments
 (0)