Skip to content

Commit

Permalink
0.33.0 +minimonkey, +fix qwen2-vl with qwen-agent
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Sep 22, 2024
1 parent 4833b6e commit 1cc0b8c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 30 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft) (wont gpu split)
- [X] [Mistral AI](https://huggingface.co/mistralai)
- - [X] [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409)
- [X] [mx262/MiniMonkey](https://huggingface.co/mx262/MiniMonkey)
- [X] [omlab/omchat-v2.0-13B-single-beta_hf](https://huggingface.co/omlab/omchat-v2.0-13B-single-beta_hf)
- [X] [openbmb](https://huggingface.co/openbmb)
- - [X] [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) (video not supported yet)
Expand Down Expand Up @@ -144,6 +145,11 @@ If you can't find your favorite model, you can [open a new issue](https://github

## Recent updates

Version 0.33.0

- new model support: mx262/MiniMonkey, thanks [@white2018](https://github.com/white2018)
- Fix qwen2-vl when used with qwen-agent and multiple system prompts (tools), thanks [@cedonley](https://github.com/cedonley)

Version 0.32.0

- new model support: From AIDC-AI, Ovis1.5-Gemma2-9B and Ovis1.5-Llama3-8B
Expand Down
86 changes: 58 additions & 28 deletions backend/minimonkey.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@

import numpy as np
import torch
import torchvision.transforms as T
from PIL import Image
import transformers
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer

transformers.logging.set_verbosity_error()

from vision_qna import *

# mx262/MiniMonkey

IMG_START_TOKEN='<img>'
IMG_END_TOKEN='</img>'
IMG_CONTEXT_TOKEN='<IMG_CONTEXT>'

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

Expand Down Expand Up @@ -131,44 +138,67 @@ def load_image2(image, input_size=448, min_num=1, max_num=12, target_aspect_rati
pixel_values = torch.stack(pixel_values)
return pixel_values


from vision_qna import *

# mx262/MiniMonkey
import transformers
transformers.logging.set_verbosity_error()

class VisionQnA(VisionQnABase):
model_name: str = "minimonkey"
format: str = '' # phi15-ish
#vision_layers: List[str] = ["vision", "vision_tower", "resampler", "visual", "in_proj","out_proj","c_fc","c_proj"]
format: str = 'chatml'
vision_layers: List[str] = ["vision_model"]

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModel.from_pretrained(**self.params).eval()

self.model.img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN)
self.eos_token = '<|im_end|>'
self.eos_token_id = self.tokenizer.convert_tokens_to_ids(self.eos_token)

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.loaded_banner()

async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGenerator[str, None]:
query, history, images, system_message = await prompt_history_images_system_from_messages(
request.messages, img_tok='', url_handler=url_to_image)
images, prompt = await prompt_from_messages(request.messages, self.format)

# set the max number of tiles in `max_num`
pixel_values, target_aspect_ratio = load_image(images[0], min_num=4, max_num=12)
pixel_values = pixel_values.to(torch.bfloat16).to(self.model.device)
pixel_values2 = load_image2(images[0], min_num=3, max_num=7, target_aspect_ratio=target_aspect_ratio)
pixel_values2 = pixel_values2.to(torch.bfloat16).to(self.model.device)
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0)
if len(images) > 0:
# set the max number of tiles in `max_num`, XXX make an option
pixel_values, target_aspect_ratio = load_image(images[-1], min_num=4, max_num=12)
pixel_values2 = load_image2(images[-1], min_num=3, max_num=7, target_aspect_ratio=target_aspect_ratio)
pixel_values = torch.cat([pixel_values2[:-1], pixel_values[:-1], pixel_values2[-1:]], 0).to(dtype=self.dtype, device=self.model.device)

generation_config = dict(do_sample=False, max_new_tokens=512)
for num_patches in [pixel_values.shape[0]]:
image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.model.num_image_token * num_patches + IMG_END_TOKEN
prompt = prompt.replace('<image>', image_tokens, 1)

answer, history = self.model.chat(self.tokenizer, pixel_values, target_aspect_ratio, query, generation_config, history=None, return_history=True)
else:
pixel_values = None
target_aspect_ratio = None

model_inputs = self.tokenizer(prompt, return_tensors='pt')
input_ids = model_inputs['input_ids'].to(self.device)
attention_mask = model_inputs['attention_mask'].to(self.device)

inputs = dict(
input_ids=input_ids,
pixel_values=pixel_values,
attention_mask=attention_mask,
target_aspect_ratio=target_aspect_ratio,
)

default_params = dict(
do_sample=False,
eos_token_id=[self.eos_token_id, self.tokenizer.eos_token_id]
)
params = self.get_generation_params(request, default_params=default_params)

del params['use_cache']

if isinstance(answer, str):
answer = [answer]
generation_kwargs = dict(
**inputs,
**params,
)

for new_text in answer:
if isinstance(new_text, str):
yield new_text
for new_text in threaded_streaming_generator(generate=self.model.generate, tokenizer=self.tokenizer, generation_kwargs=generation_kwargs):
yield new_text
5 changes: 3 additions & 2 deletions backend/qwen2-vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Qwen/Qwen2-VL-2B-Instruct
# Qwen/Qwen2-VL-7B-Instruct-AWQ
# Qwen/Qwen2-VL-7B-Instruct
# Qwen/Qwen2-VL-72B-Instruct-AWQ
# X Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int4
# X Qwen/Qwen2-VL-2B-Instruct-GPTQ-Int8
# X Qwen/Qwen2-VL-7B-Instruct-GPTQ-Int4
Expand Down Expand Up @@ -50,8 +51,8 @@ async def stream_chat_with_images(self, request: ImageChatRequest) -> AsyncGener
elif c.type == 'video': # not likely to work.
msg['content'].extend([{'type': c.type, 'video': c.image_url.url}])
else:
#msg = { 'role': m.role, 'content': [{ 'type': 'text', 'text': c.text }] }
msg = { 'role': m.role, 'content': c.text }
ctext = "".join([c.text for c in m.content]) # fix for multiple system prompt contents #19
msg = { 'role': m.role, 'content': [{ 'type': 'text', 'text': ctext }] }

messages.extend([msg])

Expand Down
2 changes: 2 additions & 0 deletions model_conf_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@
["microsoft/Phi-3.5-vision-instruct", "-A", "flash_attention_2", "--load-in-4bit"],
["microsoft/Phi-3.5-vision-instruct", "-A", "flash_attention_2"],
["mistralai/Pixtral-12B-2409"],
["mx262/MiniMonkey", "-A", "flash_attention_2", "--load-in-4bit"],
["mx262/MiniMonkey", "-A", "flash_attention_2"],
["omlab/omchat-v2.0-13B-single-beta_hf", "-A", "flash_attention_2"],
["openbmb/MiniCPM-V-2_6-int4", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
Expand Down

0 comments on commit 1cc0b8c

Please sign in to comment.