From 4f7e50b86a1c99b706f78efa0543ce4ccc5f5628 Mon Sep 17 00:00:00 2001 From: Lyu Han Date: Fri, 6 Dec 2024 18:39:01 +0800 Subject: [PATCH 01/10] update supported models (#2849) * update supported models * update deepseek-v2.5 * update --- README.md | 3 +++ README_ja.md | 3 +++ README_zh-CN.md | 3 +++ docs/en/supported_models/supported_models.md | 19 +++++++++++------ .../supported_models/supported_models.md | 21 ++++++++++++------- 5 files changed, 36 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index d160338aa6..8ef7b7994f 100644 --- a/README.md +++ b/README.md @@ -125,6 +125,8 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -136,6 +138,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/README_ja.md b/README_ja.md index fda176229e..77badaac36 100644 --- a/README_ja.md +++ b/README_ja.md @@ -122,6 +122,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -133,6 +135,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 6c24b2e500..9f3cd40a64 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -126,6 +126,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Qwen1.5 (0.5B - 110B)
  • Qwen1.5 - MoE (0.5B - 72B)
  • Qwen2 (0.5B - 72B)
  • +
  • Qwen2-MoE (57BA14B)
  • +
  • Qwen2.5 (0.5B - 32B)
  • Baichuan (7B)
  • Baichuan2 (7B-13B)
  • Code Llama (7B - 34B)
  • @@ -137,6 +139,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Mistral (7B)
  • DeepSeek-MoE (16B)
  • DeepSeek-V2 (16B, 236B)
  • +
  • DeepSeek-V2.5 (236B)
  • Mixtral (8x7B, 8x22B)
  • Gemma (2B - 7B)
  • Dbrx (132B)
  • diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 469ece487f..dd8ceb4ffa 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -10,7 +10,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | +| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | @@ -18,9 +18,13 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | -| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | +| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | @@ -29,7 +33,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | | LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | | InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | -| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | | ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | @@ -41,7 +45,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine "-" means not verified yet. ```{note} -The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. +* The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference. +* When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference ``` ## PyTorchEngine on CUDA Platform @@ -68,11 +73,13 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | -| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | @@ -81,7 +88,7 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha | CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | | CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | | LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - | | Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - | | ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index d734523282..3ec3688e1b 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -10,7 +10,7 @@ | Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | | Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | | Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | -| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | +| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes | | InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | | InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | @@ -18,9 +18,13 @@ | InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes | | Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | | Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes | -| Qwen2 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | +| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes | +| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes | | Mistral | 7B | LLM | Yes | Yes | Yes | No | | Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes | +| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No | | Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes | | Baichuan | 7B | LLM | Yes | Yes | Yes | Yes | @@ -29,7 +33,7 @@ | YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | | LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes | | InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes | -| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes | Yes | Yes | +| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes | | ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes | | MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes | @@ -41,7 +45,8 @@ “-” 表示还没有验证。 ```{note} -turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine +* turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine +* 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等 ``` ## PyTorchEngine CUDA 平台 @@ -68,11 +73,13 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes | | QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No | | QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | +| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | +| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | | MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No | -| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | Yes | Yes | +| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes | | Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No | | Dbrx | 132B | LLM | Yes | Yes | Yes | No | No | | StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No | @@ -81,7 +88,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - | | CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - | | LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - | -| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | Yes | Yes | +| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes | | InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - | | Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - | | ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - | @@ -94,7 +101,7 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att | Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - | ```{note} -* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead. +* 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16。 ``` ## PyTorchEngine 华为昇腾平台 From af0d95be0aeedfd135b3929f0377e53ef9a581f9 Mon Sep 17 00:00:00 2001 From: jinminxi104 Date: Mon, 9 Dec 2024 11:39:39 +0800 Subject: [PATCH 02/10] Update dlinfer-ascend version in runtime_ascend.txt (#2865) --- requirements/runtime_ascend.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index 05d74bbe72..c5d44cc995 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -1,5 +1,5 @@ accelerate>=0.29.3 -dlinfer-ascend>=0.1.2 +dlinfer-ascend>=0.1.3 einops fastapi fire From fb3f8cc1b53e8402c3a6968c31e10923c985d764 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 9 Dec 2024 13:09:39 +0800 Subject: [PATCH 03/10] warn glm4v does not support multi images --- lmdeploy/vl/model/glm_4v.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/lmdeploy/vl/model/glm_4v.py b/lmdeploy/vl/model/glm_4v.py index 2a72ee18fe..aa3372a5f7 100644 --- a/lmdeploy/vl/model/glm_4v.py +++ b/lmdeploy/vl/model/glm_4v.py @@ -39,16 +39,27 @@ def build_preprocessor(self): def preprocess(self, messages: List[Dict]) -> List[Dict]: """refers to the spec of `super.preprocess()""" - images = self.collect_images(messages) outputs = [] - for image, params in images: - image = image.convert('RGB') - pixel_values = self.image_transform(image) - outputs.append( - dict(pixel_values=pixel_values, - image_size=image.size, + for message in messages: + if not isinstance(message['content'], List): + continue + images = [ + x['image'] for x in message['content'] if x['type'] == 'image' + ] + if len(images) > 1: + logger.warning( + f'glm4v does not support the input of multiple images' + f' in a single chat round, but got {len(images)} images.') + # we still pass all the images to the model and let the + # model decide what to do + images = [x.convert('RGB') for x in images] + pixel_values = [self.image_transform(x) for x in images] + outputs.extend([ + dict(pixel_values=_2, + image_size=_1.size, image_tokens=self.n_token_per_image, - image_token_id=0)) + image_token_id=0) for _1, _2 in zip(images, pixel_values) + ]) messages.append(dict(role='preprocess', content=outputs)) return messages From ed2efb357b67475954383de4aca48cdec9999c97 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 9 Dec 2024 17:10:59 +0800 Subject: [PATCH 04/10] fix deepseek-vl --- lmdeploy/serve/vl_async_engine.py | 2 ++ lmdeploy/vl/model/deepseek.py | 20 +++++++++++++------- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/lmdeploy/serve/vl_async_engine.py b/lmdeploy/serve/vl_async_engine.py index cf35959bd4..c6f140a113 100644 --- a/lmdeploy/serve/vl_async_engine.py +++ b/lmdeploy/serve/vl_async_engine.py @@ -124,6 +124,8 @@ async def async_convert_to_pil_images(cls, def _inner_call(i, in_messages, out_messages): role = in_messages[i]['role'] content = in_messages[i]['content'] + assert role in ['sytem', 'user', 'assistant'], \ + f'unsupported role "{role}"' if role != 'user' or isinstance(content, str): # the content is a user's prompt or an assistant's prompt, # returning it directly diff --git a/lmdeploy/vl/model/deepseek.py b/lmdeploy/vl/model/deepseek.py index af682fb3bb..33f54784d5 100644 --- a/lmdeploy/vl/model/deepseek.py +++ b/lmdeploy/vl/model/deepseek.py @@ -153,22 +153,28 @@ def proc_messages(cls, messages, chat_template, sequence_start): x['text'] for x in message['content'] if x['type'] == 'text' ] content = content[0] - if IMAGE_TOKEN not in content: + n_image = sum( + [1 for x in message['content'] if x['type'] == 'image']) + n_placeholder = content.count(IMAGE_TOKEN) + if n_placeholder == 0: logger.warning( f"""for deepseek-vl model, the user should insert the {IMAGE_TOKEN} to user prompt manually, please read https://lmdeploy.readthedocs.io/en/latest/inference/vl_pipeline.html for more details.""") # noqa - n_images = len( - [1 for x in message['content'] if x['type'] == 'image']) - if n_images == 1: + if n_placeholder != 0 and n_placeholder != n_image: + logger.error( + f'unmatched placeholder and image: {n_placeholder} vs ' + f'{n_image}. Ignore the placeholder') + content = content.replace(IMAGE_TOKEN, '') + n_placeholder = 0 + if n_placeholder == 0: + if n_image == 1: content = f'{IMAGE_TOKEN}{content}' else: content = ''.join([ f'{IMAGE_TOKEN} is Figure {str(i)}.\n' - for i in range(n_images) + for i in range(n_image) ]) + content - else: - logger.error('TODO deepseek-vl') prompt_messages.append(dict(role='user', content=content)) prompt = chat_template.messages2prompt(prompt_messages, sequence_start) return prompt, IMAGE_TOKEN From 18b38e9a0823d4e268b9947fc1e3cc00a392bdbc Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Mon, 9 Dec 2024 18:47:19 +0800 Subject: [PATCH 05/10] fix internvl --- lmdeploy/vl/model/internvl.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/lmdeploy/vl/model/internvl.py b/lmdeploy/vl/model/internvl.py index b5abb55fdd..6bc6dbdf0e 100644 --- a/lmdeploy/vl/model/internvl.py +++ b/lmdeploy/vl/model/internvl.py @@ -175,10 +175,7 @@ def _forward_v1_5(self, inputs, max_batch_size): pixel_values = [ x['pixel_values'] for x in inputs[idx:idx + max_batch_size] ] - split = [ - x['pixel_values'].shape[0] - for x in inputs[idx:idx + max_batch_size] - ] + split = [x.shape[0] for x in pixel_values] pixel_values = torch.cat(pixel_values, dim=0) pixel_values = pixel_values.to(self.model.device, dtype=torch.float16) @@ -202,7 +199,7 @@ def _forward(self, inputs, max_batch_size): pixel_values = [ x['pixel_values'] for x in inputs[idx:idx + max_batch_size] ] - pixel_values = torch.cat(outputs, dim=0) + pixel_values = torch.cat(pixel_values, dim=0) pixel_values = pixel_values.to(self.model.device, dtype=torch.float16) logger.info(f'vision forward shape: {pixel_values.shape}') From db367f40741506bd2b82c820c22d28d1ed388757 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 10 Dec 2024 14:41:34 +0800 Subject: [PATCH 06/10] fix minicpm 2.6 --- lmdeploy/vl/model/minicpmv.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index a31135be4c..decb5a0054 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -199,14 +199,18 @@ def forward(self, patch_attn_mask = torch.zeros((B, 1, max_patches), dtype=torch.bool, device=self.model.device) - for i in range(B): - patch_attn_mask[i, :tgt_sizes[i][0] * tgt_sizes[i][1]] = True logger.info(f'vision forward shape: {pixel_values.shape}') if self.version == '2.5': + for j in range(B): + patch_attn_mask[j, :tgt_sizes[j][0] * + tgt_sizes[j][1]] = True embeddings = self.model.vpm( pixel_values.type(torch.half), patch_attention_mask=patch_attn_mask).last_hidden_state else: + for j in range(B): + patch_attn_mask[j, 0, :tgt_sizes[j][0] * + tgt_sizes[j][1]] = True embeddings = self.model.vpm( pixel_values.type(torch.half), patch_attention_mask=patch_attn_mask, From 715fbb313a4aef9663a92dbd7b97e390e5ca60c3 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 10 Dec 2024 17:33:47 +0800 Subject: [PATCH 07/10] fix minicpm v2.5 --- lmdeploy/vl/model/minicpmv.py | 102 ++++++++++++++-------------------- 1 file changed, 42 insertions(+), 60 deletions(-) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index decb5a0054..368c5051c6 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -3,7 +3,6 @@ import warnings from typing import Dict, List -import numpy as np import torch from PIL.Image import Image from transformers import AutoConfig, AutoModelForCausalLM @@ -109,12 +108,13 @@ def _preprocess_v2_5(self, image: Image, params: Dict = None) -> Dict: # pixel_values, tgt_sizes are list of torch tensors pixel_values, tgt_sizes = self._reshape_by_patch(slice_images) num_patches = len(pixel_values) - return dict(pixel_values=pixel_values, - tgt_sizes=tgt_sizes, - best_grid=best_grid, - num_patches=num_patches, - image_tokens=1, - image_token_id=0) + return dict( + pixel_values=pixel_values, # a list + tgt_sizes=tgt_sizes, # a list + best_grid=best_grid, + num_patches=num_patches, + image_tokens=1, + image_token_id=0) def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict: """image preprocessing for MiniCPM-V-2_6.""" @@ -130,7 +130,6 @@ def _preprocess_v2_6(self, image: Image, params: Dict = None) -> Dict: tgt_sizes = [torch.as_tensor(x) for x in tgt_sizes] grid = self.image_processor.get_sliced_grid( image_size=image.size, max_slice_nums=max_slice_nums) - return dict( pixel_values=pixel_values, # a list tgt_sizes=tgt_sizes, # a list @@ -173,12 +172,24 @@ def forward(self, Return: the message list with forwarding results included """ - for i, message in enumerate(messages): - if 'preprocess' not in message.keys(): - continue - inputs = message['preprocess'] - tgt_sizes = [x['tgt_sizes'] for x in inputs] - pixel_values = [x['pixel_values'] for x in inputs] + # collect preprocess results into a list + inputs = [] + inputs = [ + x['preprocess'] for x in messages if 'preprocess' in x.keys() + ] + # flatten the list + inputs = list(itertools.chain(*inputs)) + outputs = [] + for idx in range(0, len(inputs), max_batch_size): + tgt_sizes = [ + x['tgt_sizes'] for x in inputs[idx:idx + max_batch_size] + ] + pixel_values = [ + x['pixel_values'] for x in inputs[idx:idx + max_batch_size] + ] + num_patches = [ + x['num_patches'] for x in inputs[idx:idx + max_batch_size] + ] # flatten the list tgt_sizes = list(itertools.chain(*tgt_sizes)) pixel_values = list(itertools.chain(*pixel_values)) @@ -207,6 +218,11 @@ def forward(self, embeddings = self.model.vpm( pixel_values.type(torch.half), patch_attention_mask=patch_attn_mask).last_hidden_state + embeddings = self.model.resampler(embeddings, tgt_sizes) + embeddings = torch.split(embeddings, num_patches, 0) + for embedding in embeddings: + embedding = embedding.split(1, dim=0) + outputs.extend([x.squeeze() for x in embedding]) else: for j in range(B): patch_attn_mask[j, 0, :tgt_sizes[j][0] * @@ -215,8 +231,12 @@ def forward(self, pixel_values.type(torch.half), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state - embeddings = self.model.resampler(embeddings, tgt_sizes) - messages[i].update(dict(forward=embeddings)) + embeddings = self.model.resampler(embeddings, tgt_sizes) + embeddings = torch.split(embeddings, num_patches, 0) + for embedding in embeddings: + embedding = embedding.split(1, dim=0) + outputs.extend([_ for _ in embedding]) + messages.append(dict(role='forward', content=outputs)) return messages def proc_messages(self, messages, chat_template, sequence_start): @@ -230,6 +250,7 @@ def proc_messages(self, messages, chat_template, sequence_start): continue if 'preprocess' not in message.keys(): continue + prompts = [] for x in message['preprocess']: prompt = f'{IMAGE_TOKEN}' if x.get('use_image_id', False): @@ -247,11 +268,12 @@ def proc_messages(self, messages, chat_template, sequence_start): [f'{IMAGE_TOKEN}' * grid[0]] * grid[1]) prompt = prompt + slice - prompt += '\n' + prompt += '\n' + prompts.append(prompt) content = [ x['text'] for x in message['content'] if x['type'] == 'text' ] - prompt += content[0] + prompt = ''.join(prompts) + content[0] prompt_messages.append(dict(role='user', content=prompt)) prompt = chat_template.messages2prompt(prompt_messages, sequence_start) return prompt, IMAGE_TOKEN @@ -265,45 +287,5 @@ def to_pytorch(self, messages, chat_template, tokenizer, sequence_start): def to_turbomind(self, messages, chat_template, tokenizer, sequence_start): prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start) - features = [] - for message in messages: - if 'preprocess' not in message.keys(): - continue - assert 'forward' in message.keys() - inputs = message.pop('preprocess', None) - embeddings = message.pop('forward', None) - num_patches = [x['num_patches'] for x in inputs] - embeddings = torch.split(embeddings, num_patches, 0) - embeddings = [emmbedding.split(1) for emmbedding in embeddings] - embeddings = list(itertools.chain(*embeddings)) - features.extend(embeddings) - - # flatten the list - features = list(itertools.chain(*features)) - features = [x.cpu().numpy() for x in features] - - # split prompt into segments and validate data - segs = prompt.split(IMAGE_TOKEN) - assert len(segs) == len(features) + 1, ( - f'the number of {IMAGE_TOKEN} is not equal ' - f'to input images, {len(segs) - 1} vs {len(features)}') - - # tokenizer prompt, and get input_embeddings and input_embedding_ranges - input_ids = [] - begins = [] - ends = [] - IMAGE_DUMMY_TOKEN_INDEX = 0 - for i, seg in enumerate(segs): - if i > 0 and i <= len(features): - image_dim = features[i - 1].shape[0] - begins.append(len(input_ids)) - ends.append(begins[-1] + image_dim) - input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) - seg_ids = tokenizer.encode(seg, - add_bos=((i == 0) and sequence_start)) - input_ids.extend(seg_ids) - ranges = np.stack([begins, ends], axis=1).tolist() - return dict(prompt=prompt, - input_ids=input_ids, - input_embeddings=features, - input_embedding_ranges=ranges) + return super().to_turbomind_aux(messages, prompt, IMAGE_TOKEN, + tokenizer, sequence_start) From 1a6d88f5b5f392e53750ddafceac10fbe8334d33 Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 10 Dec 2024 17:57:01 +0800 Subject: [PATCH 08/10] fix minicpm v2.6 --- lmdeploy/vl/model/minicpmv.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/lmdeploy/vl/model/minicpmv.py b/lmdeploy/vl/model/minicpmv.py index 368c5051c6..129bc9766f 100644 --- a/lmdeploy/vl/model/minicpmv.py +++ b/lmdeploy/vl/model/minicpmv.py @@ -218,11 +218,6 @@ def forward(self, embeddings = self.model.vpm( pixel_values.type(torch.half), patch_attention_mask=patch_attn_mask).last_hidden_state - embeddings = self.model.resampler(embeddings, tgt_sizes) - embeddings = torch.split(embeddings, num_patches, 0) - for embedding in embeddings: - embedding = embedding.split(1, dim=0) - outputs.extend([x.squeeze() for x in embedding]) else: for j in range(B): patch_attn_mask[j, 0, :tgt_sizes[j][0] * @@ -231,11 +226,12 @@ def forward(self, pixel_values.type(torch.half), patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes).last_hidden_state - embeddings = self.model.resampler(embeddings, tgt_sizes) - embeddings = torch.split(embeddings, num_patches, 0) - for embedding in embeddings: - embedding = embedding.split(1, dim=0) - outputs.extend([_ for _ in embedding]) + + embeddings = self.model.resampler(embeddings, tgt_sizes) + embeddings = torch.split(embeddings, num_patches, 0) + for embedding in embeddings: + embedding = embedding.split(1, dim=0) + outputs.extend([x.squeeze() for x in embedding]) messages.append(dict(role='forward', content=outputs)) return messages @@ -269,6 +265,9 @@ def proc_messages(self, messages, chat_template, sequence_start): grid[1]) prompt = prompt + slice prompt += '\n' + else: + prompt = (prompt + + '\n' if self.version == '2.6' else prompt) prompts.append(prompt) content = [ x['text'] for x in message['content'] if x['type'] == 'text' From f0a442266ec388ac9f7f06b5b40d075df2588bde Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 10 Dec 2024 18:13:25 +0800 Subject: [PATCH 09/10] update llava_next.py --- lmdeploy/vl/model/llava_next.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/lmdeploy/vl/model/llava_next.py b/lmdeploy/vl/model/llava_next.py index ab58b105da..3b73f514fe 100644 --- a/lmdeploy/vl/model/llava_next.py +++ b/lmdeploy/vl/model/llava_next.py @@ -4,7 +4,6 @@ from typing import Dict, List import torch -from transformers import AutoProcessor from lmdeploy.utils import get_logger from lmdeploy.vl.model.llava_hf import VISION_MODELS, LlavaHfVisionModel @@ -20,12 +19,7 @@ class LlavaNextVisionModel(LlavaHfVisionModel): _arch = 'LlavaNextForConditionalGeneration' def build_preprocessor(self): - processor = AutoProcessor.from_pretrained(self.model_path, - trust_remote_code=True) - if hasattr(processor, 'tokenizer'): - del processor.tokenizer - processor.prtokenizer = None - self.processor = processor.image_processor + super().build_preprocessor() # build the model with empty weights. The model will be used in # `preprocess` to get the image token number from accelerate import init_empty_weights @@ -94,10 +88,10 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: patch_size=self.hf_config.vision_config.image_size, ) for imsize in result['image_sizes'] ] - # TODO(remove hardcode 576) + hidden_size = self.hf_config.text_config.hidden_size fake_image_features = torch.zeros( - [image_num_patches[0], 576, hidden_size]) + [image_num_patches[0], self.n_token_per_image, hidden_size]) image_sizes = result['image_sizes'] image_newline = torch.randn(self.hf_config.text_config.hidden_size) strategy = self.hf_config.vision_feature_select_strategy From ee022adea2ffce65bc1e6e75ec2a53dfc3d89bba Mon Sep 17 00:00:00 2001 From: lvhan028 Date: Tue, 10 Dec 2024 19:17:27 +0800 Subject: [PATCH 10/10] remove hardcode from xcomposer2.py --- lmdeploy/vl/model/xcomposer2.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py index 8df093c0fe..67101216e0 100644 --- a/lmdeploy/vl/model/xcomposer2.py +++ b/lmdeploy/vl/model/xcomposer2.py @@ -197,20 +197,24 @@ def build_model(self): def _preprocess_2d5(self, image: Image, params: Dict) -> Dict: """image preprocessing for internlm-xcomposer2d5-7b.""" hd_num = params.get('hd_num', 24) - pixel_values = self.HD_transform(image, hd_num=hd_num) - pixel_values = self.vis_processor(pixel_values).unsqueeze(0).half() - return pixel_values + image = self.HD_transform(image, hd_num=hd_num) + pixel_values = self.vis_processor(image).unsqueeze(0).half() + w, h = image.size + n_token_per_image = int((h * w + 1) * 400 + 1 + (h + 1) * 20) + return pixel_values, n_token_per_image def _preprocess_7b(self, image: Image, params: Dict) -> Dict: """image preprocessing for internlm-xcomposer2-7b.""" pixel_values = self.vis_processor(image).unsqueeze(0).half() - return pixel_values + return pixel_values, 256 def _preprocess_4khd_7b(self, image: Image, params: Dict) -> Dict: """image preprocessing for internlm-xcomposer2-4khd-7b.""" - pixel_values = self.HD_transform(image, hd_num=25) - pixel_values = self.vis_processor(pixel_values).unsqueeze(0).half() - return pixel_values + image = self.HD_transform(image, hd_num=25) + pixel_values = self.vis_processor(image).unsqueeze(0).half() + w, h = image.size + n_token_per_image = int((h * w + 1) * 144 + 1 + (h + 1) * 12) + return pixel_values, n_token_per_image def preprocess(self, messages: List[Dict]) -> List[Dict]: """refer to `super().preprocess() for spec.""" @@ -218,11 +222,11 @@ def preprocess(self, messages: List[Dict]) -> List[Dict]: outputs = [] for image, params in images: image = image.convert('RGB') - pixel_values = self.preprocess_func(image, params) + pixel_values, n_token = self.preprocess_func(image, params) outputs.append( dict(pixel_values=pixel_values, image_size=image.size, - image_tokens=576, + image_tokens=n_token, image_token_id=0)) messages.append(dict(role='preprocess', content=outputs)) return messages