From 464833415a591b9bfae0d27652057eb42f67c4ee Mon Sep 17 00:00:00 2001 From: "yi.chu" Date: Mon, 2 Dec 2024 19:34:07 +0800 Subject: [PATCH] [MiniCPMV] fix precision bug --- models/MiniCPM-V-2_6/README.md | 2 +- models/MiniCPM-V-2_6/compile/README.md | 9 +- models/MiniCPM-V-2_6/compile/export_onnx.py | 11 +- models/MiniCPM-V-2_6/compile/run_compile.sh | 2 +- models/MiniCPM-V-2_6/python_demo/README.md | 6 +- models/MiniCPM-V-2_6/python_demo/chat.cpp | 19 +- models/MiniCPM-V-2_6/python_demo/pipeline.py | 59 +-- .../image_processing_minicpmv.py | 418 ++++++++++++++++++ .../processor_config/preprocessor_config.json | 24 + .../processor_config/processing_minicpmv.py | 240 ++++++++++ .../tokenization_minicpmv_fast.py | 0 .../tokenizer.json | 0 .../tokenizer_config.json | 0 .../vocab.json | 0 14 files changed, 741 insertions(+), 49 deletions(-) create mode 100644 models/MiniCPM-V-2_6/support/processor_config/image_processing_minicpmv.py create mode 100644 models/MiniCPM-V-2_6/support/processor_config/preprocessor_config.json create mode 100644 models/MiniCPM-V-2_6/support/processor_config/processing_minicpmv.py rename models/MiniCPM-V-2_6/support/{token_config => processor_config}/tokenization_minicpmv_fast.py (100%) rename models/MiniCPM-V-2_6/support/{token_config => processor_config}/tokenizer.json (100%) mode change 100755 => 100644 rename models/MiniCPM-V-2_6/support/{token_config => processor_config}/tokenizer_config.json (100%) mode change 100755 => 100644 rename models/MiniCPM-V-2_6/support/{token_config => processor_config}/vocab.json (100%) mode change 100755 => 100644 diff --git a/models/MiniCPM-V-2_6/README.md b/models/MiniCPM-V-2_6/README.md index 22090ce..6c67fe2 100755 --- a/models/MiniCPM-V-2_6/README.md +++ b/models/MiniCPM-V-2_6/README.md @@ -32,7 +32,7 @@ python3 export_onnx.py --model_path your_minicpmv_path 此处介绍如何将onnx模型编译成bmodel。也可以省去编译模型这一步,直接下载编译好的模型: ``` shell -python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4.bmodel +python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4_seq1024.bmodel ``` #### 1. 下载docker,启动容器 diff --git a/models/MiniCPM-V-2_6/compile/README.md b/models/MiniCPM-V-2_6/compile/README.md index 43370b2..177f99c 100755 --- a/models/MiniCPM-V-2_6/compile/README.md +++ b/models/MiniCPM-V-2_6/compile/README.md @@ -3,15 +3,18 @@ ## Export onnx ```shell -pip install transformers_stream_generator einops tiktoken accelerate torch==2.0.1+cpu torchvision==0.15.2 transformers==4.40.0 +pip install torch==2.5.1 --index-url https://download.pytorch.org/whl/cpu +pip install transformers_stream_generator einops tiktoken accelerate transformers==4.40.0 cp files/MiniCPM-V-2_6/modeling_qwen2.py /usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/ cp files/MiniCPM-V-2_6/resampler.py your_torch_model cp files/MiniCPM-V-2_6/modeling_navit_siglip.py your_torch_model ``` your_torch_model是你模型的位置 ```shell -python3 export_onnx.py --model_path your_torch_model --seq_length 512 --device cpu +python3 export_onnx.py --model_path your_torch_model --seq_length 512 --device cpu --image_file ../python_demo/test0.jpg ``` +* image_file:image_file为真实图片的路径,导出模型时,输入size会固定为该图片的size。`image_file请输入你实际的图片` +* 目前不支持多图,不支持图片size可变 ## Compile bmodel 使用io_alone @@ -23,7 +26,7 @@ python3 export_onnx.py --model_path your_torch_model --seq_length 512 --device c 也可以直接下载编译好的模型,不用自己编译 ```shell pip3 install dfss -python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/minicpm_int4_seq512_1dev.bmodel +python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4_seq1024.bmodel ``` ### python demo diff --git a/models/MiniCPM-V-2_6/compile/export_onnx.py b/models/MiniCPM-V-2_6/compile/export_onnx.py index ee8b36a..42d4935 100755 --- a/models/MiniCPM-V-2_6/compile/export_onnx.py +++ b/models/MiniCPM-V-2_6/compile/export_onnx.py @@ -286,7 +286,6 @@ def test_net_with_mask(): tgt_sizes = inputs["tgt_sizes"][0].to(dtype).to(device) vit_infer = VisionTransformer(pixel_values, tgt_sizes) vit_embeds = vit_infer(pixel_values) # [1, 64, 3584] - vit_token_length = vit_embeds.shape[1] msgs = [{'role': 'user', 'content': '(./)\n请详细描述一下图片内容'}] prompts_lists = processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) @@ -295,11 +294,11 @@ def test_net_with_mask(): [[image]], max_slice_nums=MAX_SLICE_NUMS, use_image_id=None, - return_tensors="pt", + return_tensors="pt", max_length=8192 ).to(device) ids = inputs.input_ids[0] - first_offset = int(torch.where(ids==128244)[0][0]) + image_offsets = torch.where(ids==128244)[0].tolist() ids = ids.tolist() ID_IM_END = tokenizer.convert_tokens_to_ids("<|im_end|>") @@ -308,8 +307,10 @@ def test_net_with_mask(): input_ids = torch.tensor(ids).view(SEQ_LENGTH).to(device) out = embed(input_ids).view(1, SEQ_LENGTH, HIDDEN_SIZE) # [1, 512, 3584] - for i in range(vit_embeds.shape[0]): - out[:, first_offset+i*vit_token_length:first_offset+(i+1)*vit_token_length, :] = vit_embeds[i] + patch_num = pixel_values.shape[0] + patch_size = len(image_offsets) // patch_num + for i in range(patch_num): + out[:, image_offsets[i*patch_size]:image_offsets[i*patch_size]+patch_size, :] = vit_embeds[i] position_ids = list(range(token_len)) + (SEQ_LENGTH - token_len) * [0] position_ids = torch.tensor([position_ids]).to(device) diff --git a/models/MiniCPM-V-2_6/compile/run_compile.sh b/models/MiniCPM-V-2_6/compile/run_compile.sh index 78c7967..ad0986a 100644 --- a/models/MiniCPM-V-2_6/compile/run_compile.sh +++ b/models/MiniCPM-V-2_6/compile/run_compile.sh @@ -94,7 +94,7 @@ sudo cp files/${model_name_upper}/resampler.py ${model_path} sudo cp files/${model_name_upper}/modeling_navit_siglip.py ${model_path} echo "export onnx..." -python export_onnx.py --model_path ${model_path} --seq_length ${seq_length} +python export_onnx.py --model_path ${model_path} --seq_length ${seq_length} --image_file ../python_demo/test0.jpg echo "compile model..." source ${tpu_mlir_path}/envsetup.sh diff --git a/models/MiniCPM-V-2_6/python_demo/README.md b/models/MiniCPM-V-2_6/python_demo/README.md index cc0d2cb..4b16e90 100755 --- a/models/MiniCPM-V-2_6/python_demo/README.md +++ b/models/MiniCPM-V-2_6/python_demo/README.md @@ -9,7 +9,7 @@ pip3 install gradio==3.39.0 mdtex2html==1.2.0 dfss 如果不打算自己编译模型,可以直接用下载好的模型 ``` -python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4.bmodel +python3 -m dfss --url=open@sophgo.com:/ext_model_information/LLM/LLM-TPU/minicpmv26_bm1684x_int4_seq1024.bmodel ``` 编译库文件 @@ -20,5 +20,5 @@ cd build && cmake .. && make && cp *cpython* .. && cd .. # python demo ``` -python3 pipeline.py --model_path minicpmv26_bm1684x_int4.bmodel --tokenizer_path ../support/token_config/ --devid 0 -``` \ No newline at end of file +python3 pipeline.py --model_path minicpmv26_bm1684x_int4_seq1024.bmodel --processor_path ../support/processor_config/ --devid 0 +``` diff --git a/models/MiniCPM-V-2_6/python_demo/chat.cpp b/models/MiniCPM-V-2_6/python_demo/chat.cpp index 1a647f5..9c90af4 100755 --- a/models/MiniCPM-V-2_6/python_demo/chat.cpp +++ b/models/MiniCPM-V-2_6/python_demo/chat.cpp @@ -48,7 +48,7 @@ class MiniCPMV { void init(int devid, std::string model_path); void deinit(); int forward_first(std::vector &tokens, std::vector &pixel_values, - int img_offset); + std::vector &img_offsets, int patch_num); int forward_next(); std::mt19937 sgen; @@ -160,7 +160,7 @@ void MiniCPMV::deinit() { } int MiniCPMV::forward_first(std::vector &tokens, - std::vector &pixel_values, int img_offset) { + std::vector &pixel_values, std::vector &img_offsets, int patch_num) { std::vector input_ids(SEQLEN, 0); std::vector position_id(SEQLEN, 0); std::vector attention_mask(SEQLEN * SEQLEN, ATTENTION_MASK); @@ -185,7 +185,7 @@ int MiniCPMV::forward_first(std::vector &tokens, bm_memcpy_s2d(bm_handle, in_mem, (void *)input_ids.data()); net_launch(net_embed); // prefil embedding - if (pixel_values.size() * sizeof(float) == IMAGE_BYTES && img_offset > 0) { + if (pixel_values.size() * sizeof(float) == IMAGE_BYTES && img_offsets.size() > 0) { d2d(dev_buffer, out_mem); out_mem = dev_buffer; // forward vision transformer @@ -195,10 +195,15 @@ int MiniCPMV::forward_first(std::vector &tokens, net_launch(net_vit); // concatenante texting embedding and image embedding - int dst_offset = img_offset * HIDDEN_SIZE * 2; - int vit_size = bm_mem_get_device_size(vit_out_mem); - bm_memcpy_d2d_byte(bm_handle, out_mem, dst_offset, vit_out_mem, 0, - vit_size); + int type_byte = sizeof(uint16_t); + int patch_bytes = bm_mem_get_device_size(vit_out_mem) / patch_num; + int patch_size = net_vit->stages[0].output_shapes[0].dims[1]; + for (int i = 0; i < patch_num; i++) { + int vit_offset = i * patch_bytes; + int dst_offset = img_offsets[i * patch_size] * HIDDEN_SIZE * type_byte; + + bm_memcpy_d2d_byte(bm_handle, out_mem, dst_offset, vit_out_mem, vit_offset, patch_bytes); + } } // forward blocks diff --git a/models/MiniCPM-V-2_6/python_demo/pipeline.py b/models/MiniCPM-V-2_6/python_demo/pipeline.py index 6fe66a9..328b5d2 100755 --- a/models/MiniCPM-V-2_6/python_demo/pipeline.py +++ b/models/MiniCPM-V-2_6/python_demo/pipeline.py @@ -3,7 +3,7 @@ import argparse from PIL import Image import torchvision.transforms as T -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from torchvision.transforms.functional import InterpolationMode import chat import os @@ -38,18 +38,13 @@ def __init__(self, args): self.device = args.devid # load tokenizer - print("Load " + args.tokenizer_path + " ...") - self.tokenizer = AutoTokenizer.from_pretrained( - args.tokenizer_path, trust_remote_code=True - ) - self.tokenizer.decode([0]) # warm up - + print("Load " + args.processor_path + " ...") self.processor = AutoProcessor.from_pretrained( - args.tokenizer_path, trust_remote_code=True + args.processor_path, trust_remote_code=True ) - # preprocess parameters, such as prompt & tokenizer - self.system_prompt = '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n' + self.tokenizer = self.processor.tokenizer + self.tokenizer.decode([0]) # warm up # load model self.model = chat.MiniCPMV() @@ -58,27 +53,33 @@ def __init__(self, args): self.ID_EOS = self.tokenizer.eos_token_id self.ID_IM_END = self.tokenizer.convert_tokens_to_ids("<|im_end|>") + # parameters + self.MAX_SLICE_NUMS = self.processor.image_processor.max_slice_nums + def encode(self): if not self.image_str: inserted_image_str = "" - self.pixel_values = [] else: inserted_image_str = "(./)\n" - image = Image.open(sample_image_file).convert('RGB') - inputs = processor.image_processor([image], do_pad=True, max_slice_nums=MAX_SLICE_NUMS, return_tensors="pt") - pixel_values = inputs["pixel_values"][0] - - msgs = [{'role': 'user', 'content': '{}{}'.format(self.inserted_image_str, self.input_str)}] - prompt = self.system_prompt + self.input_str + "<|im_end|>\n<|im_start|>assistant\n" - self.input_ids = self.tokenizer.encode(prompt) - self.image_offset = 0 - self.pixel_values = [] - return - self.pixel_values = load_image(self.image_str).flatten().tolist() - msgs = [{'role': 'user', 'content': '(./)\n{}'.format(self.input_str)}] - self.input_ids = processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)[0] - self.image_offset = 0 - breakpoint() + image = Image.open(self.image_str).convert('RGB') + + msgs = [{'role': 'user', 'content': '{}{}'.format(inserted_image_str, self.input_str)}] + prompts_lists = self.processor.tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) + + inputs = self.processor( + prompts_lists, + [[image]] if image else None, + max_slice_nums=self.MAX_SLICE_NUMS, + use_image_id=None, + return_tensors="pt", + max_length=8192 + ) + self.input_ids = inputs.input_ids[0] + self.pixel_values = torch.cat(inputs["pixel_values"][0], dim=0).flatten().tolist() + self.image_offsets = torch.where(self.input_ids==128244)[0].tolist() + self.patch_num = len(inputs["pixel_values"][0]) + + self.input_ids = self.input_ids.tolist() def chat(self): """ @@ -107,7 +108,7 @@ def chat(self): # Chat first_start = time.time() token = self.model.forward_first( - self.input_ids, self.pixel_values, self.image_offset) + self.input_ids, self.pixel_values, self.image_offsets, self.patch_num) first_end = time.time() tok_num = 1 # Following tokens @@ -142,8 +143,8 @@ def main(args): parser = argparse.ArgumentParser() parser.add_argument('-m', '--model_path', type=str, required=True, help='path to the bmodel file') - parser.add_argument('-t', '--tokenizer_path', type=str, - default="../support/token_config", help='path to the tokenizer file') + parser.add_argument('-p', '--processor_path', type=str, + default="../support/processor_config", help='path to the processor file') parser.add_argument('-d', '--devid', type=int, default=0, help='device ID to use') args = parser.parse_args() diff --git a/models/MiniCPM-V-2_6/support/processor_config/image_processing_minicpmv.py b/models/MiniCPM-V-2_6/support/processor_config/image_processing_minicpmv.py new file mode 100644 index 0000000..84f9c5e --- /dev/null +++ b/models/MiniCPM-V-2_6/support/processor_config/image_processing_minicpmv.py @@ -0,0 +1,418 @@ +from typing import Optional, Union, Dict, Any, List + +import torch +import math +import PIL.Image +import PIL.ImageSequence +import numpy as np +import PIL +from PIL import Image + +from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers import AutoImageProcessor +from transformers.image_transforms import to_channel_dimension_format +from transformers.image_utils import ( + ImageInput, + make_list_of_images, + valid_images, + is_torch_tensor, + is_batched, + to_numpy_array, + infer_channel_dimension_format, + ChannelDimension +) + + +def recursive_converter(converter, value): + if isinstance(value, list): + new_value = [] + for v in value: + new_value += [recursive_converter(converter, v)] + return new_value + else: + return converter(value) + + +class MiniCPMVBatchFeature(BatchFeature): + r""" + Extend from BatchFeature for supporting various image size + """ + def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + if tensor_type is None: + return self + + is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) + + def converter(value): + try: + if not is_tensor(value): + tensor = as_tensor(value) + return tensor + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + + for key, value in self.items(): + self[key] = recursive_converter(converter, value) + return self + + def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature": + requires_backends(self, ["torch"]) + import torch + + def cast_tensor(v): + # check if v is a floating point + if torch.is_floating_point(v): + # cast and send to device + return v.to(*args, **kwargs) + elif device is not None: + return v.to(device=device) + else: + return v + + new_data = {} + device = kwargs.get("device") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + new_data[k] = recursive_converter(cast_tensor, v) + self.data = new_data + return self + + +class MiniCPMVImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + max_slice_nums=9, + scale_resolution=448, + patch_size=14, + **kwargs): + super().__init__(**kwargs) + self.max_slice_nums = max_slice_nums + self.scale_resolution = scale_resolution + self.patch_size = patch_size + self.use_image_id = kwargs.pop("use_image_id", False) + self.image_feature_size = kwargs.pop("image_feature_size", 64) + self.im_start_token = kwargs.pop("im_start", "") + self.im_end_token = kwargs.pop("im_end", "") + self.slice_start_token = kwargs.pop("slice_start", "") + self.slice_end_token = kwargs.pop("slice_end", "") + self.unk_token = kwargs.pop("unk", "") + self.im_id_start = kwargs.pop("im_id_start", "") + self.im_id_end = kwargs.pop("im_id_end", "") + self.slice_mode = kwargs.pop("slice_mode", True) + self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5])) + self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5])) + self.version = kwargs.pop("version", 2.0) + + def ensure_divide(self, length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + def find_best_resize(self, + original_size, + scale_resolution, + patch_size, + allow_upscale=False): + width, height = original_size + if (width * height > + scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = self.ensure_divide(width, patch_size) + best_height = self.ensure_divide(height, patch_size) + return (best_width, best_height) + + def get_refine_size(self, + original_size, + grid, + scale_resolution, + patch_size, + allow_upscale=False): + width, height = original_size + grid_x, grid_y = grid + + refine_width = self.ensure_divide(width, grid_x) + refine_height = self.ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = self.find_best_resize((grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale) + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + return refine_size + + def split_to_patches(self, image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + return patches + + def slice_image( + self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False + ): + original_size = image.size + source_image = None + best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split) + patches = [] + + if best_grid is None: + # dont need to slice, upsample + best_size = self.find_best_resize( + original_size, scale_resolution, patch_size, allow_upscale=True + ) + source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC) + else: + # source image, down-sampling and ensure divided by patch_size + best_resize = self.find_best_resize(original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC) + refine_size = self.get_refine_size( + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True + ) + refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC) + patches = self.split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + def get_grid_placeholder(self, grid): + if grid is None: + return "" + slice_image_placeholder = ( + self.slice_start_token + + self.unk_token * self.image_feature_size + + self.slice_end_token + ) + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(slice_image_placeholder) + slices.append("".join(lines)) + + slice_placeholder = "\n".join(slices) + return slice_placeholder + + def get_image_id_placeholder(self, idx=0): + return f"{self.im_id_start}{idx}{self.im_id_end}" + + def get_sliced_images(self, image, max_slice_nums=None): + slice_images = [] + + if not self.slice_mode: + return [image] + + max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) + assert max_slice_nums > 0 + source_image, patches, sliced_grid = self.slice_image( + image, + max_slice_nums, # default: 9 + self.scale_resolution, # default: 448 + self.patch_size # default: 14 + ) + + slice_images.append(source_image) + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + slice_images.append(patches[i][j]) + return slice_images + + def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False): + original_width, original_height = image_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + if multiple <= 1 or nerver_split: + return None + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + candidate_grids = [] + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + return best_grid + + def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None): + max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) + assert max_slice_nums > 0 + grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums) + + image_placeholder = ( + self.im_start_token + + self.unk_token * self.image_feature_size + + self.im_end_token + ) + use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id) + if use_image_id: + final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder + else: + final_placeholder = image_placeholder + + if self.slice_mode: + final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid) + return final_placeholder + + def to_pil_image(self, image, rescale=None) -> PIL.Image.Image: + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): + The image to convert to the PIL Image format. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will + default to `True` if the image type is a floating type, `False` otherwise. + """ + if isinstance(image, PIL.Image.Image): + return image + if is_torch_tensor(image): + image = image.numpy() + + if isinstance(image, np.ndarray): + if rescale is None: + # rescale default to the array being of floating type. + rescale = isinstance(image.flat[0], np.floating) + # If the channel as been moved to first dim, we put it back at the end. + if image.ndim == 3 and image.shape[0] in [1, 3]: + image = image.transpose(1, 2, 0) + if rescale: + image = image * 255 + image = image.astype(np.uint8) + return PIL.Image.fromarray(image) + return image + + def reshape_by_patch(self, image): + """ + :param image: shape [3, H, W] + :param patch_size: + :return: [3, patch_size, HW/patch_size] + """ + image = torch.from_numpy(image) + patch_size = self.patch_size + patches = torch.nn.functional.unfold( + image, + (patch_size, patch_size), + stride=(patch_size, patch_size) + ) + + patches = patches.reshape(image.size(0), patch_size, patch_size, -1) + patches = patches.permute(0, 1, 3, 2).reshape(1, image.size(0), patch_size, -1) + return patches.numpy() + + def preprocess( + self, + images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]], + do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5 + max_slice_nums: int = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> MiniCPMVBatchFeature: + if isinstance(images, Image.Image): + images_list = [[images]] + elif isinstance(images[0], Image.Image): + images_list = [images] + else: + images_list = images + + new_images_list = [] + image_sizes_list = [] + tgt_sizes_list = [] + + for _images in images_list: + if _images is None or len(_images) == 0: + new_images_list.append([]) + image_sizes_list.append([]) + tgt_sizes_list.append([]) + continue + if not valid_images(_images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + _images = [self.to_pil_image(image).convert("RGB") for image in _images] + input_data_format = infer_channel_dimension_format(np.array(_images[0])) + + new_images = [] + image_sizes = [image.size for image in _images] + tgt_sizes = [] + for image in _images: + image_patches = self.get_sliced_images(image, max_slice_nums) + image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches] + image_patches = [ + self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format) + for image in image_patches + ] + image_patches = [ + to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format) + for image in image_patches + ] + for slice_image in image_patches: + new_images.append(self.reshape_by_patch(slice_image)) + tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))) + + if tgt_sizes: + tgt_sizes = np.vstack(tgt_sizes) + + new_images_list.append(new_images) + image_sizes_list.append(image_sizes) + tgt_sizes_list.append(tgt_sizes) + return MiniCPMVBatchFeature( + data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list}, tensor_type=return_tensors + ) + +AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor) diff --git a/models/MiniCPM-V-2_6/support/processor_config/preprocessor_config.json b/models/MiniCPM-V-2_6/support/processor_config/preprocessor_config.json new file mode 100644 index 0000000..7111b61 --- /dev/null +++ b/models/MiniCPM-V-2_6/support/processor_config/preprocessor_config.json @@ -0,0 +1,24 @@ +{ + "image_processor_type": "MiniCPMVImageProcessor", + "auto_map": { + "AutoProcessor": "processing_minicpmv.MiniCPMVProcessor", + "AutoImageProcessor": "image_processing_minicpmv.MiniCPMVImageProcessor" + }, + "processor_class": "MiniCPMVProcessor", + "max_slice_nums": 9, + "scale_resolution": 448, + "patch_size": 14, + "use_image_id": true, + "image_feature_size": 64, + "im_start": "", + "im_end": "", + "slice_start": "", + "slice_end": "", + "unk": "", + "im_id_start": "", + "im_id_end": "", + "slice_mode": true, + "norm_mean": [0.5, 0.5, 0.5], + "norm_std": [0.5, 0.5, 0.5], + "version": 2.6 +} \ No newline at end of file diff --git a/models/MiniCPM-V-2_6/support/processor_config/processing_minicpmv.py b/models/MiniCPM-V-2_6/support/processor_config/processing_minicpmv.py new file mode 100644 index 0000000..6e3cbbd --- /dev/null +++ b/models/MiniCPM-V-2_6/support/processor_config/processing_minicpmv.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for MiniCPMV. +""" + +from typing import List, Optional, Union, Dict, Any +import torch +import re + +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device + +from .image_processing_minicpmv import MiniCPMVBatchFeature + + +class MiniCPMVProcessor(ProcessorMixin): + r""" + Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor. + + [`MiniCPMVProcessor`] offers all the functionalities of [`MiniCPMVImageProcessor`] and [`LlamaTokenizerWrapper`]. See the + [`~MiniCPMVProcessor.__call__`] and [`~MiniCPMVProcessor.decode`] for more information. + + Args: + image_processor ([`MiniCPMVImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerWrapper`], *optional*): + The tokenizer is a required input. + """ + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None): + super().__init__(image_processor, tokenizer) + self.version = image_processor.version + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + images: ImageInput = None, + max_length: Optional[int] = None, + do_pad: Optional[bool] = True, + max_slice_nums: int = None, + use_image_id: bool = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + **kwargs + ) -> MiniCPMVBatchFeature: + + if images is not None: + image_inputs = self.image_processor(images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors) + return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + output_ids = args[0] + result_text = [] + for result in output_ids: + result = result[result != 0] + if result[0] == self.tokenizer.bos_id: + result = result[1:] + if result[-1] == self.tokenizer.eos_id: + result = result[:-1] + result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip()) + return result_text + # return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + result = args[0] + result = result[result != 0] + if result[0] == self.tokenizer.bos_id: + result = result[1:] + if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id): + result = result[:-1] + return self.tokenizer.decode(result, *args[1:], **kwargs).strip() + + def _convert( + self, input_str, max_inp_length: Optional[int] = None + ): + if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False): + input_ids = self.tokenizer.encode(input_str) + else: + input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_str) + if max_inp_length is not None: + input_ids = input_ids[:max_inp_length] + input_ids = torch.tensor(input_ids, dtype=torch.int32) + + start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id) + end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id) + + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + + image_bounds = torch.hstack( + [ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ] + ) + return input_ids, image_bounds + + def _convert_images_texts_to_inputs( + self, + images, + texts: Union[str, List[str]], + truncation=None, + max_length=None, + max_slice_nums=None, + use_image_id=None, + return_tensors=None, + **kwargs + ): + if images is None or not len(images): + model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs) + return MiniCPMVBatchFeature(data={**model_inputs}) + + pattern = "(./)" + images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"] + + if isinstance(texts, str): + texts = [texts] + input_ids_list = [] + image_bounds_list = [] + for index, text in enumerate(texts): + image_tags = re.findall(pattern, text) + assert len(image_tags) == len(image_sizes[index]) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = final_text + text_chunks[i] + \ + self.image_processor.get_slice_image_placeholder( + image_sizes[index][i], + i, + max_slice_nums, + use_image_id + ) + final_text += text_chunks[-1] + input_ids, image_bounds = self._convert(final_text, max_length) + input_ids_list.append(input_ids) + image_bounds_list.append(image_bounds) + padded_input_ids, padding_lengths = self.pad( + input_ids_list, + padding_side="left" + ) + for i, length in enumerate(padding_lengths): + image_bounds_list[i] = image_bounds_list[i] + length + attention_mask = padded_input_ids.ne(0) + + return MiniCPMVBatchFeature(data={ + "input_ids": padded_input_ids, + "attention_mask": attention_mask, + "pixel_values": images, + "image_sizes": image_sizes, + "image_bound": image_bounds_list, + "tgt_sizes": tgt_sizes + }) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + + def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"): + items = [] + if isinstance(inputs[0], list): + assert isinstance(inputs[0][0], torch.Tensor) + for it in inputs: + for tr in it: + items.append(tr) + else: + assert isinstance(inputs[0], torch.Tensor) + items = inputs + + batch_size = len(items) + shape = items[0].shape + dim = len(shape) + assert dim <= 2 + if max_length is None: + max_length = 0 + max_length = max(max_length, max(item.shape[-1] for item in items)) + min_length = min(item.shape[-1] for item in items) + dtype = items[0].dtype + + if dim == 0: + return torch.stack([item for item in items], dim=0), [0] + elif dim == 1: + if max_length == min_length: + return torch.stack([item for item in items], dim=0), [0] * batch_size + tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value + else: + tensor = ( + torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + + padding_value + ) + + padding_length = [] + for i, item in enumerate(items): + if dim == 1: + if padding_side == "left": + tensor[i, -len(item) :] = item.clone() + else: + tensor[i, : len(item)] = item.clone() + elif dim == 2: + if padding_side == "left": + tensor[i, -len(item) :, :] = item.clone() + else: + tensor[i, : len(item), :] = item.clone() + padding_length.append(tensor.shape[-1] - len(item)) + + return tensor, padding_length diff --git a/models/MiniCPM-V-2_6/support/token_config/tokenization_minicpmv_fast.py b/models/MiniCPM-V-2_6/support/processor_config/tokenization_minicpmv_fast.py similarity index 100% rename from models/MiniCPM-V-2_6/support/token_config/tokenization_minicpmv_fast.py rename to models/MiniCPM-V-2_6/support/processor_config/tokenization_minicpmv_fast.py diff --git a/models/MiniCPM-V-2_6/support/token_config/tokenizer.json b/models/MiniCPM-V-2_6/support/processor_config/tokenizer.json old mode 100755 new mode 100644 similarity index 100% rename from models/MiniCPM-V-2_6/support/token_config/tokenizer.json rename to models/MiniCPM-V-2_6/support/processor_config/tokenizer.json diff --git a/models/MiniCPM-V-2_6/support/token_config/tokenizer_config.json b/models/MiniCPM-V-2_6/support/processor_config/tokenizer_config.json old mode 100755 new mode 100644 similarity index 100% rename from models/MiniCPM-V-2_6/support/token_config/tokenizer_config.json rename to models/MiniCPM-V-2_6/support/processor_config/tokenizer_config.json diff --git a/models/MiniCPM-V-2_6/support/token_config/vocab.json b/models/MiniCPM-V-2_6/support/processor_config/vocab.json old mode 100755 new mode 100644 similarity index 100% rename from models/MiniCPM-V-2_6/support/token_config/vocab.json rename to models/MiniCPM-V-2_6/support/processor_config/vocab.json