diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 7ef502abee345..aa11524812cdd 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -7,7 +7,6 @@ from torch import nn from torch.library import Library -from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, @@ -81,6 +80,7 @@ def test_simple_piecewise_compile(): use_cudagraph=True, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') @@ -96,11 +96,10 @@ def test_simple_piecewise_compile(): 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - with set_compile_context([1, 2]): - model(inputs) + model(inputs) - model(torch.randn(2).cuda()) - model(torch.randn(1).cuda()) + model(torch.randn(2).cuda()) + model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() global global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index dbd5a3bbffeab..07c10a3a18c55 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -13,7 +13,6 @@ from torch import nn from torch.library import Library -from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, @@ -256,6 +255,7 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + cudagraph_capture_sizes=[1, 2], ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] @@ -273,10 +273,9 @@ def run_model(llama_config, input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() positions = torch.arange(B).cuda() - with set_compile_context([1, 2]): - model(input_ids, positions) - model(input_ids[:2], positions[:2]) - model(input_ids[:1], positions[:1]) + model(input_ids, positions) + model(input_ids[:2], positions[:2]) + model(input_ids[:1], positions[:1]) input_ids[:2].zero_() output = model(input_ids[:2], positions[:2]) @@ -379,10 +378,13 @@ def benchmark(): level=CompilationLevel.PIECEWISE, use_cudagraph=True, splitting_ops=["silly.attention"], + cudagraph_capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, ) + level=CompilationLevel.PIECEWISE, + cudagraph_capture_sizes=cudagraph_sizes, + ) vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): @@ -396,17 +398,16 @@ def benchmark(): graphs = {} - with set_compile_context(cudagraph_sizes): - model(input_ids, positions) - for b in cudagraph_sizes[::-1]: - if not piecewise: - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, pool=pool): - output = model(input_ids[:b], positions[:b]) - graphs[b] = (graph, output) - else: + model(input_ids, positions) + for b in cudagraph_sizes[::-1]: + if not piecewise: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=pool): output = model(input_ids[:b], positions[:b]) - graphs[b] = (model, output) + graphs[b] = (graph, output) + else: + output = model(input_ids[:b], positions[:b]) + graphs[b] = (model, output) for b in cudagraph_sizes: if piecewise: # noqa is for `Function definition does not bind loop variable` diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 87a05b3011393..cae25ae9fa2c8 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,8 +1,8 @@ import pytest from tests.utils import multi_gpu_test +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 01e208347bff4..35018c3c14dee 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,8 +5,8 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 9e166ae64dbfb..5289c91f201cd 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -4,12 +4,12 @@ import pytest import torch +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] @@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(expanded_batch_size) + graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 433a9b30ba57a..4055524f3e0c7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,13 +3,14 @@ import pytest import torch +from vllm.config import VllmConfig from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port -from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +from vllm.worker.model_runner import ModelRunner def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: @@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size): model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) - expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) + expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.num_prefills == 0 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..d49a83fe3981f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -242,10 +242,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is updated now, because only here can - # we get the sizes to capture for cudagraph - # from compilation context - self.compilation_configs.init_during_runtime() self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py deleted file mode 100644 index 29db3d4c637b9..0000000000000 --- a/vllm/compilation/compile_context.py +++ /dev/null @@ -1,23 +0,0 @@ -from contextlib import contextmanager -from typing import Any - -_compile_context: Any = None - - -def get_compile_context() -> Any: - """Get the current compile context.""" - return _compile_context - - -@contextmanager -def set_compile_context(context: Any): - """A context manager that stores the current compile context, - usually it is a list of sizes to specialize. - """ - global _compile_context - prev_context = _compile_context - _compile_context = context - try: - yield - finally: - _compile_context = prev_context diff --git a/vllm/config.py b/vllm/config.py index 5f50d65ec87e1..326340d3fa655 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2357,15 +2357,10 @@ def init_backend(self) -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(self) - def init_during_runtime(self): + def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - from vllm.compilation.compile_context import get_compile_context - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context + we need to know the cudagraph sizes.""" + if self.cudagraph_capture_sizes is None: self.capture_sizes = sizes_to_specialize else: @@ -2386,6 +2381,21 @@ def init_during_runtime(self): self.inductor_compile_sizes = [] self.compile_sizes = self.inductor_compile_sizes + # sort to make sure cudagraph capture sizes are in descending order + self.capture_sizes.sort(reverse=True) + + +_BATCH_SIZE_ALIGNMENT = 8 +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. +# NOTE: get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) +] + @dataclass class VllmConfig: @@ -2413,6 +2423,41 @@ class VllmConfig: kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + @staticmethod + def get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + @staticmethod + def get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded + size. if not, it means the padded size is larger than the largest size + in _BATCH_SIZES_TO_CAPTURE, return the largest size in + _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = VllmConfig.get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] + @staticmethod def _get_quantization_config( model_config: ModelConfig, @@ -2496,6 +2541,28 @@ def __post_init__(self): self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE + if not envs.VLLM_USE_V1: + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + max_batchsize_to_capture = \ + self.get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + size for size in _BATCH_SIZES_TO_CAPTURE + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + batch_size_capture_list = [1, 2, 4 + ] + [i for i in range(8, 513, 8)] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + if self.cache_config is not None and \ self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 099ca7e12b288..5d5e8ae1ee532 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -25,8 +25,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .interfaces import HasInnerState, SupportsLoRA from .utils import maybe_prefix @@ -404,7 +402,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ac0d265a961f0..b32032e411b0a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -23,8 +23,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .utils import maybe_prefix @@ -187,7 +185,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) self.mamba_cache = MambaCacheManager( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1fa47f553dfd6..4692762493f00 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,6 @@ import torch.distributed import torch.nn as nn -from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -100,7 +99,11 @@ def __init__( == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] + # The convention is different. + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed(self.vllm_config.compilation_config.capture_sizes)) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) @@ -548,10 +551,9 @@ def profile_run(self) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] - with set_compile_context(self.cudagraph_batch_sizes): - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index ae18c79c980c8..5697fbbaa2041 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -25,8 +25,7 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata, - _get_graph_batch_size) + ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -465,7 +464,8 @@ def _prepare_encoder_model_input_tensors( # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) - graph_batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = self.vllm_config.get_graph_batch_size( + batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size # extend the cross_block_tables and encoder_seq_lens to match diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 01b7022a763cd..c685dafca442d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -18,7 +18,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState -from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_kv_transfer_group, get_pp_group @@ -63,16 +62,7 @@ logger = init_logger(__name__) LORA_WARMUP_RANK = 8 -_BATCH_SIZE_ALIGNMENT = 8 -# all the token sizes that **can** be captured by cudagraph. -# they can be arbitrarily large. -# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. -# the actual sizes to capture will be determined by the model, -# depending on the model's max_num_seqs. -# NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) -] + _NUM_WARMUP_ITERS = 2 TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") @@ -763,7 +753,6 @@ def _use_captured_graph(self, max_decode_seq_len: int, max_encoder_seq_len: int = 0) -> bool: return (decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and batch_size <= self.runner.max_batchsize_to_capture) @@ -811,7 +800,7 @@ def _get_cuda_graph_pad_size(self, max_encoder_seq_len): return -1 - graph_batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size return graph_batch_size - batch_size @@ -1023,7 +1012,7 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size( self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ @@ -1333,14 +1322,7 @@ def profile_run(self) -> None: dtype=self.model_config.dtype, device=self.device) - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - if self.model_config.enforce_eager: - batch_size_capture_list = [] - with set_compile_context(batch_size_capture_list): - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() # Cleanup if self.lora_config: @@ -1463,18 +1445,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=self.model_config.dtype, device=self.device) - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): - for batch_size in reversed(batch_size_capture_list): + for batch_size in \ + self.vllm_config.compilation_config.capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, @@ -1997,37 +1975,3 @@ def forward( return self.output_buffers["hidden_states"] return self.output_buffers - - -def _get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _get_max_graph_batch_size(max_num_seqs: int) -> int: - """ - max_num_seqs: Maximum number of sequences in a batch. - _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. - - pad the max_num_seqs if necessary by calling _get_graph_batch_size, - which will deal with some edge cases like 1, 2, 4. - - if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. - if not, it means the padded size is larger than the largest size in - _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. - """ - padded_size = _get_graph_batch_size(max_num_seqs) - if padded_size in _BATCH_SIZES_TO_CAPTURE: - return padded_size - assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1]