diff --git a/vllm/config.py b/vllm/config.py index 81e571444ed30..d98cde335c689 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -134,7 +134,7 @@ def __init__(self, enforce_eager: Optional[bool] = None, max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, - max_batch_size_to_capture: Optional[int] = None, + max_batchsize_to_capture: Optional[int] = None, max_logprobs: int = 20, disable_sliding_window: bool = False, skip_tokenizer_init: bool = False, @@ -166,7 +166,7 @@ def __init__(self, raise ValueError("`max_context_len_to_capture` is deprecated. " "Use `max_seq_len_to_capture` instead.") self.max_seq_len_to_capture = max_seq_len_to_capture - self.max_batch_size_to_capture = max_batch_size_to_capture + self.max_batchsize_to_capture = max_batchsize_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window self.skip_tokenizer_init = skip_tokenizer_init diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 251392591ca98..f1e220e34f907 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -128,7 +128,7 @@ class EngineArgs: enforce_eager: Optional[bool] = None max_context_len_to_capture: Optional[int] = None max_seq_len_to_capture: int = 8192 - max_batch_size_to_capture: Optional[int] = None + max_batchsize_to_capture: Optional[int] = None disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 # Note: Specifying a tokenizer pool by passing a class @@ -515,9 +515,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Additionally for encoder-decoder models, if the ' 'sequence length of the encoder input is larger ' 'than this, we fall back to the eager mode.') - parser.add_argument('--max-batch-size-to-capture', + parser.add_argument('--max-batchsize-to-capture', type=int, - default=EngineArgs.max_batch_size_to_capture, + default=EngineArgs.max_batchsize_to_capture, help='Maximum batch size covered by CUDA ' 'graphs. When the batch size is larger than ' 'this, we fall back to eager mode. ') @@ -889,6 +889,7 @@ def create_model_config(self) -> ModelConfig: enforce_eager=self.enforce_eager, max_context_len_to_capture=self.max_context_len_to_capture, max_seq_len_to_capture=self.max_seq_len_to_capture, + max_batchsize_to_capture=self.max_batchsize_to_capture, max_logprobs=self.max_logprobs, disable_sliding_window=self.disable_sliding_window, skip_tokenizer_init=self.skip_tokenizer_init, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 84377b5dfd2b8..0f40c3b573db9 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -980,7 +980,7 @@ def __init__( self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture self.max_batchsize_to_capture = _get_max_graph_batch_size( self.scheduler_config.max_num_seqs, - self.model_config.max_batch_size_to_capture) + self.model_config.max_batchsize_to_capture) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) @@ -1907,7 +1907,7 @@ def _get_graph_batch_size(batch_size: int) -> int: def _get_max_graph_batch_size(max_num_seqs: int, - max_batch_size_to_capture: Optional[int]) -> int: + max_batchsize_to_capture: Optional[int]) -> int: """ max_num_seqs: Maximum number of sequences in a batch. _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. @@ -1919,10 +1919,10 @@ def _get_max_graph_batch_size(max_num_seqs: int, if not, it means the padded size is larger than the largest size in _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. """ - if max_batch_size_to_capture is None: - max_batch_size_to_capture = max_num_seqs - max_batch_size_to_capture = min(max_batch_size_to_capture, max_num_seqs) - padded_size = _get_graph_batch_size(max_batch_size_to_capture) + if max_batchsize_to_capture is None: + max_batchsize_to_capture = max_num_seqs + max_batchsize_to_capture = min(max_batchsize_to_capture, max_num_seqs) + padded_size = _get_graph_batch_size(max_batchsize_to_capture) if padded_size in _BATCH_SIZES_TO_CAPTURE: return padded_size assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]