Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MODEL] Intern vl2 support #970

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gptqmodel/models/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def get_best_device(backend: BACKEND=BACKEND.AUTO) -> torch.device:
"baichuan",
"internlm",
"internlm2",
"internvl_chat",
"qwen",
"xverse",
"deci",
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys

from .definitions.internvl_chat import InternVLChatGPTQ

# TODO: waiting for pytorch implementgation of aten ops for MPS
if sys.platform == "darwin":
Expand Down Expand Up @@ -94,6 +95,7 @@
"baichuan": BaiChuanGPTQ,
"internlm": InternLMGPTQ,
"internlm2": InternLM2GPTQ,
"internvl_chat": InternVLChatGPTQ,
"qwen": QwenGPTQ,
"mistral": MistralGPTQ,
"Yi": YiGPTQ,
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .hymba import HymbaGPTQ
from .internlm import InternLMGPTQ
from .internlm2 import InternLM2GPTQ
from .internvl_chat import InternVLChatGPTQ
from .llama import LlamaGPTQ
from .longllama import LongLlamaGPTQ
from .minicpm3 import MiniCPM3GPTQ
Expand Down
153 changes: 153 additions & 0 deletions gptqmodel/models/definitions/internvl_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from typing import Dict

import torch

from transformers import AutoTokenizer
from ..base import BaseGPTQModel
from ...utils.calibration import batched
import torchvision.transforms as T
from torchvision.transforms.functional import InterpolationMode

from ...utils.image import fetch_image
from ...utils.model import MODALITY

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform


def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio


def dynamic_preprocess(image, min_num=1, max_num=12, image_size=448, use_thumbnail=False):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
# calculate the existing image aspect ratio
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
# find the closest aspect ratio to the target
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
# calculate the target width and height
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
# resize the image
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
# split the image
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images


def load_image(image, input_size=448, max_num=12):
image = image.convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values


class InternVLChatGPTQ(BaseGPTQModel):
IMG_START_TOKEN = '<img>'
IMG_END_TOKEN = '</img>'
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'

require_pkgs_version = ["transformers<=4.44.2", "timm>=1.0.12", "torchvision>=0.20.1"]

base_modules = ["language_model.model.tok_embeddings", "language_model.model.norm"]

layers_node = "language_model.model.layers"
layer_type = "InternLM2DecoderLayer"
layer_modules = [
["attention.wqkv", "attention.wo"],

["feed_forward.w1", "feed_forward.w3"],
["feed_forward.w2"],
]

modality = [MODALITY.TEXT, MODALITY.IMAGE_TO_TEXT]

def preprocess_dataset(self, sample: Dict) -> Dict:
template = self.model.conv_template
template.append_message(template.roles[0], sample["question"])
template.append_message(template.roles[1], sample["answer"])
query = template.get_prompt()

pixel_values = load_image(fetch_image(sample), max_num=12).to(torch.bfloat16)
num_patches = pixel_values.size(0)
image_tokens = self.IMG_START_TOKEN + self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + self.IMG_END_TOKEN
query = query.replace('<image>', image_tokens, 1)
image_flags = torch.tensor([1] * num_patches, dtype=torch.long)
return {
"query": query,
"pixel_values": pixel_values,
"image_flags": image_flags,
}

def prepare_dataset(
self,
calibration_dataset,
batch_size: int = 1,
tokenizer=None, ):
if tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(self.model_local_path, trust_remote_code=True)

tokenizer.padding_side = 'left'

calib_data = []
for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset):
queries, pixel_values, image_flags = tuple(
[instance[key] for instance in batch] for key in ("query", "pixel_values", "image_flags"))
model_inputs = tokenizer(queries, return_tensors='pt', padding=True)
input_ids = model_inputs['input_ids']
attention_mask = model_inputs['attention_mask']

pixel_values = torch.cat(pixel_values, dim=0)
image_flags = torch.cat(image_flags, dim=0)

calib_data.append({
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"image_flags": image_flags,
})
return calib_data
5 changes: 2 additions & 3 deletions tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, torch_dtype="aut

is_quantized = model.quantized

# ovis cannot load processor
is_ovis_model = model.__class__.__name__ == "OvisGPTQ"
need_create_processor = is_image_to_text_model and not is_ovis_model
is_qwen2vl_model = model.__class__.__name__ == "Qwen2VLGPTQ"
need_create_processor = is_image_to_text_model and is_qwen2vl_model
if not is_quantized:
model.quantize(calibration_dataset, batch_size=batch_size)

Expand Down
11 changes: 10 additions & 1 deletion tests/models/ovis/image_to_test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from gptqmodel.models import OvisGPTQ, Qwen2VLGPTQ
from gptqmodel.models import OvisGPTQ, Qwen2VLGPTQ, InternVLChatGPTQ


def format_ovis_dataset(image, assistant):
Expand Down Expand Up @@ -29,6 +29,12 @@ def format_qwen2_vl_dataset(image, assistant):
{"role": "assistant", "content": assistant},
]

def format_internlm2_vl_dataset(image, assistant):
return {
"image": image,
"question": f"<image>\nDescribe the image in detail.",
"answer": assistant,
}

def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]:
from datasets import load_dataset
Expand All @@ -49,4 +55,7 @@ def get_calib_dataset(model):
if isinstance(model, Qwen2VLGPTQ):
return prepare_dataset(format_qwen2_vl_dataset, n_sample=1)

if isinstance(model, InternVLChatGPTQ):
return prepare_dataset(format_internlm2_vl_dataset, n_sample=1)

raise NotImplementedError(f"Unsupported MODEL: {model.__class__}")
20 changes: 20 additions & 0 deletions tests/models/test_internvl_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from model_test import ModelTest


class TestInternlm2_VL(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/InternVL2-8B-MPO"
NATIVE_ARC_CHALLENGE_ACC = 0.3217
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3575
APPLY_CHAT_TEMPLATE = True
TRUST_REMOTE_CODE = True
BATCH_SIZE = 6
USE_VLLM = False


def test_internlm2_5(self):
# transformers<=4.44.2 run normal
model, tokenizer, processor = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE,
torch_dtype=self.TORCH_DTYPE, use_flash_attn=False)



Loading