Skip to content

Commit 6ee94c7

Browse files
Reintroduce with perf fixes: feature: unify new_tokens format sample state to trtllm samper tokens format (#5513)
58a8a8f - these changes were previously merged to main here. 6aef149 - the changes were temporarily reverted in main, due to a significant perf regression in models using the TorchSampler (observed by @byshiue). This PR is meant to re-merge these changes along with a fix to prevent the regression. The first commit of this PR is actually just the reverted revert - filter it out of the changes to see previously unmerged changes. Signed-off-by: Netanel Haber <[email protected]>
1 parent f28cd30 commit 6ee94c7

File tree

12 files changed

+457
-494
lines changed

12 files changed

+457
-494
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
from torch._prims_common import DeviceLikeType
77

8+
from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager
89
from tensorrt_llm._utils import nvtx_range
910

1011
from ...._utils import mpi_rank, mpi_world_size
@@ -256,6 +257,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
256257
assert isinstance(executor_config.pytorch_backend_config, LlmArgs), msg
257258
ad_config: LlmArgs = executor_config.pytorch_backend_config
258259

260+
max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size
259261
# some derivative properties
260262
max_draft_tokens = (
261263
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_tokens
@@ -272,7 +274,13 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
272274
max_seq_len=ad_config.max_seq_len,
273275
max_batch_size=ad_config.max_batch_size,
274276
)
275-
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
277+
seq_slot_manager = SeqSlotManager(max_num_sequences=max_num_sequences)
278+
resource_manager = ResourceManager(
279+
{
280+
ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager,
281+
ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager,
282+
}
283+
)
276284
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
277285

278286
# scheduling
@@ -287,10 +295,14 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
287295
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
288296
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
289297
# correctly for models as needed.
290-
sampler = TorchSampler(
298+
sampler_args = TorchSampler.Args(
291299
max_seq_len=ad_config.max_seq_len,
300+
max_draft_tokens=max_draft_tokens,
301+
max_num_sequences=max_num_sequences,
302+
max_beam_width=executor_config.max_beam_width,
292303
mixed_sampler=ad_config.mixed_sampler,
293304
)
305+
sampler = TorchSampler(sampler_args)
294306

295307
# creating the executor object
296308
py_executor = PyExecutor(
@@ -299,6 +311,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir:
299311
model_engine=engine,
300312
sampler=sampler,
301313
dist=mpi_dist,
314+
max_num_sequences=max_num_sequences,
302315
disable_overlap_scheduler=ad_config.disable_overlap_scheduler,
303316
max_input_len=ad_config.max_input_len,
304317
max_batch_size=ad_config.max_batch_size,

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@
2626
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
2727
PeftCacheManager, ResourceManager,
2828
ResourceManagerType)
29-
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
30-
TRTLLMSampler)
29+
from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler
3130
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
3231
SimpleScheduler)
3332
from .seq_slot_manager import SeqSlotManager
@@ -514,6 +513,7 @@ def create_py_executor_instance(
514513
sampler=sampler,
515514
drafter=drafter,
516515
dist=dist,
516+
max_num_sequences=max_num_sequences,
517517
disable_overlap_scheduler=pytorch_backend_config.
518518
disable_overlap_scheduler,
519519
max_batch_size=executor_config.max_batch_size,
@@ -525,27 +525,44 @@ def create_py_executor_instance(
525525
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold)
526526

527527

528-
def instantiate_sampler(model_engine: PyTorchModelEngine,
528+
def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping,
529+
*, max_seq_len: int, mixed_sampler: bool):
530+
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
531+
max_draft_tokens = (0 if executor_config.speculative_config is None else
532+
executor_config.speculative_config.max_draft_tokens)
533+
return TorchSampler.Args(
534+
max_seq_len=max_seq_len,
535+
max_draft_tokens=max_draft_tokens,
536+
max_num_sequences=max_num_sequences,
537+
max_beam_width=executor_config.max_beam_width,
538+
mixed_sampler=mixed_sampler,
539+
)
540+
541+
542+
def instantiate_sampler(engine: PyTorchModelEngine,
529543
executor_config: ExecutorConfig,
530544
pytorch_backend_config: PyTorchConfig,
531545
mapping: Mapping):
546+
sampler_args = create_torch_sampler_args(
547+
executor_config,
548+
mapping,
549+
max_seq_len=engine.max_seq_len,
550+
mixed_sampler=pytorch_backend_config.mixed_sampler)
532551
if mapping.cp_config.get('cp_type') == 'star_attention':
533552
assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
534-
return TorchStarAttentionSampler(max_seq_len=model_engine.max_seq_len)
535-
spec_config = model_engine.spec_config
536-
if spec_config is not None and spec_config.spec_dec_mode.has_spec_decoder():
537-
return get_spec_decoder(max_seq_len=model_engine.max_seq_len,
538-
spec_config=spec_config)
553+
return TorchSampler(sampler_args)
554+
if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder(
555+
):
556+
return get_spec_decoder(sampler_args, engine.spec_config)
539557
if pytorch_backend_config.enable_trtllm_sampler:
540-
return TRTLLMSampler(executor_config, model_engine.model,
541-
model_engine.dtype, mapping,
542-
get_decoding_mode(executor_config),
558+
decoding_mode = get_decoding_mode(executor_config)
559+
return TRTLLMSampler(executor_config, engine.model, engine.dtype,
560+
mapping, decoding_mode,
543561
pytorch_backend_config.disable_overlap_scheduler)
544-
elif not model_engine.model.model_config.is_generation:
562+
if not engine.model.model_config.is_generation:
545563
# NOTE: choose sampler based on model type
546564
return EarlyStopSampler()
547-
return TorchSampler(max_seq_len=model_engine.max_seq_len,
548-
mixed_sampler=pytorch_backend_config.mixed_sampler)
565+
return TorchSampler(sampler_args)
549566

550567

551568
def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode:

tensorrt_llm/_torch/pyexecutor/guided_decoder.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import itertools
21
import math
32
from typing import List, Optional
43

@@ -52,8 +51,7 @@ def bitmask_size(self) -> int:
5251

5352
def build(self, scheduled_requests: ScheduledRequests,
5453
resource_manager: SeqSlotManager) -> None:
55-
for llm_req in itertools.chain(scheduled_requests.context_requests,
56-
scheduled_requests.generation_requests):
54+
for llm_req in scheduled_requests.all_requests():
5755
if llm_req.guided_decoding_params is None:
5856
continue
5957
slot = resource_manager.slot_manager.get_slot(llm_req.request_id)
@@ -84,9 +82,7 @@ def execute(self, scheduled_requests: ScheduledRequests,
8482
torch.cuda.current_stream().wait_stream(self._stream)
8583

8684
batched_logits, batched_bitmask = [], []
87-
for i, llm_req in enumerate(
88-
itertools.chain(scheduled_requests.context_requests,
89-
scheduled_requests.generation_requests)):
85+
for i, llm_req in enumerate(scheduled_requests.all_requests()):
9086
if llm_req.guided_decoding_params is None:
9187
continue
9288
if llm_req.is_context_init_state and not llm_req.is_last_context_chunk:

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def __init__(
254254
exclude_last_generation_logits: bool = False,
255255
return_perf_metrics: bool = False,
256256
stop_words_list: list[list[int]] | None = None,
257+
is_draft: bool = False,
257258
**kwargs):
258259
self.py_logits_post_processors = kwargs.pop("py_logits_post_processors",
259260
None)
@@ -288,6 +289,7 @@ def __init__(
288289
self.py_return_context_logits = return_context_logits
289290
self.py_return_generation_logits = return_generation_logits
290291
self.py_return_logits_device_memory = return_logits_device_memory
292+
self.py_is_draft = is_draft
291293

292294
# TODO: remove this when use DynamicDecodeOp in pytorch flow.
293295
# currently, keep py_stop_words_list as python list, rather than tensor.

0 commit comments

Comments
 (0)