Skip to content

Commit

Permalink
[torch.compile] remove compilation_context and simplify code (vllm-pr…
Browse files Browse the repository at this point in the history
…oject#10838)

Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored and weilong.yu committed Dec 13, 2024
1 parent 441a663 commit 3257258
Show file tree
Hide file tree
Showing 14 changed files with 128 additions and 143 deletions.
9 changes: 4 additions & 5 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
Expand All @@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):

with set_compile_context([1, 2]):
model(inputs)
model(inputs)

model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())

input = torch.zeros(2).cuda()
global global_counter
Expand Down
33 changes: 17 additions & 16 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
Expand All @@ -273,10 +273,9 @@ def run_model(llama_config,
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()

with set_compile_context([1, 2]):
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])

input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
Expand Down Expand Up @@ -379,10 +378,13 @@ def benchmark():
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=cudagraph_sizes,
)

vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
Expand All @@ -396,17 +398,16 @@ def benchmark():

graphs = {}

with set_compile_context(cudagraph_sizes):
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
else:
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
graphs[b] = (graph, output)
else:
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
for b in cudagraph_sizes:
if piecewise:
# noqa is for `Function definition does not bind loop variable`
Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

from ...utils import check_outputs_equal

Expand Down Expand Up @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

from ...utils import check_outputs_equal

Expand Down Expand Up @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size

BATCH_SIZES = [1, 4, 16, 64, 256]

Expand Down Expand Up @@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
Expand Down
5 changes: 3 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
from vllm.worker.model_runner import ModelRunner


def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
Expand Down Expand Up @@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.num_prefills == 0
Expand Down
4 changes: 0 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
assert not self._called, "VllmBackend can only be called once"

self.graph = graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs.init_during_runtime()
self.configure_post_pass()

self.split_gm, self.piecewise_graphs = split_graph(
Expand Down
23 changes: 0 additions & 23 deletions vllm/compilation/compile_context.py

This file was deleted.

83 changes: 75 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,15 +2357,10 @@ def init_backend(self) -> Union[str, Callable]:
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)

def init_during_runtime(self):
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
we need to know the cudagraph sizes."""

if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
Expand All @@ -2386,6 +2381,21 @@ def init_during_runtime(self):
self.inductor_compile_sizes = []
self.compile_sizes = self.inductor_compile_sizes

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)


_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]


@dataclass
class VllmConfig:
Expand Down Expand Up @@ -2413,6 +2423,41 @@ class VllmConfig:
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)

@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]

@staticmethod
def _get_quantization_config(
model_config: ModelConfig,
Expand Down Expand Up @@ -2496,6 +2541,28 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
self.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -25,8 +25,6 @@
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
Expand Down Expand Up @@ -404,7 +402,7 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import MambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -23,8 +23,6 @@
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

from .utils import maybe_prefix

Expand Down Expand Up @@ -187,7 +185,7 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
self.mamba_cache = MambaCacheManager(
Expand Down
Loading

0 comments on commit 3257258

Please sign in to comment.