From 2c899318ee276f97f37814ba8ee471ba9d0d3fd8 Mon Sep 17 00:00:00 2001 From: sixsixcoder Date: Thu, 10 Oct 2024 06:33:36 +0000 Subject: [PATCH] add pytest for glm4v --- examples/offline_inference_vision_language.py | 6 -- .../decoder_only/vision_language/test_glm4.py | 77 +++++++-------- vllm/model_executor/models/chatglm.py | 96 ++++++------------- 3 files changed, 62 insertions(+), 117 deletions(-) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index d9f344d9fa229..4da82c68c2b8a 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -11,10 +11,6 @@ from vllm.assets.image import ImageAsset from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser -import os - -os.environ['CUDA_VISIBLE_DEVICES'] = '3' - # LLaVA-1.5 def run_llava(question): @@ -201,7 +197,6 @@ def run_qwen2_vl(question): def run_glm4v(question): model_name = "THUDM/glm-4v-9b" - model_name = "/workspace/siaowei/model/glm-4v-9b" llm = LLM( model=model_name, @@ -210,7 +205,6 @@ def run_glm4v(question): trust_remote_code=True, # gpu_memory_utilization=0.5, enforce_eager=True) - # prompt = f"[gMASK]<|user|>{question}<|begin_of_image|><|endoftext|><|end_of_image|><|assistant|>" prompt = question stop_token_ids = [151329, 151336, 151338] return llm, prompt, stop_token_ids diff --git a/tests/models/decoder_only/vision_language/test_glm4.py b/tests/models/decoder_only/vision_language/test_glm4.py index 0bf279cb28ab1..1465525370716 100644 --- a/tests/models/decoder_only/vision_language/test_glm4.py +++ b/tests/models/decoder_only/vision_language/test_glm4.py @@ -1,15 +1,9 @@ # tests/models/decoder_only/vision_language/test_glm4v.py -from typing import List, Optional, Tuple, Type, Dict - import pytest -import os -from vllm.sequence import SampleLogprobs +from typing import List, Optional, Tuple, Type from vllm.multimodal.utils import rescale_image_size -from vllm.utils import is_cpu -from transformers import AutoConfig, AutoTokenizer - - -from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner) +from ....conftest import (IMAGE_ASSETS, HfRunner, + PromptImageInput, VllmRunner) from ...utils import check_logprobs_close HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ @@ -18,13 +12,9 @@ "cherry_blossom": "What is the season?", }) -os.environ['CUDA_VISIBLE_DEVICES'] = '3' - -models = ["/workspace/siaowei/model/glm-4v-9b"] +models = ["THUDM/glm-4v-9b"] target_dtype = "bfloat16" -if is_cpu(): - target_dtype = "bfloat16" def run_test( hf_runner: Type[HfRunner], @@ -38,49 +28,48 @@ def run_test( mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, -): +): # max_model_len should be greater than image_feature_size - with vllm_runner(model, - max_model_len=4096, - max_num_seqs=1, - dtype=dtype, - limit_mm_per_prompt={"image": mm_limit}, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - # gpu_memory_utilization=0.9, - enforce_eager=True) as vllm_model: + with vllm_runner( + model, + max_model_len=4096, + max_num_seqs=1, + dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + # gpu_memory_utilization=0.9, + enforce_eager=True) as vllm_model: # tokenizer = vllm_model.model.get_tokenizer() - # pass stop_token_ids = [151329, 151336, 151338] vllm_outputs_per_image = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - stop_token_ids=stop_token_ids) - for prompts, images in inputs - ] + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images, + stop_token_ids=stop_token_ids) + for prompts, images in inputs + ] with hf_runner(model, dtype=dtype) as hf_model: - eos_token_id = hf_model.tokenizer.eos_token_id + hf_model.model.get_output_embeddings = lambda: \ + hf_model.model.transformer.output_layer hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, images=images, - eos_token_id=eos_token_id - # tokenizer=tokenizer + # tokenizer=tokenizer ) for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, vllm_outputs_per_image): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) @pytest.mark.parametrize("model", models) @@ -90,11 +79,11 @@ def run_test( # No image [], # Single-scale - # [1.0], + [1.0], # Single-scale, batched - # [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], # Multi-scale - # [0.25, 0.5, 1.0], + [0.25, 0.5, 1.0], ], ) @pytest.mark.parametrize("dtype", [target_dtype]) @@ -118,4 +107,4 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, mm_limit=1, tensor_parallel_size=1, - ) \ No newline at end of file + ) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 15c421b197225..37f2b0759be49 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -4,17 +4,16 @@ """Inference-only ChatGLM model compatible with THUDM weights.""" from argparse import Namespace from array import array -from typing import (Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict) -from PIL import Image - +from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict import torch +from PIL import Image from torch import nn from torch.nn import LayerNorm -from vllm.logger import init_logger from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -30,8 +29,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.glm4_vision_encoder import EVA2CLIPModel from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.multimodal.base import MultiModalData from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict +from vllm.multimodal.base import MultiModalData +from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) from vllm.transformers_utils.configs import ChatGLMConfig @@ -45,62 +45,26 @@ def calculate_image_placeholder(vision_config): return (vision_config["image_size"] // vision_config["patch_size"] // 2)**2 -# @lru_cache -def cached_get_image_processor( - processor_name: str, - *args, - trust_remote_code: bool = False, - **kwargs, -): - """Gets an image processor for the given model name via HuggingFace.""" - # don't put this import at the top level - # it will call torch.cuda.device_count() - from transformers import AutoTokenizer - - try: - processor = AutoTokenizer.from_pretrained( - processor_name, - *args, - trust_remote_code=trust_remote_code, - **kwargs) - image_processor = processor.apply_chat_template - except ValueError as e: - # If the error pertains to the processor class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - # Unlike AutoTokenizer, AutoImageProcessor does not separate such errors - if not trust_remote_code: - err_msg = ( - "Failed to load the image processor. If the image processor is " - "a custom processor not yet available in the HuggingFace " - "transformers library, consider setting " - "`trust_remote_code=True` in LLM or using the " - "`--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - - return image_processor - - def mm_input_mapper_for_glmv( ctx: InputContext, data: MultiModalData[object], ) -> Dict: model_config = ctx.model_config - image_processor = cached_get_image_processor( - model_config.model, trust_remote_code=model_config.trust_remote_code) - if image_processor is None: + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + if tokenizer is None: raise RuntimeError("No HuggingFace processor is available " "to process the image object") try: - raw_batch_data = image_processor(conversation=[{ - "role": "user", - "image": data - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True).data + raw_batch_data = tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "image": data + }], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True).data except Exception: logger.error("Failed to process image (%s)", data) raise @@ -196,23 +160,21 @@ def input_processor_for_glmv(ctx: InputContext, llm_inputs: LLMInputs): input_ids = llm_inputs.get("prompt_token_ids") position_ids = llm_inputs.get("position_ids") - - image_processor = cached_get_image_processor( + tokenizer = cached_get_tokenizer( ctx.model_config.model, trust_remote_code=ctx.model_config.trust_remote_code) + try: - raw_batch_data = image_processor(conversation=[{ - "role": - "user", - "image": - llm_inputs['multi_modal_data']["image"], - "content": - llm_inputs['prompt'] - }], - add_generation_prompt=True, - tokenize=True, - return_tensors="pt", - return_dict=True).data + raw_batch_data = tokenizer.apply_chat_template( + conversation=[{ + "role": "user", + "image": llm_inputs['multi_modal_data']["image"], + "content": llm_inputs['prompt'] + }], + add_generation_prompt=True, + tokenize=True, + return_tensors="pt", + return_dict=True).data except Exception: logger.error("Failed to process content (%s)", llm_inputs['prompt']) raise