2626from .resource_manager import (KVCacheManager , MambaHybridCacheManager ,
2727 PeftCacheManager , ResourceManager ,
2828 ResourceManagerType )
29- from .sampler import (EarlyStopSampler , TorchSampler , TorchStarAttentionSampler ,
30- TRTLLMSampler )
29+ from .sampler import EarlyStopSampler , TorchSampler , TRTLLMSampler
3130from .scheduler import (BindCapacityScheduler , BindMicroBatchScheduler ,
3231 SimpleScheduler )
3332from .seq_slot_manager import SeqSlotManager
@@ -514,6 +513,7 @@ def create_py_executor_instance(
514513 sampler = sampler ,
515514 drafter = drafter ,
516515 dist = dist ,
516+ max_num_sequences = max_num_sequences ,
517517 disable_overlap_scheduler = pytorch_backend_config .
518518 disable_overlap_scheduler ,
519519 max_batch_size = executor_config .max_batch_size ,
@@ -525,27 +525,44 @@ def create_py_executor_instance(
525525 garbage_collection_gen0_threshold = garbage_collection_gen0_threshold )
526526
527527
528- def instantiate_sampler (model_engine : PyTorchModelEngine ,
528+ def create_torch_sampler_args (executor_config : ExecutorConfig , mapping : Mapping ,
529+ * , max_seq_len : int , mixed_sampler : bool ):
530+ max_num_sequences = executor_config .max_batch_size * mapping .pp_size
531+ max_draft_tokens = (0 if executor_config .speculative_config is None else
532+ executor_config .speculative_config .max_draft_tokens )
533+ return TorchSampler .Args (
534+ max_seq_len = max_seq_len ,
535+ max_draft_tokens = max_draft_tokens ,
536+ max_num_sequences = max_num_sequences ,
537+ max_beam_width = executor_config .max_beam_width ,
538+ mixed_sampler = mixed_sampler ,
539+ )
540+
541+
542+ def instantiate_sampler (engine : PyTorchModelEngine ,
529543 executor_config : ExecutorConfig ,
530544 pytorch_backend_config : PyTorchConfig ,
531545 mapping : Mapping ):
546+ sampler_args = create_torch_sampler_args (
547+ executor_config ,
548+ mapping ,
549+ max_seq_len = engine .max_seq_len ,
550+ mixed_sampler = pytorch_backend_config .mixed_sampler )
532551 if mapping .cp_config .get ('cp_type' ) == 'star_attention' :
533552 assert pytorch_backend_config .attn_backend == "FLASHINFER_STAR_ATTENTION" , "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'"
534- return TorchStarAttentionSampler (max_seq_len = model_engine .max_seq_len )
535- spec_config = model_engine .spec_config
536- if spec_config is not None and spec_config .spec_dec_mode .has_spec_decoder ():
537- return get_spec_decoder (max_seq_len = model_engine .max_seq_len ,
538- spec_config = spec_config )
553+ return TorchSampler (sampler_args )
554+ if engine .spec_config is not None and engine .spec_config .spec_dec_mode .has_spec_decoder (
555+ ):
556+ return get_spec_decoder (sampler_args , engine .spec_config )
539557 if pytorch_backend_config .enable_trtllm_sampler :
540- return TRTLLMSampler (executor_config , model_engine . model ,
541- model_engine . dtype , mapping ,
542- get_decoding_mode ( executor_config ) ,
558+ decoding_mode = get_decoding_mode (executor_config )
559+ return TRTLLMSampler ( executor_config , engine . model , engine . dtype ,
560+ mapping , decoding_mode ,
543561 pytorch_backend_config .disable_overlap_scheduler )
544- elif not model_engine .model .model_config .is_generation :
562+ if not engine .model .model_config .is_generation :
545563 # NOTE: choose sampler based on model type
546564 return EarlyStopSampler ()
547- return TorchSampler (max_seq_len = model_engine .max_seq_len ,
548- mixed_sampler = pytorch_backend_config .mixed_sampler )
565+ return TorchSampler (sampler_args )
549566
550567
551568def get_decoding_mode (executor_config : ExecutorConfig ) -> DecodingMode :
0 commit comments