From 14ef2fc36547b898d49d84a3355d1b7a2540a171 Mon Sep 17 00:00:00 2001 From: matatonic Date: Fri, 13 Sep 2024 13:38:57 -0400 Subject: [PATCH] 0.30.0 +pixtral, attn_changes --- README.md | 7 +- backend/pixtral.py | 48 ++++++++ chat_with_image.py | 17 +-- model_conf_tests.alt.json | 22 ++-- model_conf_tests.json | 124 ++++++++++---------- requirements.txt | 4 + test_api_model.py | 208 ++++++++++++++++++++++++++++++++++ vision.py | 20 +++- vision.sample.env | 230 +++++++++++++++++++------------------- vision_qna.py | 44 ++++++-- 10 files changed, 515 insertions(+), 209 deletions(-) create mode 100644 backend/pixtral.py create mode 100755 test_api_model.py diff --git a/README.md b/README.md index 8749076..5e31717 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,8 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/ - - [X] [Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct) - - [X] [Florence-2-large-ft](https://huggingface.co/microsoft/Florence-2-large-ft) (wont gpu split) - - [X] [Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft) (wont gpu split) +- [X] [Mistral AI](https://huggingface.co/mistralai) +- - [X] [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409) - [X] [openbmb](https://huggingface.co/openbmb) - - [X] [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) (video not supported yet) - - [X] [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5) @@ -131,10 +133,13 @@ If you can't find your favorite model, you can [open a new issue](https://github ## Recent updates + Version 0.30.0 -- Update moondream2 to version 2024-08-26 +- new model support: mistralai/Pixtral-12B-2409 (no streaming yet, no quants yet) - new model support: LMMs-Lab's llava-onevision-qwen2, 0.5b, 7b and 72b (72b untested, 4bit support doesn't seem to work properly yet) +- Update moondream2 to version 2024-08-26 +- Performance fixed: idefics2-8b-AWQ, idefics2-8b-chatty-AWQ Version 0.29.0 diff --git a/backend/pixtral.py b/backend/pixtral.py new file mode 100644 index 0000000..3240ab2 --- /dev/null +++ b/backend/pixtral.py @@ -0,0 +1,48 @@ + +from huggingface_hub import snapshot_download +from safetensors import safe_open +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +from mistral_inference.transformer import Transformer +from mistral_inference.generate import generate +from mistral_common.tokens.tokenizers.mistral import MistralTokenizer + +from vision_qna import * + +# mistralai/Pixtral-12B-2409 + +class VisionQnA(VisionQnABase): + model_name: str = "pixtral" + format: str = "pixtral" + visual_layers: List[str] = ["vision_encoder", 'vision_language_adapter'] + + def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None): + super().__init__(model_id, device, device_map, extra_params, format) + + mistral_models_path = snapshot_download(repo_id=model_id, allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"]) + + self.tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json") + self.model = Transformer.from_folder(mistral_models_path, device=self.device, dtype=self.dtype) + + # bitsandbytes already moves the model to the device, so we don't need to do it again. + #if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): + # self.model = self.model.to(self.device) + + self.loaded_banner() + + async def chat_with_images(self, request: ImageChatRequest) -> str: + prompt = await pixtral_messages(request.messages) + + # tokenize image urls and text + tokenized = self.tokenizer.encode_chat_completion(prompt) + + generation_kwargs = dict( + eos_id = self.tokenizer.instruct_tokenizer.tokenizer.eos_id, + max_tokens = request.max_tokens, + temperature= 0.35 if request.temperature is None else request.temperature, + ) + + out_tokens, _ = generate([tokenized.tokens], self.model, images=[tokenized.images], **generation_kwargs) + + return self.tokenizer.decode(out_tokens[0]) diff --git a/chat_with_image.py b/chat_with_image.py index 50776ae..b1521e0 100755 --- a/chat_with_image.py +++ b/chat_with_image.py @@ -1,15 +1,16 @@ #!/usr/bin/env python +try: + import dotenv + dotenv.load_dotenv(override=True) +except: + pass + import os import requests import argparse from datauri import DataURI from openai import OpenAI -try: - import dotenv - dotenv.load_dotenv(override=True) -except: - pass def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> str: if img_url.startswith('http'): @@ -29,6 +30,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> parser = argparse.ArgumentParser(description='Test vision using OpenAI', formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser.add_argument('-s', '--system-prompt', type=str, default=None) + parser.add_argument('--openai-model', type=str, default="gpt-4-vision-preview") parser.add_argument('-S', '--start-with', type=str, default=None, help="Start reply with, ex. 'Sure, ' (doesn't work with all models)") parser.add_argument('-m', '--max-tokens', type=int, default=None) parser.add_argument('-t', '--temperature', type=float, default=None) @@ -40,7 +42,8 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image') args = parser.parse_args() - client = OpenAI(base_url=os.environ.get('OPENAI_BASE_URL', 'http://localhost:5006/v1'), api_key='skip') + client = OpenAI(base_url=os.environ.get('OPENAI_BASE_URL', 'http://localhost:5006/v1'), + api_key=os.environ.get('OPENAI_API_KEY', 'sk-ip')) params = {} if args.max_tokens is not None: @@ -67,7 +70,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> if args.start_with: messages.extend([{ "role": "assistant", "content": [{ "type": "text", "text": args.start_with }] }]) - response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, **params) + response = client.chat.completions.create(model=args.openai_model, messages=messages, **params) if not args.single: print(f"Answer: ", end='', flush=True) diff --git a/model_conf_tests.alt.json b/model_conf_tests.alt.json index a0654be..96aebe3 100644 --- a/model_conf_tests.alt.json +++ b/model_conf_tests.alt.json @@ -12,15 +12,15 @@ ["THUDM/cogvlm2-llama3-chat-19B"], ["THUDM/cogvlm2-llama3-chinese-chat-19B", "--load-in-4bit"], ["THUDM/cogvlm2-llama3-chinese-chat-19B"], - ["cognitivecomputations/dolphin-vision-72b", "--use-flash-attn", "--load-in-4bit", "--device-map", "cuda:0"], - ["cognitivecomputations/dolphin-vision-7b", "--use-flash-attn", "--load-in-4bit", "--device-map", "cuda:0"], - ["cognitivecomputations/dolphin-vision-7b", "--use-flash-attn", "--device-map", "cuda:0"], - ["llava-hf/llava-v1.6-mistral-7b-hf", "--use-flash-attn", "--load-in-4bit"], - ["llava-hf/llava-v1.6-mistral-7b-hf", "--use-flash-attn"], - ["openbmb/MiniCPM-V", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["openbmb/MiniCPM-V", "--use-flash-attn", "--device-map", "cuda:0"], - ["openbmb/MiniCPM-V-2", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["openbmb/MiniCPM-V-2", "--use-flash-attn", "--device-map", "cuda:0"], - ["tiiuae/falcon-11B-vlm", "--use-flash-attn", "--load-in-4bit"], - ["tiiuae/falcon-11B-vlm", "--use-flash-attn"] + ["cognitivecomputations/dolphin-vision-72b", "-A", "flash_attention_2", "--load-in-4bit", "--device-map", "cuda:0"], + ["cognitivecomputations/dolphin-vision-7b", "-A", "flash_attention_2", "--load-in-4bit", "--device-map", "cuda:0"], + ["cognitivecomputations/dolphin-vision-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["llava-hf/llava-v1.6-mistral-7b-hf", "-A", "flash_attention_2", "--load-in-4bit"], + ["llava-hf/llava-v1.6-mistral-7b-hf", "-A", "flash_attention_2"], + ["openbmb/MiniCPM-V", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["openbmb/MiniCPM-V", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["openbmb/MiniCPM-V-2", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["openbmb/MiniCPM-V-2", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["tiiuae/falcon-11B-vlm", "-A", "flash_attention_2", "--load-in-4bit"], + ["tiiuae/falcon-11B-vlm", "-A", "flash_attention_2"] ] diff --git a/model_conf_tests.json b/model_conf_tests.json index fb99cac..94bb960 100644 --- a/model_conf_tests.json +++ b/model_conf_tests.json @@ -14,12 +14,12 @@ ["BAAI/Bunny-v1_1-Llama-3-8B-V"], ["BAAI/Emu2-Chat", "--load-in-4bit"], ["BAAI/Emu2-Chat", "--max-memory=0:78GiB,1:20GiB"], - ["HuggingFaceM4/idefics2-8b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["HuggingFaceM4/idefics2-8b", "--use-flash-attn", "--device-map", "cuda:0"], - ["HuggingFaceM4/idefics2-8b-AWQ", "--use-flash-attn", "--device-map", "cuda:0"], - ["HuggingFaceM4/idefics2-8b-chatty", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["HuggingFaceM4/idefics2-8b-chatty", "--use-flash-attn", "--device-map", "cuda:0"], - ["HuggingFaceM4/idefics2-8b-chatty-AWQ", "--use-flash-attn", "--device-map", "cuda:0"], + ["HuggingFaceM4/idefics2-8b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["HuggingFaceM4/idefics2-8b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["HuggingFaceM4/idefics2-8b-AWQ", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["HuggingFaceM4/idefics2-8b-chatty", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["HuggingFaceM4/idefics2-8b-chatty", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["HuggingFaceM4/idefics2-8b-chatty-AWQ", "-A", "flash_attention_2", "--device-map", "cuda:0"], ["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0", "--load-in-4bit"], ["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0", "--max-tiles", "40", "--load-in-4bit"], ["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0", "--max-tiles", "40"], @@ -29,6 +29,8 @@ ["OpenGVLab/InternVL2-1B", "--device-map", "cuda:0"], ["OpenGVLab/InternVL2-2B", "--device-map", "cuda:0", "--load-in-4bit"], ["OpenGVLab/InternVL2-2B", "--device-map", "cuda:0"], + ["OpenGVLab/InternVL2-4B", "--device-map", "cuda:0", "--load-in-4bit"], + ["OpenGVLab/InternVL2-4B", "--device-map", "cuda:0"], ["OpenGVLab/InternVL2-8B", "--device-map", "cuda:0", "--load-in-4bit"], ["OpenGVLab/InternVL2-8B", "--device-map", "cuda:0"], ["OpenGVLab/InternVL2-26B", "--device-map", "cuda:0", "--load-in-4bit"], @@ -40,6 +42,9 @@ ["OpenGVLab/Mini-InternVL-Chat-2B-V1-5", "--max-tiles", "40", "--load-in-4bit"], ["OpenGVLab/Mini-InternVL-Chat-2B-V1-5", "--max-tiles", "40"], ["OpenGVLab/Mini-InternVL-Chat-2B-V1-5"], + ["OpenGVLab/Mini-InternVL-Chat-4B-V1-5", "--max-tiles", "40", "--load-in-4bit"], + ["OpenGVLab/Mini-InternVL-Chat-4B-V1-5", "--load-in-4bit"], + ["OpenGVLab/Mini-InternVL-Chat-4B-V1-5"], ["Qwen/Qwen-VL-Chat", "--load-in-4bit"], ["Qwen/Qwen-VL-Chat"], ["Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5"], @@ -50,68 +55,67 @@ ["THUDM/glm-4v-9b", "--device-map", "cuda:0"], ["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0", "--load-in-4bit"], ["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0"], - ["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0"], - ["TIGER-Lab/Mantis-8B-siglip-llama3", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["TIGER-Lab/Mantis-8B-siglip-llama3", "--use-flash-attn", "--device-map", "cuda:0"], + ["TIGER-Lab/Mantis-8B-clip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["TIGER-Lab/Mantis-8B-clip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["TIGER-Lab/Mantis-8B-siglip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["TIGER-Lab/Mantis-8B-siglip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0"], ["adept/fuyu-8b", "--device-map", "cuda:0", "--load-in-4bit"], ["adept/fuyu-8b", "--device-map", "cuda:0"], - ["echo840/Monkey", "--load-in-4bit"], - ["echo840/Monkey"], ["echo840/Monkey-Chat", "--load-in-4bit"], ["echo840/Monkey-Chat"], - ["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "--use-flash-attn", "--load-in-4bit"], - ["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "--use-flash-attn"], - ["fancyfeast/joy-caption-pre-alpha", "--load-in-4bit", "--use-flash-attn"], - ["fancyfeast/joy-caption-pre-alpha", "--use-flash-attn"], - ["internlm/internlm-xcomposer2d5-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["internlm/internlm-xcomposer2d5-7b", "--use-flash-attn", "--device-map", "cuda:0"], - ["internlm/internlm-xcomposer2-4khd-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["internlm/internlm-xcomposer2-4khd-7b", "--use-flash-attn", "--device-map", "cuda:0"], - ["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0"], - ["internlm/internlm-xcomposer2-7b-4bit", "--use-flash-attn"], - ["internlm/internlm-xcomposer2-vl-1_8b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["internlm/internlm-xcomposer2-vl-1_8b", "--use-flash-attn", "--device-map", "cuda:0"], - ["internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn", "--device-map", "cuda:0"], - ["internlm/internlm-xcomposer2-vl-7b-4bit", "--use-flash-attn"], - ["llava-hf/llava-1.5-13b-hf", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["llava-hf/llava-1.5-13b-hf", "--use-flash-attn", "--device-map", "cuda:0"], - ["llava-hf/llava-1.5-7b-hf", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["llava-hf/llava-1.5-7b-hf", "--use-flash-attn", "--device-map", "cuda:0"], - ["llava-hf/llava-v1.6-34b-hf", "--use-flash-attn", "--load-in-4bit"], - ["llava-hf/llava-v1.6-34b-hf", "--use-flash-attn"], - ["llava-hf/llava-v1.6-vicuna-13b-hf", "--use-flash-attn", "--load-in-4bit"], - ["llava-hf/llava-v1.6-vicuna-13b-hf", "--use-flash-attn"], - ["llava-hf/llava-v1.6-vicuna-7b-hf", "--use-flash-attn", "--load-in-4bit"], - ["llava-hf/llava-v1.6-vicuna-7b-hf", "--use-flash-attn"], - ["lmms-lab/llava-onevision-qwen2-0.5b-ov", "--use-flash-attn"], - ["lmms-lab/llava-onevision-qwen2-7b-ov", "--use-flash-attn"], - ["microsoft/Florence-2-base-ft", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["microsoft/Florence-2-base-ft", "--use-flash-attn", "--device-map", "cuda:0"], - ["microsoft/Florence-2-large-ft", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["microsoft/Florence-2-large-ft", "--use-flash-attn", "--device-map", "cuda:0"], - ["microsoft/Phi-3-vision-128k-instruct", "--use-flash-attn", "--load-in-4bit"], - ["microsoft/Phi-3-vision-128k-instruct", "--use-flash-attn"], - ["microsoft/Phi-3.5-vision-instruct", "--use-flash-attn", "--load-in-4bit"], - ["microsoft/Phi-3.5-vision-instruct", "--use-flash-attn"], - ["openbmb/MiniCPM-V-2_6", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["openbmb/MiniCPM-V-2_6", "--use-flash-attn", "--device-map", "cuda:0"], - ["openbmb/MiniCPM-Llama3-V-2_5", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["openbmb/MiniCPM-Llama3-V-2_5", "--use-flash-attn", "--device-map", "cuda:0"], - ["qihoo360/360VL-8B", "--use-flash-attn", "--load-in-4bit"], - ["qihoo360/360VL-8B", "--use-flash-attn"], - ["qnguyen3/nanoLLaVA", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["qnguyen3/nanoLLaVA", "--use-flash-attn", "--device-map", "cuda:0"], - ["qnguyen3/nanoLLaVA-1.5", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"], - ["qnguyen3/nanoLLaVA-1.5", "--use-flash-attn", "--device-map", "cuda:0"], + ["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "-A", "flash_attention_2", "--load-in-4bit"], + ["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "-A", "flash_attention_2"], + ["fancyfeast/joy-caption-pre-alpha", "--load-in-4bit", "-A", "flash_attention_2"], + ["fancyfeast/joy-caption-pre-alpha", "-A", "flash_attention_2"], + ["internlm/internlm-xcomposer2d5-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["internlm/internlm-xcomposer2d5-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["internlm/internlm-xcomposer2-4khd-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["internlm/internlm-xcomposer2-4khd-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["internlm/internlm-xcomposer2-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["internlm/internlm-xcomposer2-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["internlm/internlm-xcomposer2-7b-4bit", "-A", "flash_attention_2"], + ["internlm/internlm-xcomposer2-vl-1_8b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["internlm/internlm-xcomposer2-vl-1_8b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["internlm/internlm-xcomposer2-vl-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["internlm/internlm-xcomposer2-vl-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["internlm/internlm-xcomposer2-vl-7b-4bit", "-A", "flash_attention_2"], + ["llava-hf/llava-1.5-13b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["llava-hf/llava-1.5-13b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["llava-hf/llava-1.5-7b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["llava-hf/llava-1.5-7b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["llava-hf/llava-v1.6-34b-hf", "-A", "flash_attention_2", "--load-in-4bit"], + ["llava-hf/llava-v1.6-34b-hf", "-A", "flash_attention_2"], + ["llava-hf/llava-v1.6-vicuna-13b-hf", "-A", "flash_attention_2", "--load-in-4bit"], + ["llava-hf/llava-v1.6-vicuna-13b-hf", "-A", "flash_attention_2"], + ["llava-hf/llava-v1.6-vicuna-7b-hf", "-A", "flash_attention_2", "--load-in-4bit"], + ["llava-hf/llava-v1.6-vicuna-7b-hf", "-A", "flash_attention_2"], + ["lmms-lab/llava-onevision-qwen2-0.5b-ov", "-A", "flash_attention_2"], + ["lmms-lab/llava-onevision-qwen2-7b-ov", "-A", "flash_attention_2"], + ["microsoft/Florence-2-base-ft", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["microsoft/Florence-2-base-ft", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["microsoft/Florence-2-large-ft", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["microsoft/Florence-2-large-ft", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["microsoft/Phi-3-vision-128k-instruct", "-A", "flash_attention_2", "--load-in-4bit"], + ["microsoft/Phi-3-vision-128k-instruct", "-A", "flash_attention_2"], + ["microsoft/Phi-3.5-vision-instruct", "-A", "flash_attention_2", "--load-in-4bit"], + ["microsoft/Phi-3.5-vision-instruct", "-A", "flash_attention_2"], + ["mistralai/Pixtral-12B-2409"], + ["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["openbmb/MiniCPM-Llama3-V-2_5", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["openbmb/MiniCPM-Llama3-V-2_5", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["qihoo360/360VL-8B", "-A", "flash_attention_2", "--load-in-4bit"], + ["qihoo360/360VL-8B", "-A", "flash_attention_2"], + ["qnguyen3/nanoLLaVA", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["qnguyen3/nanoLLaVA", "-A", "flash_attention_2", "--device-map", "cuda:0"], + ["qnguyen3/nanoLLaVA-1.5", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"], + ["qnguyen3/nanoLLaVA-1.5", "-A", "flash_attention_2", "--device-map", "cuda:0"], ["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0", "--load-in-4bit"], ["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0"], ["togethercomputer/Llama-3-8B-Dragonfly-Med-v1", "--load-in-4bit"], ["togethercomputer/Llama-3-8B-Dragonfly-Med-v1"], ["togethercomputer/Llama-3-8B-Dragonfly-v1", "--load-in-4bit"], ["togethercomputer/Llama-3-8B-Dragonfly-v1"], - ["vikhyatk/moondream2", "--use-flash-attn", "--load-in-4bit"], - ["vikhyatk/moondream2", "--use-flash-attn"] + ["vikhyatk/moondream2", "-A", "flash_attention_2", "--load-in-4bit"], + ["vikhyatk/moondream2", "-A", "flash_attention_2"] ] diff --git a/requirements.txt b/requirements.txt index dbd70cb..3183f53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -50,3 +50,7 @@ logger # llava-onevision git+https://github.com/LLaVA-VL/LLaVA-NeXT.git + +# mistral +mistral_inference>=1.4.0 +mistral_common>=1.4.0 diff --git a/test_api_model.py b/test_api_model.py new file mode 100755 index 0000000..26407ca --- /dev/null +++ b/test_api_model.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python +try: + import dotenv + load_dotenv(override=True) +except: + pass + +import time +import json +import sys +import os +import requests +import argparse +import subprocess +import traceback +from datauri import DataURI +from openai import OpenAI +import torch + +# tests are configured with model_conf_tests.json + +all_results = [] + +client = OpenAI( + base_url=os.environ.get("OPENAI_BASE_URL", 'http://localhost:5006/v1'), + api_key=os.environ.get("OPENAI_API_KEY", 'sk-ip'), +) + +urls = { + 'tree': 'https://images.freeimages.com/images/large-previews/e59/autumn-tree-1408307.jpg', + 'waterfall': 'https://images.freeimages.com/images/large-previews/242/waterfall-1537490.jpg', + 'horse': 'https://images.freeimages.com/images/large-previews/5fa/attenborough-nature-reserve-1398791.jpg', + 'leaf': 'https://images.freeimages.com/images/large-previews/cd7/gingko-biloba-1058537.jpg', +} + +quality_urls = { + '98.21': ('What is the total bill?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'), + 'walmart': ('What store is the receipt from?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'), +} + +no_image = { + '5': 'In the integer sequence: 1, 2, 3, 4, ... What number comes next after 4?' +} + +green_pass = '\033[92mpass\033[0m✅' +red_fail = '\033[91mfail\033[0m❌' + + +def data_url_from_url(img_url: str) -> str: + response = requests.get(img_url) + + img_data = response.content + content_type = response.headers['content-type'] + return str(DataURI.make(mimetype=content_type, charset='utf-8', base64=True, data=img_data)) + +def record_result(cmd_args, results, t, mem, note): + # update all_results with the test data + all_results.extend([{ + 'args': cmd_args, + 'results': results, + 'time': t, + 'mem': mem, + 'note': note + }]) + result = all(results) + print(f"test {green_pass if result else red_fail}, time: {t:.1f}s, mem: {mem:.1f}GB, {note}") + +if __name__ == '__main__': + # Initialize argparse + parser = argparse.ArgumentParser(description='Test vision using OpenAI') + parser.add_argument('-s', '--system-prompt', type=str, default=None) + parser.add_argument('-m', '--max-tokens', type=int, default=None) + parser.add_argument('-t', '--temperature', type=float, default=None) + parser.add_argument('-p', '--top_p', type=float, default=None) + parser.add_argument('-v', '--verbose', action='store_true', help="Verbose") + parser.add_argument('--openai-model', type=str, default="gpt-4-vision-preview") + parser.add_argument('--abort-on-fail', action='store_true', help="Abort testing on fail.") + parser.add_argument('--quiet', action='store_true', help="Less test noise.") + parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level") + args = parser.parse_args() + + + params = {} + if args.max_tokens is not None: + params['max_tokens'] = args.max_tokens + if args.temperature is not None: + params['temperature'] = args.temperature + if args.top_p is not None: + params['top_p'] = args.top_p + + def generate_response(image_url, prompt): + + messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else [] + messages.extend([ + { "role": "user", "content": [ + { "type": "image_url", "image_url": { "url": image_url } }, + { "type": "text", "text": prompt }, + ]}]) + + response = client.chat.completions.create(model=args.openai_model, messages=messages, **params) + answer = response.choices[0].message.content + return answer + + def generate_stream_response(image_url, prompt): + + messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else [] + messages.extend([ + { "role": "user", "content": [ + { "type": "image_url", "image_url": { "url": image_url } }, + { "type": "text", "text": prompt }, + ]}]) + + response = client.chat.completions.create(model=args.openai_model, messages=messages, **params, stream=True) + answer = '' + for chunk in response: + if chunk.choices[0].delta.content: + answer += chunk.choices[0].delta.content + + return answer + + if True: + # XXX TODO: timeout + results = [] + ### Single round + + test_time = time.time() + + # url tests + for name, url in urls.items(): + answer = generate_response(url, "What is the subject of the image?") + correct = name in answer.lower() + results.extend([correct]) + if not correct: + print(f"{name}[url]: fail, got: {answer}") + if args.abort_on_fail: + break + else: + print(f"{name}[url]: pass{', got: ' + answer if args.verbose else ''}") + + data_url = data_url_from_url(url) + answer = generate_response(data_url, "What is the subject of the image?") + correct = name in answer.lower() + results.extend([correct]) + if not correct: + print(f"{name}[data]: fail, got: {answer}") + if args.abort_on_fail: + break + else: + print(f"{name}[data]: pass{', got: ' + answer if args.verbose else ''}") + + answer = generate_stream_response(data_url, "What is the subject of the image?") + correct = name in answer.lower() + results.extend([correct]) + if not correct: + print(f"{name}[data_stream]: fail, got: {answer}") + if args.abort_on_fail: + break + else: + print(f"{name}[data_stream]: pass{', got: ' + answer if args.verbose else ''}") + + + ## OCR tests + quality_urls = { + '98.21': ('What is the total bill?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'), + 'walmart': ('What store is the receipt from?', 'https://ocr.space/Content/Images/receipt-ocr-original.webp'), + } + for name, question in quality_urls.items(): + prompt, data_url = question + answer = generate_stream_response(data_url, prompt) + correct = name in answer.lower() or 'wal-mart' in answer.lower() + results.extend([correct]) + if not correct: + print(f"{name}[quality]: fail, got: {answer}") + if args.abort_on_fail: + break + else: + print(f"{name}[quality]: pass{', got: ' + answer if args.verbose else ''}") + + # No image tests + no_image = { + '5': 'In the sequence of numbers: 1, 2, 3, 4, ... What number comes next after 4?' + } + + def no_image_response(prompt): + messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else [] + messages.extend([{ "role": "user", "content": prompt }]) + + response = client.chat.completions.create(model=args.openai_model, messages=messages, **params, max_tokens=5) + answer = response.choices[0].message.content + return answer + + for name, prompt in no_image.items(): + answer = no_image_response(prompt) + correct = True #name in answer.lower() # - no exceptions is enough. + results.extend([correct]) + if not correct: + print(f"{name}[no_img]: fail, got: {answer}") + if args.abort_on_fail: + break + else: + print(f"{name}[no_img]: pass{', got: ' + answer if args.verbose else ''}") + + test_time = time.time() - test_time + + result = all(results) + note = f'{results.count(True)}/{len(results)} tests passed.' + + print(f"test {green_pass if results else red_fail}, time: {test_time:.1f}s, {note}") diff --git a/vision.py b/vision.py index 9d83c98..e241f3a 100644 --- a/vision.py +++ b/vision.py @@ -121,17 +121,20 @@ def parse_args(argv=None): parser.add_argument('-b', '--backend', action='store', default=None, help="Force the backend to use (phi3, idefics2, llavanext, llava, etc.)") parser.add_argument('-f', '--format', action='store', default=None, help="Force a specific chat format. (vicuna, mistral, chatml, llama2, phi15, etc.) (doesn't work with all models)") parser.add_argument('-d', '--device', action='store', default="auto", help="Set the torch device for the model. Ex. cpu, cuda:1") + #parser.add_argument('-t', '--dtype', action='store', default="auto", help="Set the torch dtype, ex. 'float16'") parser.add_argument('--device-map', action='store', default=os.environ.get('OPENEDAI_DEVICE_MAP', "auto"), help="Set the default device map policy for the model. (auto, balanced, sequential, balanced_low_0, cuda:1, etc.)") parser.add_argument('--max-memory', action='store', default=None, help="(emu2 only) Set the per cuda device_map max_memory. Ex. 0:22GiB,1:22GiB,cpu:128GiB") parser.add_argument('--no-trust-remote-code', action='store_true', help="Don't trust remote code (required for many models)") parser.add_argument('-4', '--load-in-4bit', action='store_true', help="load in 4bit (doesn't work with all models)") parser.add_argument('-8', '--load-in-8bit', action='store_true', help="load in 8bit (doesn't work with all models)") - parser.add_argument('-F', '--use-flash-attn', action='store_true', help="Use Flash Attention 2 (doesn't work with all models or GPU)") + parser.add_argument('-F', '--use-flash-attn', action='store_true', help="DEPRECATED: use --attn_implementation flash_attention_2 or -A flash_attention_2") + parser.add_argument('-A', '--attn_implementation', default='sdpa', type=str, help="Set the attn_implementation", choices=['sdpa', 'eager', 'flash_attention_2']) parser.add_argument('-T', '--max-tiles', action='store', default=None, type=int, help="Change the maximum number of tiles. [1-55+] (uses more VRAM for higher resolution, doesn't work with all models)") + parser.add_argument('--preload', action='store_true', help="Preload model and exit.") + parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level") - parser.add_argument('-P', '--port', action='store', default=5006, type=int, help="Server tcp port") parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. localhost") - parser.add_argument('--preload', action='store_true', help="Preload model and exit.") + parser.add_argument('-P', '--port', action='store', default=5006, type=int, help="Server tcp port") return parser.parse_args() if __name__ == "__main__": @@ -143,13 +146,18 @@ def parse_args(argv=None): logger.info(f"Loading VisionQnA[{args.backend}] with {args.model}") backend = importlib.import_module(f'backend.{args.backend}') - extra_params = {} + if args.use_flash_attn: + #logger.warning("The -F/--use-flash-attn option is deprecated and will be removed in a future release. Please use -A/--attn_implementation flash_attention_2 instead.") + args.attn_implementation = "flash_attention_2" + + extra_params = dict( + attn_implementation = args.attn_implementation + ) + if args.load_in_4bit: extra_params['load_in_4bit'] = True if args.load_in_8bit: extra_params['load_in_8bit'] = True - if args.use_flash_attn: - extra_params['use_flash_attn'] = True if args.max_tiles: extra_params['max_tiles'] = args.max_tiles diff --git a/vision.sample.env b/vision.sample.env index f4b391f..868710d 100644 --- a/vision.sample.env +++ b/vision.sample.env @@ -4,119 +4,117 @@ HF_HOME=hf_home HF_HUB_ENABLE_HF_TRANSFER=1 #HF_TOKEN=hf-... #CUDA_VISIBLE_DEVICES=1,0 -#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 8.9s, mem: 8.7GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V" # test pass✅, time: 8.2s, mem: 19.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh --load-in-4bit" # test pass✅, time: 8.8s, mem: 9.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh" # test pass✅, time: 5.7s, mem: 10.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B --load-in-4bit" # test pass✅, time: 12.5s, mem: 8.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B" # test pass✅, time: 9.0s, mem: 11.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B-zh" # test pass✅, time: 7.9s, mem: 12.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B --load-in-4bit" # test pass✅, time: 9.8s, mem: 5.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B" # test pass✅, time: 7.3s, mem: 12.2GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B --load-in-4bit" # test pass✅, time: 9.9s, mem: 5.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B" # test pass✅, time: 8.7s, mem: 13.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 10.7s, mem: 9.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V" # test pass✅, time: 10.7s, mem: 19.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --load-in-4bit" # test pass✅, time: 26.3s, mem: 29.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --max-memory=0:78GiB,1:20GiB" # test pass✅, time: 20.4s, mem: 71.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 12.1s, mem: 12.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b --use-flash-attn --device-map cuda:0" # test pass✅, time: 11.8s, mem: 21.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-AWQ --use-flash-attn --device-map cuda:0" # test pass✅, time: 62.8s, mem: 12.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 17.2s, mem: 12.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty --use-flash-attn --device-map cuda:0" # test pass✅, time: 14.8s, mem: 21.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty-AWQ --use-flash-attn --device-map cuda:0" # test pass✅, time: 90.2s, mem: 13.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 21.5s, mem: 27.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40 --load-in-4bit" # test pass✅, time: 27.6s, mem: 30.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40" # test pass✅, time: 25.7s, mem: 55.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0" # test pass✅, time: 19.2s, mem: 52.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5-Int8 --device-map cuda:0" # test pass✅, time: 38.4s, mem: 32.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-1B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.5s, mem: 4.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-1B --device-map cuda:0" # test pass✅, time: 8.5s, mem: 5.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-2B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 15.8s, mem: 5.2GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-2B --device-map cuda:0" # test pass✅, time: 8.8s, mem: 7.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-8B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.0s, mem: 9.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-8B --device-map cuda:0" # test pass✅, time: 8.3s, mem: 18.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-26B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 22.8s, mem: 27.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-26B --device-map cuda:0" # test pass✅, time: 20.5s, mem: 52.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-40B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 54.6s, mem: 32.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-40B --device-map cuda:0" # test pass✅, time: 48.1s, mem: 77.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-Llama3-76B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 40.4s, mem: 53.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --load-in-4bit" # test pass✅, time: 6.5s, mem: 5.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40 --load-in-4bit" # test pass✅, time: 7.2s, mem: 7.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40" # test pass✅, time: 7.0s, mem: 9.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5" # test pass✅, time: 6.2s, mem: 7.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat --load-in-4bit" # test pass✅, time: 10.2s, mem: 11.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat" # test pass✅, time: 7.1s, mem: 19.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5" # test pass✅, time: 9.2s, mem: 10.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5" # test pass✅, time: 4.4s, mem: 10.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5" # test pass✅, time: 9.9s, mem: 10.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-r-v1" # test pass✅, time: 11.0s, mem: 10.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 69.8s, mem: 16.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0" # test pass✅, time: 53.7s, mem: 28.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0 --load-in-4bit" # test pass✅, time: 10.2s, mem: 11.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0" # test pass✅, time: 9.7s, mem: 20.7GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 10.4s, mem: 7.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.8s, mem: 17.7GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 8.1s, mem: 8.7GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 --use-flash-attn --device-map cuda:0" # test pass✅, time: 6.6s, mem: 18.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 15.2s, mem: 16.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0" # test pass✅, time: 21.3s, mem: 25.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m echo840/Monkey --load-in-4bit" # test pass✅, time: 9.4s, mem: 16.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m echo840/Monkey" # test pass✅, time: 8.9s, mem: 21.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat --load-in-4bit" # test pass✅, time: 13.0s, mem: 16.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat" # test pass✅, time: 12.0s, mem: 21.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha --use-flash-attn --load-in-4bit" # test pass✅, time: 11.5s, mem: 7.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha --use-flash-attn" # test pass✅, time: 9.4s, mem: 11.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m fancyfeast/joy-caption-pre-alpha --load-in-4bit --use-flash-attn" # test pass✅, time: 106.4s, mem: 9.3GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m fancyfeast/joy-caption-pre-alpha --use-flash-attn" # test pass✅, time: 61.1s, mem: 18.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2d5-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.4s, mem: 9.2GB, 1/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2d5-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 27.5s, mem: 28.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 2.3s, mem: 7.3GB, 1/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 19.8s, mem: 20.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 1.9s, mem: 6.7GB, 1/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 26.2s, mem: 18.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b-4bit --use-flash-attn" # test pass✅, time: 60.9s, mem: 9.2GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 1.5s, mem: 3.3GB, 1/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.5s, mem: 7.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 1.4s, mem: 7.0GB, 1/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 26.2s, mem: 19.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b-4bit --use-flash-attn" # test pass✅, time: 45.5s, mem: 10.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 13.5s, mem: 10.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0" # test pass✅, time: 10.1s, mem: 26.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.3s, mem: 6.2GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.4s, mem: 14.7GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 62.1s, mem: 23.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf --use-flash-attn" # test pass✅, time: 69.8s, mem: 69.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 16.8s, mem: 13.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf --use-flash-attn" # test pass✅, time: 16.7s, mem: 30.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 16.6s, mem: 8.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf --use-flash-attn" # test pass✅, time: 14.9s, mem: 17.2GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m lmms-lab/llava-onevision-qwen2-0.5b-ov --use-flash-attn" # test pass✅, time: 8.8s, mem: 15.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m lmms-lab/llava-onevision-qwen2-7b-ov --use-flash-attn" # test pass✅, time: 19.2s, mem: 29.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m lmms-lab/llava-onevision-qwen2-72b-ov --use-flash-attn --load-in-4bit" # test fail❌, time: 5.7s, mem: 79.1GB, Test failed with InternalServerError -#CLI_COMMAND="python vision.py -m microsoft/Florence-2-base-ft --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 2.9s, mem: 1.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Florence-2-base-ft --use-flash-attn --device-map cuda:0" # test pass✅, time: 2.2s, mem: 1.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Florence-2-large-ft --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 3.7s, mem: 1.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Florence-2-large-ft --use-flash-attn --device-map cuda:0" # test pass✅, time: 2.9s, mem: 2.8GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct --use-flash-attn --load-in-4bit" # test pass✅, time: 11.4s, mem: 7.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct --use-flash-attn" # test pass✅, time: 9.3s, mem: 12.0GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Phi-3.5-vision-instruct --use-flash-attn --load-in-4bit" # test pass✅, time: 9.8s, mem: 5.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m microsoft/Phi-3.5-vision-instruct --use-flash-attn" # test pass✅, time: 8.6s, mem: 9.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 13.2s, mem: 9.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6 --use-flash-attn --device-map cuda:0" # test pass✅, time: 11.9s, mem: 18.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 26.9s, mem: 9.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 --use-flash-attn --device-map cuda:0" # test pass✅, time: 22.6s, mem: 19.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B --use-flash-attn --load-in-4bit" # test pass✅, time: 11.6s, mem: 8.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B --use-flash-attn" # test pass✅, time: 8.2s, mem: 17.6GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 10.7s, mem: 8.2GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA --use-flash-attn --device-map cuda:0" # test pass✅, time: 6.8s, mem: 8.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA-1.5 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 8.0s, mem: 8.1GB, 12/13 tests passed. -#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA-1.5 --use-flash-attn --device-map cuda:0" # test pass✅, time: 7.2s, mem: 8.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0 --load-in-4bit" # test pass✅, time: 8.0s, mem: 8.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0" # test pass✅, time: 8.2s, mem: 17.9GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-Med-v1 --load-in-4bit" # test pass✅, time: 11.7s, mem: 8.4GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-Med-v1" # test pass✅, time: 10.1s, mem: 18.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-v1 --load-in-4bit" # test pass✅, time: 10.3s, mem: 8.5GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-v1" # test pass✅, time: 12.5s, mem: 18.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 --use-flash-attn --load-in-4bit" # test pass✅, time: 6.2s, mem: 3.1GB, 13/13 tests passed. -#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 --use-flash-attn" # test pass✅, time: 5.4s, mem: 4.7GB, 13/13 tests passed. \ No newline at end of file +#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 10.4s, mem: 8.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V" # test pass✅, time: 7.8s, mem: 19.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh --load-in-4bit" # test pass✅, time: 6.3s, mem: 9.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh" # test pass✅, time: 5.2s, mem: 10.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B --load-in-4bit" # test pass✅, time: 10.1s, mem: 8.4GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B" # test pass✅, time: 8.6s, mem: 11.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B-zh" # test pass✅, time: 7.1s, mem: 12.4GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B --load-in-4bit" # test pass✅, time: 8.7s, mem: 5.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B" # test pass✅, time: 6.6s, mem: 12.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B --load-in-4bit" # test pass✅, time: 10.3s, mem: 5.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B" # test pass✅, time: 7.9s, mem: 13.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 11.2s, mem: 9.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V" # test pass✅, time: 9.9s, mem: 19.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --load-in-4bit" # test pass✅, time: 31.2s, mem: 29.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --max-memory=0:78GiB,1:20GiB" # test pass✅, time: 20.2s, mem: 71.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 24.3s, mem: 30.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 21.6s, mem: 38.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-AWQ -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 25.9s, mem: 29.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 31.4s, mem: 30.4GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 25.2s, mem: 38.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty-AWQ -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 28.8s, mem: 29.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 22.0s, mem: 27.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40 --load-in-4bit" # test pass✅, time: 28.7s, mem: 30.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40" # test pass✅, time: 25.2s, mem: 54.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0" # test pass✅, time: 18.4s, mem: 52.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5-Int8 --device-map cuda:0" # test fail❌, time: -1.0s, mem: -1.0GB, Error: Server failed to start (exit). +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-1B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 30.7s, mem: 4.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-1B --device-map cuda:0" # test pass✅, time: 8.0s, mem: 4.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-2B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 16.5s, mem: 4.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-2B --device-map cuda:0" # test pass✅, time: 8.2s, mem: 6.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-8B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 8.8s, mem: 8.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-8B --device-map cuda:0" # test pass✅, time: 7.3s, mem: 18.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-26B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 24.9s, mem: 26.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-26B --device-map cuda:0" # test pass✅, time: 19.6s, mem: 52.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-40B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 36.8s, mem: 32.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-40B --device-map cuda:0" # test pass✅, time: 46.9s, mem: 77.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL2-Llama3-76B --device-map cuda:0 --load-in-4bit" # test pass✅, time: 40.4s, mem: 53.4GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --load-in-4bit" # test pass✅, time: 5.9s, mem: 5.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40 --load-in-4bit" # test pass✅, time: 7.4s, mem: 6.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40" # test pass✅, time: 6.6s, mem: 8.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5" # test pass✅, time: 6.0s, mem: 7.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat --load-in-4bit" # test pass✅, time: 9.0s, mem: 11.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat" # test pass✅, time: 6.1s, mem: 19.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5" # test pass✅, time: 9.0s, mem: 9.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-interleave-r-v1.5" # test pass✅, time: 3.5s, mem: 9.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-singleimg-r-v1.5" # test pass✅, time: 10.3s, mem: 9.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m Salesforce/xgen-mm-phi3-mini-instruct-r-v1" # test pass✅, time: 10.9s, mem: 10.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 78.1s, mem: 16.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0" # test pass✅, time: 53.8s, mem: 28.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.5s, mem: 11.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0" # test pass✅, time: 9.0s, mem: 20.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 11.4s, mem: 7.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 8.1s, mem: 17.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 8.2s, mem: 8.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 5.9s, mem: 18.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 21.3s, mem: 16.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0" # test pass✅, time: 21.0s, mem: 25.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat --load-in-4bit" # test pass✅, time: 14.2s, mem: 16.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat" # test pass✅, time: 11.0s, mem: 21.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha -A flash_attention_2 --load-in-4bit" # test pass✅, time: 9.6s, mem: 7.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha -A flash_attention_2" # test pass✅, time: 8.4s, mem: 11.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m fancyfeast/joy-caption-pre-alpha --load-in-4bit -A flash_attention_2" # test pass✅, time: 107.5s, mem: 9.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m fancyfeast/joy-caption-pre-alpha -A flash_attention_2" # test pass✅, time: 60.4s, mem: 18.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2d5-7b -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test fail❌, time: 2.9s, mem: 8.9GB, 1/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2d5-7b -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 28.1s, mem: 28.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test fail❌, time: 1.7s, mem: 7.0GB, 1/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 21.0s, mem: 21.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test fail❌, time: 1.5s, mem: 6.4GB, 1/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 28.1s, mem: 18.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b-4bit -A flash_attention_2" # test pass✅, time: 53.1s, mem: 8.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test fail❌, time: 0.9s, mem: 2.9GB, 1/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 8.9s, mem: 7.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test fail❌, time: 0.9s, mem: 6.6GB, 1/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 23.2s, mem: 19.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b-4bit -A flash_attention_2" # test pass✅, time: 49.7s, mem: 10.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 11.7s, mem: 9.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 8.2s, mem: 26.4GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.2s, mem: 5.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 7.3s, mem: 14.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf -A flash_attention_2 --load-in-4bit" # test pass✅, time: 56.0s, mem: 23.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf -A flash_attention_2" # test pass✅, time: 72.2s, mem: 68.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf -A flash_attention_2 --load-in-4bit" # test pass✅, time: 15.5s, mem: 13.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf -A flash_attention_2" # test pass✅, time: 13.6s, mem: 29.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf -A flash_attention_2 --load-in-4bit" # test pass✅, time: 15.5s, mem: 8.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf -A flash_attention_2" # test pass✅, time: 10.7s, mem: 16.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m lmms-lab/llava-onevision-qwen2-0.5b-ov -A flash_attention_2" # test pass✅, time: 8.2s, mem: 15.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m lmms-lab/llava-onevision-qwen2-7b-ov -A flash_attention_2" # test pass✅, time: 18.9s, mem: 29.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Florence-2-base-ft -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 2.2s, mem: 1.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Florence-2-base-ft -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 1.7s, mem: 1.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Florence-2-large-ft -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 2.6s, mem: 1.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Florence-2-large-ft -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 1.9s, mem: 2.5GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct -A flash_attention_2 --load-in-4bit" # test pass✅, time: 12.6s, mem: 6.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct -A flash_attention_2" # test pass✅, time: 9.0s, mem: 11.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Phi-3.5-vision-instruct -A flash_attention_2 --load-in-4bit" # test pass✅, time: 9.5s, mem: 4.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m microsoft/Phi-3.5-vision-instruct -A flash_attention_2" # test pass✅, time: 8.3s, mem: 9.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m mistralai/Pixtral-12B-2409" # test pass✅, time: 16.5s, mem: 35.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6 -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 14.3s, mem: 10.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2_6 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 11.7s, mem: 19.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 24.0s, mem: 9.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 22.3s, mem: 19.2GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B -A flash_attention_2 --load-in-4bit" # test pass✅, time: 12.5s, mem: 8.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B -A flash_attention_2" # test pass✅, time: 7.2s, mem: 17.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 8.9s, mem: 7.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 6.6s, mem: 8.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA-1.5 -A flash_attention_2 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 7.1s, mem: 7.8GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA-1.5 -A flash_attention_2 --device-map cuda:0" # test pass✅, time: 5.7s, mem: 8.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0 --load-in-4bit" # test pass✅, time: 11.9s, mem: 8.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0" # test pass✅, time: 8.6s, mem: 17.6GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-Med-v1 --load-in-4bit" # test pass✅, time: 10.8s, mem: 8.1GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-Med-v1" # test pass✅, time: 9.2s, mem: 17.7GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-v1 --load-in-4bit" # test pass✅, time: 12.8s, mem: 8.3GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-v1" # test pass✅, time: 12.8s, mem: 17.9GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 -A flash_attention_2 --load-in-4bit" # test pass✅, time: 6.1s, mem: 3.0GB, 13/13 tests passed. +#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 -A flash_attention_2" # test pass✅, time: 4.7s, mem: 4.7GB, 13/13 tests passed. \ No newline at end of file diff --git a/vision_qna.py b/vision_qna.py index 0242604..049576f 100644 --- a/vision_qna.py +++ b/vision_qna.py @@ -11,6 +11,8 @@ from pydantic import BaseModel from transformers import BitsAndBytesConfig, TextIteratorStreamer from loguru import logger +from mistral_common.protocol.instruct.messages import UserMessage, TextChunk, ImageURLChunk, SystemMessage, AssistantMessage, ToolMessage +from mistral_common.protocol.instruct.request import ChatCompletionRequest # When models require an image but no image given black_pixel_url = 'data:image/png;charset=utf-8;base64,iVBORw0KGgoAAAANSUhEUgAAAAgAAAAICAIAAABLbSncAAAADElEQVQI12NgGB4AAADIAAF8Y2l9AAAAAElFTkSuQmCC' @@ -60,7 +62,7 @@ def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_p load_in_4bit_params = { 'quantization_config': BitsAndBytesConfig( load_in_4bit=True, -# bnb_4bit_quant_type='nf4', + bnb_4bit_quant_type='nf4', # bnb_4bit_use_double_quant=True, # XXX gone for now, make this an option bnb_4bit_compute_dtype=self.dtype, llm_int8_skip_modules=self.vision_layers, @@ -76,12 +78,6 @@ def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_p } self.params.update(load_in_8bit_params) - if extra_params.get('use_flash_attn', False): - flash_attn_params = { - "attn_implementation": "flash_attention_2", - } - self.params.update(flash_attn_params) - if extra_params.get('trust_remote_code', False): self.params.update({"trust_remote_code": True }) @@ -697,6 +693,35 @@ async def florence_prompt_from_messages(messages: list[Message], url_handler = u return images, prompt +async def pixtral_messages(messages: list[Message]): + pix_messages = [] + + # generation_msg = '' + +# if messages and messages[-1].role == 'assistant': +# generation_msg += messages[-1].content[0].text +# messages.pop(-1) + + for m in messages: + content = [] + text = '' + for c in m.content: + if c.type == 'text' and c.text: + text = c.text + content.extend([TextChunk(text=c.text)]) + if c.type == 'image_url': + content.extend([ ImageURLChunk(image_url=c.image_url.url) ]) + + if m.role == 'user': + pix_messages.extend([UserMessage(content=content)]) + elif m.role == 'assistant': + pix_messages.extend([AssistantMessage(content=text)]) + elif m.role == 'system': + pix_messages.extend([SystemMessage(content=text)]) +# elif m.role == 'tool': +# pix_messages.extend([ToolMessage(content=text, tool_call_id=]) + + return ChatCompletionRequest(messages=pix_messages, model="pixtral") async def prompt_from_messages(messages: list[Message], format: str) -> str: known_formats = { @@ -861,4 +886,7 @@ def guess_backend(model_name: str) -> str: return 'dv-qwen' if 'fancyfeast/joy-caption-pre-alpha' in model_id: - return 'joy-caption-pre-alpha' \ No newline at end of file + return 'joy-caption-pre-alpha' + + if 'pixtral' in model_id: + return 'pixtral' \ No newline at end of file