Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-zhwang committed Oct 25, 2024
1 parent ffd209b commit 357ee52
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
4 changes: 2 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def __init__(self,
enforce_eager: Optional[bool] = None,
max_context_len_to_capture: Optional[int] = None,
max_seq_len_to_capture: Optional[int] = None,
max_batch_size_to_capture: Optional[int] = None,
max_batchsize_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
Expand Down Expand Up @@ -166,7 +166,7 @@ def __init__(self,
raise ValueError("`max_context_len_to_capture` is deprecated. "
"Use `max_seq_len_to_capture` instead.")
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_batch_size_to_capture = max_batch_size_to_capture
self.max_batchsize_to_capture = max_batchsize_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.skip_tokenizer_init = skip_tokenizer_init
Expand Down
7 changes: 4 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ class EngineArgs:
enforce_eager: Optional[bool] = None
max_context_len_to_capture: Optional[int] = None
max_seq_len_to_capture: int = 8192
max_batch_size_to_capture: Optional[int] = None
max_batchsize_to_capture: Optional[int] = None
disable_custom_all_reduce: bool = False
tokenizer_pool_size: int = 0
# Note: Specifying a tokenizer pool by passing a class
Expand Down Expand Up @@ -515,9 +515,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
parser.add_argument('--max-batch-size-to-capture',
parser.add_argument('--max-batchsize-to-capture',
type=int,
default=EngineArgs.max_batch_size_to_capture,
default=EngineArgs.max_batchsize_to_capture,
help='Maximum batch size covered by CUDA '
'graphs. When the batch size is larger than '
'this, we fall back to eager mode. ')
Expand Down Expand Up @@ -889,6 +889,7 @@ def create_model_config(self) -> ModelConfig:
enforce_eager=self.enforce_eager,
max_context_len_to_capture=self.max_context_len_to_capture,
max_seq_len_to_capture=self.max_seq_len_to_capture,
max_batchsize_to_capture=self.max_batchsize_to_capture,
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
Expand Down
12 changes: 6 additions & 6 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,7 @@ def __init__(
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
self.max_batchsize_to_capture = _get_max_graph_batch_size(
self.scheduler_config.max_num_seqs,
self.model_config.max_batch_size_to_capture)
self.model_config.max_batchsize_to_capture)

self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [
{} for _ in range(self.parallel_config.pipeline_parallel_size)
Expand Down Expand Up @@ -1907,7 +1907,7 @@ def _get_graph_batch_size(batch_size: int) -> int:


def _get_max_graph_batch_size(max_num_seqs: int,
max_batch_size_to_capture: Optional[int]) -> int:
max_batchsize_to_capture: Optional[int]) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
Expand All @@ -1919,10 +1919,10 @@ def _get_max_graph_batch_size(max_num_seqs: int,
if not, it means the padded size is larger than the largest size in
_BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE.
"""
if max_batch_size_to_capture is None:
max_batch_size_to_capture = max_num_seqs
max_batch_size_to_capture = min(max_batch_size_to_capture, max_num_seqs)
padded_size = _get_graph_batch_size(max_batch_size_to_capture)
if max_batchsize_to_capture is None:
max_batchsize_to_capture = max_num_seqs
max_batchsize_to_capture = min(max_batchsize_to_capture, max_num_seqs)
padded_size = _get_graph_batch_size(max_batchsize_to_capture)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
Expand Down

0 comments on commit 357ee52

Please sign in to comment.