From 606fddd827535084a6764a58dabc12ce2ae7cad8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 29 Sep 2024 00:54:35 +0800 Subject: [PATCH] [CI/Build] Update models tests & examples (#8874) Co-authored-by: Roger Wang --- .buildkite/test-pipeline.yaml | 51 +++--- examples/offline_inference_vision_language.py | 28 ++-- ...e_inference_vision_language_multi_image.py | 13 +- tests/conftest.py | 84 +++++----- .../vision_language/test_llava_onevision.py | 29 ++-- .../vision_language/test_minicpmv.py | 2 +- .../vision_language/test_phi3v.py | 2 +- .../decoder_only/vision_language/test_qwen.py | 2 +- .../vision_language/test_broadcast.py | 35 ++++ .../vision_language/test_mllama.py | 153 ++++++++---------- tests/models/utils.py | 9 +- vllm/inputs/registry.py | 12 +- .../layers/quantization/utils/w8a8_utils.py | 3 +- 13 files changed, 239 insertions(+), 184 deletions(-) create mode 100644 tests/models/encoder_decoder/vision_language/test_broadcast.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index d9dcacf5d991e..bb42b5f29a725 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -9,6 +9,7 @@ # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) # command(str): the single command to run for tests. incompatible with commands. # commands(list): the list of commands to run for test. incompatbile with command. # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] @@ -39,7 +40,7 @@ steps: # Check API reference (if it fails, you may have missing mock imports) - grep \"sig sig-object py\" build/html/dev/sampling_params.html -- label: Async Engine, Inputs, Utils, Worker Test # 15min +- label: Async Engine, Inputs, Utils, Worker Test # 24min fast_check: true source_file_dependencies: - vllm/ @@ -81,7 +82,7 @@ steps: commands: - pytest -v -s core -- label: Entrypoints Test # 20min +- label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" fast_check: true mirror_hardwares: [amd] @@ -151,7 +152,7 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: Examples Test # 12min +- label: Examples Test # 15min working_dir: "/vllm-workspace/examples" #mirror_hardwares: [amd] source_file_dependencies: @@ -169,7 +170,7 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: Prefix Caching Test # 7min +- label: Prefix Caching Test # 9min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -177,7 +178,7 @@ steps: commands: - pytest -v -s prefix_caching -- label: Samplers Test # 18min +- label: Samplers Test # 36min source_file_dependencies: - vllm/model_executor/layers - vllm/sampling_metadata.py @@ -193,7 +194,7 @@ steps: - tests/test_logits_processor command: pytest -v -s test_logits_processor.py -- label: Speculative decoding tests # 22min +- label: Speculative decoding tests # 30min source_file_dependencies: - vllm/spec_decode - tests/spec_decode @@ -203,7 +204,7 @@ steps: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py -- label: LoRA Test %N # 30min each +- label: LoRA Test %N # 15min each mirror_hardwares: [amd] source_file_dependencies: - vllm/lora @@ -211,7 +212,7 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 -- label: "PyTorch Fullgraph Smoke Test" +- label: "PyTorch Fullgraph Smoke Test" # 9min fast_check: true source_file_dependencies: - vllm/ @@ -219,14 +220,14 @@ steps: commands: - pytest -v -s compile/test_full_graph_smoke.py -- label: "PyTorch Fullgraph Test" +- label: "PyTorch Fullgraph Test" # 18min source_file_dependencies: - vllm/ - tests/compile commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Test %N # 30min each +- label: Kernels Test %N # 1h each mirror_hardwares: [amd] source_file_dependencies: - csrc/ @@ -256,7 +257,7 @@ steps: - pip install aiohttp - bash run-benchmarks.sh -- label: Quantization Test # 15min +- label: Quantization Test # 33min source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization @@ -300,7 +301,7 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/*.py --ignore=models/test_oot_registration.py -- label: Decoder-only Language Models Test # 1h3min +- label: Decoder-only Language Models Test # 1h36min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -308,7 +309,7 @@ steps: commands: - pytest -v -s models/decoder_only/language -- label: Decoder-only Multi-Modal Models Test # 56min +- label: Decoder-only Multi-Modal Models Test # 1h31min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -318,15 +319,25 @@ steps: - pytest -v -s models/decoder_only/audio_language - pytest -v -s models/decoder_only/vision_language -- label: Other Models Test # 5min +- label: Other Models Test # 6min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/embedding/language - tests/models/encoder_decoder/language + - tests/models/encoder_decoder/vision_language commands: - pytest -v -s models/embedding/language - pytest -v -s models/encoder_decoder/language + - pytest -v -s models/encoder_decoder/vision_language + +- label: Custom Models Test + #mirror_hardwares: [amd] + optional: true + commands: + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* ##### 1 GPU test ##### ##### multi gpus test ##### @@ -359,7 +370,7 @@ steps: - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' -- label: Distributed Tests (2 GPUs) # 28min +- label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -376,14 +387,16 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error - - pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py -- label: Multi-step Tests (4 GPUs) # 21min +- label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -401,7 +414,7 @@ steps: - pytest -v -s multi_step/test_correctness_async_llm.py - pytest -v -s multi_step/test_correctness_llm.py -- label: Pipeline Parallelism Test # 23min +- label: Pipeline Parallelism Test # 45min working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -427,7 +440,7 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s -x lora/test_long_context.py -- label: Weight Loading Multiple GPU Test +- label: Weight Loading Multiple GPU Test # 33min working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 6d34621a8a9bc..b94ef537d783f 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -12,6 +12,10 @@ from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + # LLaVA-1.5 def run_llava(question, modality): @@ -19,7 +23,7 @@ def run_llava(question, modality): prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf") + llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) stop_token_ids = None return llm, prompt, stop_token_ids @@ -57,7 +61,7 @@ def run_llava_onevision(question, modality): <|im_start|>assistant\n" llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf", - max_model_len=32768) + max_model_len=16384) stop_token_ids = None return llm, prompt, stop_token_ids @@ -67,7 +71,7 @@ def run_fuyu(question, modality): assert modality == "image" prompt = f"{question}\n" - llm = LLM(model="adept/fuyu-8b") + llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2) stop_token_ids = None return llm, prompt, stop_token_ids @@ -99,7 +103,8 @@ def run_phi3v(question, modality): llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, - max_num_seqs=5, + max_model_len=4096, + max_num_seqs=2, mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None @@ -122,7 +127,7 @@ def run_chameleon(question, modality): assert modality == "image" prompt = f"{question}" - llm = LLM(model="facebook/chameleon-7b") + llm = LLM(model="facebook/chameleon-7b", max_model_len=4096) stop_token_ids = None return llm, prompt, stop_token_ids @@ -145,6 +150,8 @@ def run_minicpmv(question, modality): trust_remote_code=True) llm = LLM( model=model_name, + max_model_len=4096, + max_num_seqs=2, trust_remote_code=True, ) # NOTE The stop_token_ids are different for various versions of MiniCPM-V @@ -177,7 +184,7 @@ def run_internvl(question, modality): llm = LLM( model=model_name, trust_remote_code=True, - max_num_seqs=5, + max_model_len=4096, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -215,7 +222,8 @@ def run_qwen_vl(question, modality): llm = LLM( model="Qwen/Qwen-VL", trust_remote_code=True, - max_num_seqs=5, + max_model_len=1024, + max_num_seqs=2, ) prompt = f"{question}Picture 1: \n" @@ -229,8 +237,10 @@ def run_qwen2_vl(question, modality): model_name = "Qwen/Qwen2-VL-7B-Instruct" + # Tested on L40 llm = LLM( model=model_name, + max_model_len=8192, max_num_seqs=5, ) @@ -252,10 +262,10 @@ def run_mllama(question, modality): # max_model_len (131072) for this model may cause OOM. # You may lower either to run this example on lower-end GPUs. - # The configuration below has been confirmed to launch on a - # single H100 GPU. + # The configuration below has been confirmed to launch on a single L40 GPU. llm = LLM( model=model_name, + max_model_len=4096, max_num_seqs=16, enforce_eager=True, ) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 8c5f1a7b7af08..1e99c02234d01 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -28,12 +28,18 @@ class ModelRequestData(NamedTuple): chat_template: Optional[str] +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + + def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, trust_remote_code=True, - max_num_seqs=5, + max_model_len=1024, + max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, ) placeholders = "".join(f"Picture {i}: \n" @@ -83,6 +89,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, + max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, mm_processor_kwargs={"num_crops": 4}, ) @@ -106,7 +113,6 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: llm = LLM( model=model_name, trust_remote_code=True, - max_num_seqs=5, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, ) @@ -148,10 +154,11 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen2-VL-7B-Instruct" + # Tested on L40 llm = LLM( model=model_name, - max_num_seqs=5, max_model_len=32768 if process_vision_info is None else 4096, + max_num_seqs=5, limit_mm_per_prompt={"image": len(image_urls)}, ) diff --git a/tests/conftest.py b/tests/conftest.py index db71d8bc3af1e..45dc5e8323ca4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,17 +246,14 @@ def video_assets() -> _VideoAssets: class HfRunner: - def wrap_device(self, input: _T) -> _T: - if not is_cpu(): - # Check if the input is already on the GPU - if hasattr(input, 'device') and input.device.type == "cuda": - return input # Already on GPU, no need to move - return input.to("cuda") - else: - # Check if the input is already on the CPU - if hasattr(input, 'device') and input.device.type == "cpu": - return input # Already on CPU, no need to move - return input.to("cpu") + def wrap_device(self, input: _T, device: Optional[str] = None) -> _T: + if device is None: + return self.wrap_device(input, "cpu" if is_cpu() else "cuda") + + if hasattr(input, "device") and input.device.type == device: + return input + + return input.to(device) def __init__( self, @@ -333,7 +330,7 @@ def generate( inputs = self.postprocess_inputs(inputs) output_ids = self.model.generate( - **self.wrap_device(inputs), + **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, **kwargs, ) @@ -406,7 +403,7 @@ def generate_greedy_logprobs( inputs = self.postprocess_inputs(inputs) output = self.model.generate( - **self.wrap_device(inputs), + **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -414,40 +411,39 @@ def generate_greedy_logprobs( return_dict_in_generate=True, **kwargs, ) - seq_logprobs: List[torch.Tensor] = [] - for hidden_states in output.hidden_states: - last_hidden_states = hidden_states[-1][0] - logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), - ) - if self.model.get_output_embeddings().bias is not None: - logits += self.model.get_output_embeddings( - ).bias.unsqueeze(0) - logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - seq_logprobs.append(logprobs) + seq_logprobs = self._hidden_states_to_seq_logprobs( + output.hidden_states) all_logprobs.append(seq_logprobs) return all_logprobs - def _hidden_states_to_logprobs( + def _hidden_states_to_seq_logprobs( self, - hidden_states, - num_logprobs, - ) -> Tuple[List[Dict[int, float]], int]: + hidden_states: Tuple[Tuple[torch.Tensor, ...], ...], + ) -> List[torch.Tensor]: + output_embeddings = self.model.get_output_embeddings() + seq_logprobs: List[torch.Tensor] = [] - output_len = len(hidden_states) for _, hidden_state in enumerate(hidden_states): last_hidden_states = hidden_state[-1][0] logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), + last_hidden_states.to(output_embeddings.weight.device), + output_embeddings.weight.t(), ) - if getattr(self.model.get_output_embeddings(), "bias", - None) is not None: - logits += self.model.get_output_embeddings().bias.unsqueeze(0) + if getattr(output_embeddings, "bias", None) is not None: + logits += output_embeddings.bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) + return seq_logprobs + + def _hidden_states_to_logprobs( + self, + hidden_states: Tuple[Tuple[torch.Tensor, ...], ...], + num_logprobs: int, + ) -> Tuple[List[Dict[int, float]], int]: + seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) + output_len = len(hidden_states) + # convert to dict seq_logprobs_lst: List[Dict[int, float]] = [] for tok_idx, tok_logprobs in enumerate(seq_logprobs): @@ -500,7 +496,7 @@ def generate_greedy_logprobs_limit( inputs = self.postprocess_inputs(inputs) output = self.model.generate( - **self.wrap_device(inputs), + **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -543,12 +539,20 @@ def generate_encoder_decoder_greedy_logprobs_limit( for (encoder_prompt, decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): + encoder_input_ids = self.wrap_device( - self.tokenizer(encoder_prompt, return_tensors="pt").input_ids) - decoder_input_ids = ( - None if decoder_prompt is None else self.wrap_device( + self.tokenizer(encoder_prompt, return_tensors="pt").input_ids, + device=self.model.device.type, + ) + + if decoder_prompt is None: + decoder_input_ids = None + else: + decoder_input_ids = self.wrap_device( self.tokenizer(decoder_prompt, - return_tensors="pt").input_ids)) + return_tensors="pt").input_ids, + device=self.model.device.type, + ) output = self.model.generate( encoder_input_ids, diff --git a/tests/models/decoder_only/vision_language/test_llava_onevision.py b/tests/models/decoder_only/vision_language/test_llava_onevision.py index 978631feacb8c..2c4cd3fb85297 100644 --- a/tests/models/decoder_only/vision_language/test_llava_onevision.py +++ b/tests/models/decoder_only/vision_language/test_llava_onevision.py @@ -16,8 +16,7 @@ # Video test HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ "sample_demo_1": - "<|im_start|>user