diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index df201cdc7c554..329cc42558da6 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -31,8 +31,8 @@ cleanup_docker() { echo "Disk usage is above $threshold%. Cleaning up Docker images and volumes..." # Remove dangling images (those that are not tagged and not used by any container) docker image prune -f - # Remove unused volumes - docker volume prune -f + # Remove unused volumes / force the system prune for old images as well. + docker volume prune -f && docker system prune --force --filter "until=72h" --all echo "Docker images and volumes cleanup completed." else echo "Disk usage is below $threshold%. No cleanup needed." diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 32eed1a771718..9444dc43ea97e 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -9,6 +9,7 @@ # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only +# nightly(bool): run this test in nightly pipeline only # optional(bool): never run this test by default (i.e. need to unblock manually) # command(str): the single command to run for tests. incompatible with commands. # commands(list): the list of commands to run for test. incompatbile with command. @@ -330,18 +331,28 @@ steps: commands: - pytest -v -s models/decoder_only/language --ignore=models/decoder_only/language/test_models.py --ignore=models/decoder_only/language/test_big_models.py -- label: Decoder-only Multi-Modal Models Test # 1h31min +- label: Decoder-only Multi-Modal Models Test (Standard) #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language - tests/models/decoder_only/vision_language commands: - - pytest -v -s models/decoder_only/audio_language + - pytest -v -s models/decoder_only/audio_language -m core_model + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m core_model + +- label: Decoder-only Multi-Modal Models Test (Extended) + nightly: true + source_file_dependencies: + - vllm/ + - tests/models/decoder_only/audio_language + - tests/models/decoder_only/vision_language + commands: + - pytest -v -s models/decoder_only/audio_language -m 'not core_model' # HACK - run phi3v tests separately to sidestep this transformers bug # https://github.com/huggingface/transformers/issues/34307 - pytest -v -s models/decoder_only/vision_language/test_phi3v.py - - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language + - pytest -v -s --ignore models/decoder_only/vision_language/test_phi3v.py models/decoder_only/vision_language -m 'not core_model' - label: Other Models Test # 6min #mirror_hardwares: [amd] diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 6fddca0d6e4b9..a21acd9671eeb 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -5,3 +5,19 @@ updates: directory: "/" schedule: interval: "weekly" + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "weekly" + labels: ["dependencies"] + open-pull-requests-limit: 5 + reviewers: ["khluu", "simon-mo"] + allow: + - dependency-type: "all" + groups: + patch-update: + applies-to: version-updates + update-types: ["patch"] + minor-update: + applies-to: version-updates + update-types: ["minor"] diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 3a464c5f327ad..498d069c05f0d 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -418,6 +418,31 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); } out += kChunkSize; + + int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize); + // in case the final state is separated between the last "smem_exchange" and + // and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2), + // (which occurs when `final_state_position` is a non-positivie index) + // we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it + if (final_state_position < 0 && seqlen > kWidth){ + input_t vals_load[kNElts] = {0}; + if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){ + // chunk = n_chunks - 2, a segment of the final state sits in the last index + reinterpret_cast(vals_load)[0] = smem_exchange[kNThreads - 1]; + #pragma unroll + for (int w = 0; w < -final_state_position; ++w){ + conv_states[w] = vals_load[kNElts + final_state_position + w]; + } + } + if ((chunk == n_chunks - 1) && tidx == 0){ + // chunk = n_chunks - 1, the second segment of the final state first positions + reinterpret_cast(vals_load)[0] = smem_exchange[0]; + for (int w = -final_state_position; w < kWidth - 1; ++w){ + conv_states[w] = vals_load[w + final_state_position]; + } + return; + } + } } // Final state is stored in the smem_exchange last token slot, // in case seqlen < kWidth, we would need to take the final state from the @@ -446,9 +471,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { } else { // in case the final state is in between the threads data - reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; - reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); + if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){ + // In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a + // illegal access error on H100. + // Therefore, we access last_thread + 1, only if the final state data sits there + reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; + } + reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; #pragma unroll for (int w = 0; w < kWidth - 1; ++w){ conv_states[w] = x_vals_load[offset + w ]; diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index ff893b613f150..3279e7a108232 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -277,7 +277,7 @@ Text Generation * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - - + - ✅︎ - ✅︎ * - :code:`Qwen2ForCausalLM` - Qwen2 @@ -516,7 +516,7 @@ Text Generation - Qwen-VL - T + I\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - - + - ✅︎ - ✅︎ * - :code:`Qwen2AudioForConditionalGeneration` - Qwen2-Audio @@ -540,6 +540,9 @@ Text Generation | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. +.. note:: + vLLM currently only supports adding LoRA to the language backbone of multimodal models. + .. note:: For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index 20a81f4cad1d1..cab19e4ec5b6c 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -283,7 +283,7 @@ Feature x Feature - ✅ - ✅ - ✅ - - `✗ `__ + - `✗ `__ - ? - ✅ - ✅ diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 83d2548a506e4..60cdb186331fe 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -262,10 +262,9 @@ def run_qwen2_vl(question: str, modality: str): model_name = "Qwen/Qwen2-VL-7B-Instruct" - # Tested on L40 llm = LLM( model=model_name, - max_model_len=8192, + max_model_len=4096, max_num_seqs=5, # Note - mm_processor_kwargs can also be passed to generate/chat calls mm_processor_kwargs={ diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index a34d33efba1d8..d151d62516b07 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -6,18 +6,22 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel +from vllm.utils import direct_register_custom_op os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE) global_counter = 0 +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa + -@torch.library.custom_op("silly::attention", mutates_args=["out"]) def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: global global_counter @@ -27,12 +31,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out[0] += 1 -@silly_attention.register_fake -def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: return +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + @support_torch_compile class SillyModel(nn.Module): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index db6a983d70feb..e3e5a7d0fc5a5 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -8,6 +8,7 @@ import torch from torch import nn +from torch.library import Library from vllm.compilation.compile_context import set_compile_context from vllm.compilation.config import CompilationConfig @@ -15,9 +16,12 @@ from vllm.compilation.decorators import support_torch_compile from vllm.compilation.levels import CompilationLevel from vllm.plugins import set_compilation_config +from vllm.utils import direct_register_custom_op + +# create a library to hold the custom op +silly_lib = Library("silly", "FRAGMENT") # noqa -@torch.library.custom_op("silly::attention", mutates_args=["out"]) def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out: torch.Tensor) -> None: out.copy_(q) @@ -25,12 +29,20 @@ def silly_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, out += v -@silly_attention.register_fake -def _(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - out: torch.Tensor) -> None: +def silly_attention_fake(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + out: torch.Tensor) -> None: return +direct_register_custom_op( + op_name="attention", + op_func=silly_attention, + mutates_args=["out"], + fake_impl=silly_attention_fake, + target_lib=silly_lib, +) + + @dataclass class LlamaConfig: hidden_size: int = 128 diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 2f92ff73845f5..833589ba5dc9f 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -1,3 +1,4 @@ +import dataclasses from typing import Dict, List, Optional import pytest @@ -8,33 +9,109 @@ from ..utils import compare_all_settings +@dataclasses.dataclass +class TestSetting: + model: str + model_args: List[str] + pp_size: int + tp_size: int + attn_backend: str + method: str + fullgraph: bool + + +# representative settings for testing +test_settings = [ + # basic llama model + TestSetting( + model="meta-llama/Llama-3.2-1B", + model_args=[], + pp_size=2, + tp_size=2, + attn_backend="FLASHINFER", + method="generate", + fullgraph=True, + ), + # llama model with quantization + TestSetting( + model="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + model_args=["--quantization", "gptq"], + pp_size=1, + tp_size=1, + attn_backend="FLASH_ATTN", + method="generate", + fullgraph=True, + ), + # MoE model + TestSetting( + model="ibm/PowerMoE-3b", + model_args=[], + pp_size=1, + tp_size=2, + attn_backend="FLASH_ATTN", + method="generate", + fullgraph=True, + ), + # embedding model + TestSetting( + model="BAAI/bge-multilingual-gemma2", + model_args=["--task", "embedding"], + pp_size=1, + tp_size=1, + attn_backend="FLASHINFER", + method="encode", + fullgraph=True, + ), + # vision language model + TestSetting( + model="microsoft/Phi-3.5-vision-instruct", + model_args=["--trust-remote-code", "--max-model-len", "2048"], + pp_size=2, + tp_size=1, + attn_backend="FLASH_ATTN", + method="generate_with_image", + fullgraph=False, + ), +] + + # we cannot afford testing the full Catesian product # of all models and all levels -@pytest.mark.parametrize( - "model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", - [ - ("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASHINFER", "generate", True), - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", - ["--quantization", "compressed-tensors" - ], 1, 1, "FLASH_ATTN", "generate", True), - ("ibm/PowerMoE-3b", [], 1, 2, "FLASH_ATTN", "generate", True), - # TODO: add multi-modality test for llava - ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False) - ]) -def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, - method, fullgraph): +@pytest.mark.parametrize("test_setting", test_settings) +def test_compile_correctness(test_setting: TestSetting): # this test is run under multiple suits, with different GPUs. # make sure we only run the test with correct CUDA devices. # don't use "<", as it will duplicate the tests. + model = test_setting.model + model_args = test_setting.model_args + pp_size = test_setting.pp_size + tp_size = test_setting.tp_size + attn_backend = test_setting.attn_backend + method = test_setting.method + fullgraph = test_setting.fullgraph if cuda_device_count_stateless() != pp_size * tp_size: pytest.skip("Not correct CUDA devices for the test.") import os os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend - all_args = [["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + - ["-tp", str(tp_size)]] * 3 - # don't test VLLM_TORCH_COMPILE_LEVEL == 3 case - # inductor will change the output, so we cannot compare them. + final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \ + ["-tp", str(tp_size)] + all_envs: List[Optional[Dict[str, str]]] = [] + + for level in [ + CompilationLevel.NO_COMPILATION, + CompilationLevel.PIECEWISE, + ]: + all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)}) + + # inductor will change the output, so we only compare if the output + # is close, not exactly the same. + compare_all_settings( + model, [final_args] * 2, + all_envs, + method=method if method != "generate" else "generate_close") + all_envs.clear() + for level in [ CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS, @@ -46,4 +123,4 @@ def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend, all_envs[-1][ "VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore - compare_all_settings(model, all_args, all_envs, method=method) + compare_all_settings(model, [final_args] * 3, all_envs, method=method) diff --git a/tests/entrypoints/openai/test_prompt_validation.py b/tests/entrypoints/openai/test_prompt_validation.py index 58075f7023821..1ae64ef492d5b 100644 --- a/tests/entrypoints/openai/test_prompt_validation.py +++ b/tests/entrypoints/openai/test_prompt_validation.py @@ -35,3 +35,23 @@ async def test_out_of_vocab_token_ids(): prompt=[999999], max_tokens=5, temperature=0.0) + + +@pytest.mark.asyncio +async def test_reject_multistep_with_guided_decoding(): + model_name = "gpt2" + server_args = ["--enforce-eager", "--num-scheduler-steps", "8"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + + with pytest.raises(openai.BadRequestError, + match=re.compile( + '.*Guided decoding .* multi-step decoding.*')): + await client.completions.create( + model=model_name, + prompt="Hello", + max_tokens=5, + temperature=0.0, + extra_body={"response_format": { + "type": "json_object" + }}) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 96bfe06d74ae5..f9b11018288be 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -151,7 +151,7 @@ def causal_conv1d_opcheck_fn(x: torch.Tensor, @pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize( - 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) + 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 1025, 2048, 4096]) @pytest.mark.parametrize('dim', [64]) @pytest.mark.parametrize('batch', [1]) def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, @@ -420,7 +420,10 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias, unpadded_out = out[:, :out_ref_tensor.shape[-1]] assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol) - assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + assert torch.allclose(final_states[state_indices], + final_states_ref[state_indices], + rtol=rtol, + atol=atol) causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), padded_state_indices, has_initial_states, diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index bf7ff3b5c59b8..ad05a97685351 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -555,7 +555,7 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: - rtol, atol = 7e-2, 7e-2 + rtol, atol = 1e-1, 1e-1 if torch.version.hip: atol *= 2 # set seed @@ -610,8 +610,8 @@ def test_selective_state_update_with_batch_indices(with_padding, dim, dstate, dt_bias=dt_bias, dt_softplus=True) - print("Output diff max", (out - out_ref[0]).max()) - print("Output diff mean", (out - out_ref[0]).mean()) + print("Output diff max", (out[:batch_size] - out_ref).max()) + print("Output diff mean", (out[:batch_size] - out_ref).mean()) print("Output state diff max", (state[state_indices, :] - state_ref).max()) print("Output state diff mean", (state[state_indices, :] - state_ref).mean()) diff --git a/tests/models/decoder_only/audio_language/test_ultravox.py b/tests/models/decoder_only/audio_language/test_ultravox.py index ad6c2d854d1f0..b9089e75ffab8 100644 --- a/tests/models/decoder_only/audio_language/test_ultravox.py +++ b/tests/models/decoder_only/audio_language/test_ultravox.py @@ -158,6 +158,7 @@ def run_multi_audio_test( assert all(tokens for tokens, *_ in vllm_outputs) +@pytest.mark.core_model @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) @@ -178,6 +179,7 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, ) +@pytest.mark.core_model @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py index 5c90e7f7a267c..c23fbedf0c6ae 100644 --- a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_qwen2_vl.py @@ -17,7 +17,7 @@ # Fixtures lazy import to avoid initializing CUDA during test collection -# NOTE: Qwen2vl supports multiple input modalities, so it registers multiple +# NOTE: Qwen2VL supports multiple input modalities, so it registers multiple # input mappers. @pytest.fixture() def image_input_mapper_for_qwen2_vl(): diff --git a/tests/models/decoder_only/vision_language/test_models.py b/tests/models/decoder_only/vision_language/test_models.py index 9370527e3cd57..d738647c91b66 100644 --- a/tests/models/decoder_only/vision_language/test_models.py +++ b/tests/models/decoder_only/vision_language/test_models.py @@ -75,6 +75,63 @@ # this is a good idea for checking your command first, since tests are slow. VLM_TEST_SETTINGS = { + #### Core tests to always run in the CI + "llava": VLMTestInfo( + models=["llava-hf/llava-1.5-7b-hf"], + test_type=( + VLMTestType.EMBEDDING, + VLMTestType.IMAGE, + VLMTestType.CUSTOM_INPUTS + ), + prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", + convert_assets_to_embeddings=model_utils.get_llava_embeddings, + max_model_len=4096, + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, + custom_test_opts=[CustomTestOptions( + inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( + formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:" + ), + limit_mm_per_prompt={"image": 4}, + )], + marks=[pytest.mark.core_model], + ), + "paligemma": VLMTestInfo( + models=["google/paligemma-3b-mix-224"], + test_type=VLMTestType.IMAGE, + prompt_formatter=identity, + img_idx_to_prompt = lambda idx: "", + # Paligemma uses its own sample prompts because the default one fails + single_image_prompts=IMAGE_ASSETS.prompts({ + "stop_sign": "caption es", + "cherry_blossom": "What is in the picture?", + }), + auto_cls=AutoModelForVision2Seq, + postprocess_inputs=model_utils.get_key_type_post_processor( + "pixel_values" + ), + vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, + dtype="half" if current_platform.is_rocm() else ("half", "float"), + marks=[pytest.mark.core_model], + ), + "qwen2_vl": VLMTestInfo( + models=["Qwen/Qwen2-VL-2B-Instruct"], + test_type=( + VLMTestType.IMAGE, + VLMTestType.MULTI_IMAGE, + VLMTestType.VIDEO + ), + prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 + img_idx_to_prompt=lambda idx: "<|vision_start|><|image_pad|><|vision_end|>", # noqa: E501 + video_idx_to_prompt=lambda idx: "<|vision_start|><|video_pad|><|vision_end|>", # noqa: E501 + max_model_len=4096, + max_num_seqs=2, + auto_cls=AutoModelForVision2Seq, + vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output, + marks=[pytest.mark.core_model], + image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)], + ), + #### Extended model tests "blip2": VLMTestInfo( models=["Salesforce/blip2-opt-2.7b"], test_type=VLMTestType.IMAGE, @@ -151,25 +208,6 @@ use_tokenizer_eos=True, patch_hf_runner=model_utils.internvl_patch_hf_runner, ), - "llava": VLMTestInfo( - models=["llava-hf/llava-1.5-7b-hf"], - test_type=( - VLMTestType.EMBEDDING, - VLMTestType.IMAGE, - VLMTestType.CUSTOM_INPUTS - ), - prompt_formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:", - convert_assets_to_embeddings=model_utils.get_llava_embeddings, - max_model_len=4096, - auto_cls=AutoModelForVision2Seq, - vllm_output_post_proc=model_utils.llava_image_vllm_to_hf_output, - custom_test_opts=[CustomTestOptions( - inputs=custom_inputs.multi_image_multi_aspect_ratio_inputs( - formatter=lambda img_prompt: f"USER: {img_prompt}\nASSISTANT:" - ), - limit_mm_per_prompt={"image": 4}, - )], - ), "llava_next": VLMTestInfo( models=["llava-hf/llava-v1.6-mistral-7b-hf"], test_type=(VLMTestType.IMAGE, VLMTestType.CUSTOM_INPUTS), @@ -200,12 +238,12 @@ vllm_output_post_proc=model_utils.llava_onevision_vllm_to_hf_output, # Llava-one-vision tests fixed sizes & the default size factors image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], - runner_mm_key="videos", custom_test_opts=[CustomTestOptions( inputs=custom_inputs.multi_video_multi_aspect_ratio_inputs( formatter=lambda vid_prompt: f"<|im_start|>user\n{vid_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501 ), limit_mm_per_prompt={"video": 4}, + runner_mm_key="videos", )], ), # FIXME @@ -218,9 +256,11 @@ auto_cls=AutoModelForVision2Seq, vllm_output_post_proc=model_utils.llava_video_vllm_to_hf_output, image_sizes=[((1669, 2560), (2560, 1669), (183, 488), (488, 183))], - runner_mm_key="videos", marks=[ - pytest.mark.skip(reason="LLava next video tests currently fail.") + pytest.mark.skipif( + transformers.__version__.startswith("4.46"), + reason="Model broken with changes in transformers 4.46" + ) ], ), "minicpmv": VLMTestInfo( @@ -234,23 +274,6 @@ postprocess_inputs=model_utils.wrap_inputs_post_processor, hf_output_post_proc=model_utils.minicmpv_trunc_hf_output, ), - "paligemma": VLMTestInfo( - models=["google/paligemma-3b-mix-224"], - test_type=VLMTestType.IMAGE, - prompt_formatter=identity, - img_idx_to_prompt = lambda idx: "", - # Paligemma uses its own sample prompts because the default one fails - single_image_prompts=IMAGE_ASSETS.prompts({ - "stop_sign": "caption es", - "cherry_blossom": "What is in the picture?", - }), - auto_cls=AutoModelForVision2Seq, - postprocess_inputs=model_utils.get_key_type_post_processor( - "pixel_values" - ), - vllm_output_post_proc=model_utils.paligemma_vllm_to_hf_output, - dtype="half" if current_platform.is_rocm() else ("half", "float"), - ), # Tests for phi3v currently live in another file because of a bug in # transformers. Once this issue is fixed, we can enable them here instead. # https://github.com/huggingface/transformers/issues/34307 diff --git a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py index 6856e8df81a13..e925934db0e7c 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/model_utils.py @@ -56,6 +56,17 @@ def qwen_vllm_to_hf_output( return output_ids, hf_output_str, out_logprobs +def qwen2_vllm_to_hf_output( + vllm_output: RunnerOutput, + model: str) -> Tuple[List[int], str, Optional[SampleLogprobs]]: + """Sanitize vllm output [qwen2 models] to be comparable with hf output.""" + output_ids, output_str, out_logprobs = vllm_output + + hf_output_str = output_str + "<|im_end|>" + + return output_ids, hf_output_str, out_logprobs + + def llava_image_vllm_to_hf_output(vllm_output: RunnerOutput, model: str) -> RunnerOutput: config = AutoConfig.from_pretrained(model) diff --git a/tests/models/decoder_only/vision_language/vlm_utils/runners.py b/tests/models/decoder_only/vision_language/vlm_utils/runners.py index 5a3f9e820dad0..2d3b39fe3594e 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/runners.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/runners.py @@ -29,6 +29,7 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": 1}, distributed_executor_backend=test_case.distributed_executor_backend, + runner_mm_key="images", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -51,6 +52,7 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"image": len(image_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, + runner_mm_key="images", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -74,6 +76,7 @@ def run_embedding_test(*, model_test_info: VLMTestInfo, limit_mm_per_prompt={"image": 1}, vllm_embeddings=vllm_embeddings, distributed_executor_backend=test_case.distributed_executor_backend, + runner_mm_key="images", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -101,6 +104,7 @@ def run_video_test( num_logprobs=test_case.num_logprobs, limit_mm_per_prompt={"video": len(video_assets)}, distributed_executor_backend=test_case.distributed_executor_backend, + runner_mm_key="videos", **model_test_info.get_non_parametrized_runner_kwargs()) @@ -115,7 +119,11 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo, inputs = test_case.custom_test_opts.inputs limit_mm_per_prompt = test_case.custom_test_opts.limit_mm_per_prompt - assert inputs is not None and limit_mm_per_prompt is not None + runner_mm_key = test_case.custom_test_opts.runner_mm_key + # Inputs, limit_mm_per_prompt, and runner_mm_key should all be set + assert inputs is not None + assert limit_mm_per_prompt is not None + assert runner_mm_key is not None core.run_test( hf_runner=hf_runner, @@ -127,4 +135,5 @@ def run_custom_inputs_test(*, model_test_info: VLMTestInfo, num_logprobs=test_case.num_logprobs, limit_mm_per_prompt=limit_mm_per_prompt, distributed_executor_backend=test_case.distributed_executor_backend, + runner_mm_key=runner_mm_key, **model_test_info.get_non_parametrized_runner_kwargs()) diff --git a/tests/models/decoder_only/vision_language/vlm_utils/types.py b/tests/models/decoder_only/vision_language/vlm_utils/types.py index 4d18d53af30fa..fd18c7c8346f0 100644 --- a/tests/models/decoder_only/vision_language/vlm_utils/types.py +++ b/tests/models/decoder_only/vision_language/vlm_utils/types.py @@ -52,6 +52,8 @@ class SizeType(Enum): class CustomTestOptions(NamedTuple): inputs: List[Tuple[List[str], List[Union[List[Image], Image]]]] limit_mm_per_prompt: Dict[str, int] + # kwarg to pass multimodal data in as to vllm/hf runner instances. + runner_mm_key: str = "images" class ImageSizeWrapper(NamedTuple): @@ -141,9 +143,6 @@ class VLMTestInfo(NamedTuple): Callable[[PosixPath, str, Union[List[ImageAsset], _ImageAssets]], str]] = None # noqa: E501 - # kwarg to pass multimodal data in as to vllm/hf runner instances - runner_mm_key: str = "images" - # Allows configuring a test to run with custom inputs custom_test_opts: Optional[List[CustomTestOptions]] = None @@ -168,7 +167,6 @@ def get_non_parametrized_runner_kwargs(self): "get_stop_token_ids": self.get_stop_token_ids, "model_kwargs": self.model_kwargs, "patch_hf_runner": self.patch_hf_runner, - "runner_mm_key": self.runner_mm_key, } diff --git a/tests/models/embedding/vision_language/test_llava_next.py b/tests/models/embedding/vision_language/test_llava_next.py index a8d0ac4fc160d..9fab5898a06ba 100644 --- a/tests/models/embedding/vision_language/test_llava_next.py +++ b/tests/models/embedding/vision_language/test_llava_next.py @@ -2,6 +2,7 @@ import pytest import torch.nn.functional as F +import transformers from transformers import AutoModelForVision2Seq from ....conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner @@ -85,8 +86,8 @@ def _run_test( ) -# FIXME -@pytest.mark.skip(reason="LLava next embedding tests currently fail") +@pytest.mark.skipif(transformers.__version__.startswith("4.46"), + reason="Model broken with changes in transformers 4.46") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models_text( diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 366b030eaa399..fd6564bbfe630 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -5,6 +5,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable +from vllm import SamplingParams, TokensPrompt from ..models.utils import check_outputs_equal @@ -12,6 +13,14 @@ "facebook/opt-125m", ] +UNSTABLE_PROMPT_SEQUENCE = [ + ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([3] * 1), + ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([5] * 50), + ([0] * 588) + ([1] * 1332) + ([2] * 30) + ([6] * 95), + ([0] * 588) + ([1] * 1332) + ([4] * 3) + ([7] * 174), + ([0] * 588) + ([8] * 1539), +] + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) @@ -57,3 +66,22 @@ def test_mixed_requests( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +def test_unstable_prompt_sequence( + vllm_runner, + backend: str, + monkeypatch, +) -> None: + override_backend_env_variable(monkeypatch, backend) + + with vllm_runner( + "Qwen/Qwen2.5-0.5B-Instruct", + enable_chunked_prefill=True, + enable_prefix_caching=True, + max_model_len=4096, + ) as vllm_model: + for prompt in UNSTABLE_PROMPT_SEQUENCE: + vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), + SamplingParams(max_tokens=1)) diff --git a/tests/utils.py b/tests/utils.py index e8aad9cb3268f..16e21f68c7c96 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ import asyncio +import copy import functools import os import signal @@ -8,13 +9,14 @@ import warnings from contextlib import contextmanager from pathlib import Path -from typing import Any, Callable, Dict, List, Literal, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union import openai import pytest import requests +import torch from openai.types.completion import Completion -from typing_extensions import ParamSpec, assert_never +from typing_extensions import ParamSpec import vllm.envs as envs from tests.models.utils import TextTextLogprobs @@ -272,6 +274,31 @@ def _test_completion( return results +def _test_completion_close( + client: openai.OpenAI, + model: str, + prompt: str, +): + results = [] + + # test with text prompt + completion = client.completions.create(model=model, + prompt=prompt, + max_tokens=1, + logprobs=5, + temperature=0.0) + + logporbs = completion.choices[0].logprobs.top_logprobs[0] + logporbs = {k: round(v, 2) for k, v in logporbs.items()} + + results.append({ + "test": "completion_close", + "logprobs": logporbs, + }) + + return results + + def _test_embeddings( client: openai.OpenAI, model: str, @@ -295,13 +322,81 @@ def _test_embeddings( return results +def _test_image_text( + client: openai.OpenAI, + model_name: str, + image_url: str, +): + results = [] + + # test pure text input + messages = [{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "How do you feel today?" + }, + ], + }] + + chat_completion = client.chat.completions.create(model=model_name, + messages=messages, + temperature=0.0, + max_tokens=1, + logprobs=True, + top_logprobs=5) + top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs + + for x in top_logprobs: + x.logprob = round(x.logprob, 2) + + results.append({ + "test": "pure_text", + "logprobs": top_logprobs, + }) + + messages = [{ + "role": + "user", + "content": [ + { + "type": "image_url", + "image_url": { + "url": image_url + } + }, + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + + chat_completion = client.chat.completions.create(model=model_name, + messages=messages, + temperature=0.0, + max_tokens=1, + logprobs=True, + top_logprobs=5) + top_logprobs = chat_completion.choices[0].logprobs.content[0].top_logprobs + + results.append({ + "test": "text_image", + "logprobs": top_logprobs, + }) + + return results + + def compare_two_settings(model: str, arg1: List[str], arg2: List[str], env1: Optional[Dict[str, str]] = None, env2: Optional[Dict[str, str]] = None, *, - method: Literal["generate", "encode"] = "generate", + method: str = "generate", max_wait_seconds: Optional[float] = None) -> None: """ Launch API server with two different sets of arguments/environments @@ -328,7 +423,7 @@ def compare_all_settings(model: str, all_args: List[List[str]], all_envs: List[Optional[Dict[str, str]]], *, - method: Literal["generate", "encode"] = "generate", + method: str = "generate", max_wait_seconds: Optional[float] = None) -> None: """ Launch API server with several different sets of arguments/environments @@ -397,10 +492,17 @@ def compare_all_settings(model: str, if method == "generate": results += _test_completion(client, model, prompt, token_ids) + elif method == "generate_close": + results += _test_completion_close(client, model, prompt) + elif method == "generate_with_image": + results += _test_image_text( + client, model, + "https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png" + ) elif method == "encode": results += _test_embeddings(client, model, prompt) else: - assert_never(method) + raise ValueError(f"Unknown method: {method}") if i > 0: # if any setting fails, raise an error early @@ -410,6 +512,18 @@ def compare_all_settings(model: str, compare_envs = all_envs[i] for ref_result, compare_result in zip(ref_results, compare_results): + ref_result = copy.deepcopy(ref_result) + compare_result = copy.deepcopy(compare_result) + if "embedding" in ref_result and method == "encode": + ref_embedding = torch.tensor(ref_result["embedding"]) + compare_embedding = torch.tensor( + compare_result["embedding"]) + mse = ((ref_embedding - compare_embedding)**2).mean() + assert mse < 1e-6, ( + f"Embedding for {model=} are not the same.\n" + f"mse={mse}\n") + del ref_result["embedding"] + del compare_result["embedding"] assert ref_result == compare_result, ( f"Results for {model=} are not the same.\n" f"{ref_args=} {ref_envs=}\n" diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ffa05e80623ac..c294fcf7f08fe 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -14,7 +14,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.forward_context import get_forward_context -from vllm.utils import async_tensor_h2d, make_tensor_with_pad +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -595,8 +596,6 @@ def forward( return output -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -755,8 +754,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -@unified_flash_attention.register_fake -def _( +def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -773,3 +771,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) + + +direct_register_custom_op( + op_name="unified_flash_attention", + op_func=unified_flash_attention, + mutates_args=["kv_cache"], + fake_impl=unified_flash_attention_fake, +) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 5ea101ae0432f..234c87d5c4edb 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -28,8 +28,8 @@ is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention from vllm.forward_context import get_forward_context -from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, - make_tensor_with_pad) +from vllm.utils import (async_tensor_h2d, direct_register_custom_op, + get_kv_cache_torch_dtype, make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -785,8 +785,6 @@ def forward( ) -@torch.library.custom_op("vllm::unified_flash_infer", - mutates_args=["kv_cache"]) def unified_flash_infer( query: torch.Tensor, key: torch.Tensor, @@ -906,8 +904,7 @@ def unified_flash_infer( return output.view(num_tokens, hidden_size) -@unified_flash_infer.register_fake -def _( +def unified_flash_infer_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -924,3 +921,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_flash_infer", + op_func=unified_flash_infer, + mutates_args=["kv_cache"], + fake_impl=unified_flash_infer_fake, +) diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index d1a44f3e8bfa6..32fccd0dfb496 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -138,7 +138,6 @@ def _add_seq_group( chunked_prefill_enabled: bool): is_prompt = inter_data.is_prompt block_tables = inter_data.block_tables - computed_block_nums = inter_data.computed_block_nums for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len, curr_sliding_window_block) in zip( @@ -164,10 +163,14 @@ def _add_seq_group( # NOTE: This only works for oooooooxxx style attention. block_table = [] if inter_data.prefix_cache_hit: - block_table = computed_block_nums + block_table = block_tables[seq_id] elif ((chunked_prefill_enabled or not is_prompt) and block_tables is not None): - block_table = block_tables[seq_id][-curr_sliding_window_block:] + if curr_sliding_window_block == 0: + block_table = block_tables[seq_id] + else: + block_table = block_tables[seq_id][ + -curr_sliding_window_block:] self.block_tables.append(block_table) # Compute slot mapping. diff --git a/vllm/config.py b/vllm/config.py index e9559c40dbdfb..c2a8c956b374a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -84,9 +84,6 @@ class ModelConfig: disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. If None, the user did not specify, so default to False. - max_context_len_to_capture: Maximum context len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED. Use max_seq_len_to_capture instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. Additionally for encoder-decoder models, if the @@ -147,7 +144,6 @@ def __init__( quantization: Optional[str] = None, quantization_param_path: Optional[str] = None, enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: Optional[int] = None, max_logprobs: int = 20, disable_sliding_window: bool = False, @@ -181,9 +177,6 @@ def __init__( self.quantization = quantization self.quantization_param_path = quantization_param_path self.enforce_eager = enforce_eager - if max_context_len_to_capture is not None: - raise ValueError("`max_context_len_to_capture` is deprecated. " - "Use `max_seq_len_to_capture` instead.") self.max_seq_len_to_capture = max_seq_len_to_capture self.max_logprobs = max_logprobs self.disable_sliding_window = disable_sliding_window diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b04bbc478534c..94ba41a016f6d 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,7 +37,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils import supports_custom_op +from vllm.utils import direct_register_custom_op, supports_custom_op @dataclass @@ -99,8 +99,6 @@ def _register_group(group: "GroupCoordinator") -> None: if supports_custom_op(): - @torch.library.custom_op("vllm::inplace_all_reduce", - mutates_args=["tensor"]) def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: assert group_name in _groups, f"Group {group_name} is not found." group = _groups[group_name]() @@ -108,11 +106,16 @@ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: raise ValueError(f"Group {group_name} is destroyed.") group._all_reduce_in_place(tensor) - @inplace_all_reduce.register_fake - def _(tensor: torch.Tensor, group_name: str) -> None: + def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None: return - @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + direct_register_custom_op( + op_name="inplace_all_reduce", + op_func=inplace_all_reduce, + mutates_args=["tensor"], + fake_impl=inplace_all_reduce_fake, + ) + def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor: assert group_name in _groups, f"Group {group_name} is not found." @@ -121,10 +124,17 @@ def outplace_all_reduce(tensor: torch.Tensor, raise ValueError(f"Group {group_name} is destroyed.") return group._all_reduce_out_place(tensor) - @outplace_all_reduce.register_fake - def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + def outplace_all_reduce_fake(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: return torch.empty_like(tensor) + direct_register_custom_op( + op_name="outplace_all_reduce", + op_func=outplace_all_reduce, + mutates_args=[], + fake_impl=outplace_all_reduce_fake, + ) + class GroupCoordinator: """ @@ -338,6 +348,11 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ + if input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + return input_ + if not supports_custom_op(): self._all_reduce_in_place(input_) return input_ @@ -369,9 +384,6 @@ def _all_reduce_in_place(self, input_: torch.Tensor) -> None: pynccl_comm = self.pynccl_comm if (pynccl_comm is not None and not pynccl_comm.disabled): pynccl_comm.all_reduce(input_) - elif input_.is_cpu: - import intel_extension_for_pytorch as ipex - ipex.distributed.all_reduce(input_, group=self.device_group) else: torch.distributed.all_reduce(input_, group=self.device_group) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index de886c98e51bd..b1f0f8b9df925 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -126,7 +126,6 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enforce_eager: Optional[bool] = None - max_context_len_to_capture: Optional[int] = None max_seq_len_to_capture: int = 8192 disable_custom_all_reduce: bool = False tokenizer_pool_size: int = 0 @@ -504,14 +503,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help='Always use eager-mode PyTorch. If False, ' 'will use eager mode and CUDA graph in hybrid ' 'for maximal performance and flexibility.') - parser.add_argument('--max-context-len-to-capture', - type=int, - default=EngineArgs.max_context_len_to_capture, - help='Maximum context length covered by CUDA ' - 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode. ' - '(DEPRECATED. Use --max-seq-len-to-capture instead' - ')') parser.add_argument('--max-seq-len-to-capture', type=int, default=EngineArgs.max_seq_len_to_capture, @@ -939,7 +930,6 @@ def create_model_config(self) -> ModelConfig: quantization=self.quantization, quantization_param_path=self.quantization_param_path, enforce_eager=self.enforce_eager, - max_context_len_to_capture=self.max_context_len_to_capture, max_seq_len_to_capture=self.max_seq_len_to_capture, max_logprobs=self.max_logprobs, disable_sliding_window=self.disable_sliding_window, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 3fd34fadee1ca..edef1f30a9e91 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -829,6 +829,13 @@ def add_request( raise ValueError(f"Got priority {priority} but " "Priority scheduling is not enabled.") + if isinstance(params, SamplingParams) \ + and (params.guided_decoding or params.logits_processors) \ + and self.scheduler_config.num_scheduler_steps > 1: + raise ValueError( + "Guided decoding and logits processors are not supported " + "in multi-step decoding") + if arrival_time is None: arrival_time = time.time() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 083b67c2f8e7d..3d62cb3598477 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -93,9 +93,6 @@ class LLM: enforce_eager: Whether to enforce eager execution. If True, we will disable CUDA graph and always execute the model in eager mode. If False, we will use CUDA graph and eager execution in hybrid. - max_context_len_to_capture: Maximum context len covered by CUDA graphs. - When a sequence has context length larger than this, we fall back - to eager mode (DEPRECATED. Use `max_seq_len_to_capture` instead). max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. When a sequence has context length larger than this, we fall back to eager mode. Additionally for encoder-decoder models, if the @@ -152,7 +149,6 @@ def __init__( swap_space: float = 4, cpu_offload_gb: float = 0, enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, disable_async_output_proc: bool = False, @@ -193,7 +189,6 @@ def __init__( swap_space=swap_space, cpu_offload_gb=cpu_offload_gb, enforce_eager=enforce_eager, - max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, disable_async_output_proc=disable_async_output_proc, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index 93019d0d0abb6..4741d69de11ac 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -8,6 +8,7 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, moe_align_block_size, try_get_optimal_moe_config) from vllm.scalar_type import scalar_types +from vllm.utils import direct_register_custom_op def get_scalar_type(num_bits: int, has_zp: bool): @@ -18,7 +19,6 @@ def get_scalar_type(num_bits: int, has_zp: bool): return scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128 -@torch.library.custom_op("vllm::single_marlin_moe", mutates_args=[]) def single_marlin_moe( hidden_states: torch.Tensor, w: torch.Tensor, @@ -119,8 +119,7 @@ def single_marlin_moe( return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) -@single_marlin_moe.register_fake -def _( +def single_marlin_moe_fake( hidden_states: torch.Tensor, w: torch.Tensor, scales: torch.Tensor, @@ -136,7 +135,14 @@ def _( return torch.empty_like(hidden_states) -@torch.library.custom_op("vllm::fused_marlin_moe", mutates_args=[]) +direct_register_custom_op( + op_name="single_marlin_moe", + op_func=single_marlin_moe, + mutates_args=[], + fake_impl=single_marlin_moe_fake, +) + + def fused_marlin_moe( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -324,8 +330,7 @@ def fused_marlin_moe( dim=1) -@fused_marlin_moe.register_fake -def _( +def fused_marlin_moe_fake( hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -344,3 +349,11 @@ def _( is_k_full: bool = True, ) -> torch.Tensor: return torch.empty_like(hidden_states) + + +direct_register_custom_op( + op_name="fused_marlin_moe", + op_func=fused_marlin_moe, + mutates_args=[], + fake_impl=fused_marlin_moe_fake, +) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 1cf5c2253ca0b..340da32263c1c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -12,6 +12,7 @@ from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op logger = init_logger(__name__) @@ -466,8 +467,6 @@ def get_config_dtype_str(dtype: torch.dtype, return None -@torch.library.custom_op("vllm::inplace_fused_experts", - mutates_args=["hidden_states"]) def inplace_fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, @@ -484,22 +483,29 @@ def inplace_fused_experts(hidden_states: torch.Tensor, a1_scale, a2_scale) -@inplace_fused_experts.register_fake -def _(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> None: +def inplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> None: pass -@torch.library.custom_op("vllm::outplace_fused_experts", mutates_args=[]) +direct_register_custom_op( + op_name="inplace_fused_experts", + op_func=inplace_fused_experts, + mutates_args=["hidden_states"], + fake_impl=inplace_fused_experts_fake, +) + + def outplace_fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -517,21 +523,29 @@ def outplace_fused_experts( w2_scale, a1_scale, a2_scale) -@outplace_fused_experts.register_fake -def _(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - use_fp8_w8a8: bool = False, - use_int8_w8a16: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None, - a1_scale: Optional[torch.Tensor] = None, - a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: +def outplace_fused_experts_fake( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: return torch.empty_like(hidden_states) +direct_register_custom_op( + op_name="outplace_fused_experts", + op_func=outplace_fused_experts, + mutates_args=[], + fake_impl=outplace_fused_experts_fake, +) + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index 7a039a78f09b8..718967a065192 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -119,7 +119,12 @@ def get_scaled_act_names(self) -> List[str]: def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]): - return any(module_name in prefix for module_name in llm_int8_skip_modules) + # Split the prefix into its dot-separated components + components = prefix.split('.') + + # Check if any of the skip modules exactly matches any component + return any(module_name in components + for module_name in llm_int8_skip_modules) class BitsAndBytesLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index eda99c029881f..27055e7ced865 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -493,13 +493,9 @@ def forward( :class:`LlavaImageInputs` """ if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: - # always pass the input via `inputs_embeds` - # to make sure the computation graph is consistent image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: vision_embeddings = self._process_image_input(image_input) inputs_embeds = self.language_model.model.get_input_embeddings( @@ -511,7 +507,11 @@ def forward( else: inputs_embeds = self.language_model.model.get_input_embeddings( input_ids) - input_ids = None + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 0fc4556831fd7..4928e447d5b9e 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -679,7 +679,6 @@ def forward(self, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object): if intermediate_tensors is not None: - input_ids = None inputs_embeds = None else: image_input = self._parse_and_validate_image_input(**kwargs) @@ -690,9 +689,14 @@ def forward(self, inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, self.image_token_id) - input_ids = None else: - inputs_embeds = None + inputs_embeds = self.language_model.model.embed_tokens( + input_ids) + + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + # for `torch.compile` integration + input_ids = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index f189d5857751e..61665768eacf5 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -1050,7 +1050,7 @@ def get_mm_mapping(self) -> MultiModelKeys: @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(QWenBaseModel): +class QWenLMHeadModel(QWenBaseModel, SupportsLoRA): """ QWenLMHeadModel is not only applicable to LLM but also to VL, which is not conducive to the current integration logic of LoRA in vLLM. Therefore, it diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5e191c6e715e0..5c6df5aaf5446 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -485,8 +485,8 @@ def __repr__(self) -> str: f"skip_special_tokens={self.skip_special_tokens}, " "spaces_between_special_tokens=" f"{self.spaces_between_special_tokens}, " - f"truncate_prompt_tokens={self.truncate_prompt_tokens}), " - f"guided_decoding={self.guided_decoding}") + f"truncate_prompt_tokens={self.truncate_prompt_tokens}, " + f"guided_decoding={self.guided_decoding})") class BeamSearchParams( diff --git a/vllm/utils.py b/vllm/utils.py index 03cdbe6a0dc7b..5488719cc99b0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -32,6 +32,7 @@ import torch.types import yaml from packaging.version import Version +from torch.library import Library from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -1512,3 +1513,47 @@ def weak_ref_tensors( if isinstance(tensors, tuple): return tuple(weak_ref_tensor(t) for t in tensors) raise ValueError("Invalid type for tensors") + + +def is_in_doc_build() -> bool: + try: + from sphinx.ext.autodoc.mock import _MockModule + return isinstance(torch, _MockModule) + except ModuleNotFoundError: + return False + + +# create a library to hold the custom op +vllm_lib = Library("vllm", "FRAGMENT") # noqa + + +def direct_register_custom_op( + op_name: str, + op_func: Callable, + mutates_args: List[str], + fake_impl: Optional[Callable] = None, + target_lib: Optional[Library] = None, +): + """ + `torch.library.custom_op` can have significant overhead because it + needs to consider complicated dispatching logic. This function + directly registers a custom op and dispatches it to the CUDA backend. + See https://gist.github.com/youkaichao/ecbea9ec9fc79a45d2adce1784d7a9a5 + for more details. + + By default, the custom op is registered to the vLLM library. If you + want to register it to a different library, you can pass the library + object to the `target_lib` argument. + + IMPORTANT: the lifetime of the operator is tied to the lifetime of the + library object. If you want to bind the operator to a different library, + make sure the library object is alive when the operator is used. + """ + if is_in_doc_build(): + return + schema_str = torch.library.infer_schema(op_func, mutates_args=mutates_args) + my_lib = target_lib or vllm_lib + my_lib.define(op_name + schema_str) + my_lib.impl(op_name, op_func, "CUDA") + if fake_impl is not None: + my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index ec07464e6a12a..b2af89ebf854a 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -7,6 +7,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) from vllm.forward_context import get_forward_context +from vllm.utils import direct_register_custom_op from vllm.vllm_flash_attn import flash_attn_varlen_func @@ -152,8 +153,6 @@ def forward( return output -@torch.library.custom_op("vllm::unified_flash_attention", - mutates_args=["kv_cache"]) def unified_flash_attention( query: torch.Tensor, key: torch.Tensor, @@ -217,8 +216,7 @@ def unified_flash_attention( return output.view(num_tokens, hidden_size) -@unified_flash_attention.register_fake -def _( +def unified_flash_attention_fake( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -235,3 +233,11 @@ def _( logits_soft_cap: Optional[float] = None, ) -> torch.Tensor: return torch.empty_like(query) + + +direct_register_custom_op( + op_name="unified_flash_attention", + op_func=unified_flash_attention, + mutates_args=["kv_cache"], + fake_impl=unified_flash_attention_fake, +) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 233a9e664d845..891637dafbb14 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -995,7 +995,7 @@ def __init__( # Python can be expensive. To optimize this, we cache the block table # in numpy and only copy the actual input content at every iteration. # The shape of the cached block table will be - # (max batch size to capture, max context len to capture / block size). + # (max batch size to capture, max seq len to capture / block size). self.graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32)