diff --git a/Dockerfile b/Dockerfile
index 4804484..df1177d 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,14 +1,15 @@
FROM python:3.11-slim
-RUN apt-get update && apt-get install -y git gcc
-RUN pip install --no-cache-dir --upgrade pip
+RUN apt-get update && apt-get install -y git gcc \
+ && apt-get clean && rm -rf /var/lib/apt/lists/*
+RUN --mount=type=cache,target=/root/.cache/pip pip install --upgrade pip
-RUN mkdir -p /app
+WORKDIR /app
RUN git clone https://github.com/01-ai/Yi --single-branch /app/Yi
RUN git clone https://github.com/dvlab-research/MGM.git --single-branch /app/MGM
RUN git clone https://github.com/TIGER-AI-Lab/Mantis.git --single-branch /app/Mantis
+RUN git clone https://github.com/togethercomputer/Dragonfly --single-branch /app/Dragonfly
-WORKDIR /app
COPY requirements.txt .
ARG VERSION=latest
RUN if [ "$VERSION" = "alt" ]; then echo "transformers==4.36.2" >> requirements.txt; else echo "transformers>=4.41.2\nautoawq>=0.2.5" >> requirements.txt ; fi
@@ -21,12 +22,14 @@ RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .
WORKDIR /app/Mantis
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .
+WORKDIR /app/Dragonfly
+RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .
+
WORKDIR /app
COPY *.py .
COPY backend /app/backend
-
-COPY model_conf_tests.json /app/model_conf_tests.json
+COPY model_conf_tests.json .
ENV CLI_COMMAND="python vision.py"
CMD $CLI_COMMAND
diff --git a/README.md b/README.md
index 62c12da..8b616c4 100644
--- a/README.md
+++ b/README.md
@@ -65,6 +65,9 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview`
- - [X] [Mantis-8B-siglip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-siglip-llama3) (wont gpu split)
- - [X] [Mantis-8B-clip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-clip-llama3) (wont gpu split)
- - [X] [Mantis-8B-Fuyu](https://huggingface.co/TIGER-Lab/Mantis-8B-Fuyu) (wont gpu split)
+- [X] [Together.ai](https://huggingface.co/togethercomputer)
+- - [X] [Llama-3-8B-Dragonfly-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-v1)
+- - [X] [Llama-3-8B-Dragonfly-Med-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-Med-v1)
- [X] [fuyu-8b](https://huggingface.co/adept/fuyu-8b) [pretrain]
- [X] [falcon-11B-vlm](https://huggingface.co/tiiuae/falcon-11B-vlm)
- [X] [Monkey-Chat](https://huggingface.co/echo840/Monkey-Chat)
@@ -100,6 +103,12 @@ See: [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_le
## Recent updates
+Version 0.23.0
+
+- New model support: Together.ai's Llama-3-8B-Dragonfly-v1, Llama-3-8B-Dragonfly-Med-v1 (medical image model)
+- Compatibility: chatboxai.app can now use openedai-vision as a backend!
+- Initial support for streaming (real streaming for some [dragonfly, internvl-chat-v1-5], fake streaming for the rest). More to come.
+
Version 0.22.0
- new model support: THUDM/glm-4v-9b
diff --git a/backend/dragonfly.py b/backend/dragonfly.py
new file mode 100644
index 0000000..5ea6b65
--- /dev/null
+++ b/backend/dragonfly.py
@@ -0,0 +1,71 @@
+from threading import Thread
+from transformers import AutoTokenizer, AutoProcessor, logging
+from dragonfly.models.modeling_dragonfly import DragonflyForCausalLM
+from dragonfly.models.processing_dragonfly import DragonflyProcessor
+
+import warnings
+# disable some warnings
+logging.set_verbosity_error()
+warnings.filterwarnings('ignore')
+
+from vision_qna import *
+
+# togethercomputer/Llama-3-8B-Dragonfly-v1
+# togethercomputer/Llama-3-8B-Dragonfly-Med-v1
+
+class VisionQnA(VisionQnABase):
+ model_name: str = "dragonfly"
+ format: str = 'llama3'
+ vision_layers: List[str] = ['image_encoder', 'vision_model', 'encoder', 'mpl', 'vision_embed_tokens']
+
+ def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
+ super().__init__(model_id, device, device_map, extra_params, format)
+
+ del self.params['trust_remote_code']
+
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+ clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
+ self.processor = DragonflyProcessor(image_processor=clip_processor.image_processor, tokenizer=self.tokenizer, image_encoding_style="llava-hd")
+
+ self.model = DragonflyForCausalLM.from_pretrained(**self.params)
+
+ # bitsandbytes already moves the model to the device, so we don't need to do it again.
+ if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
+ self.model = self.model.to(dtype=self.dtype, device=self.device)
+
+ self.eos_id = "<|eot_id|>"
+ self.eos_token_id = self.tokenizer.encode(self.eos_id, add_special_tokens=False)
+
+ print(f"Loaded {model_id} on device: {self.model.device} with dtype: {self.model.dtype}")
+
+ async def stream_chat_with_images(self, request: ImageChatRequest):
+ images, prompt = await llama3_prompt_from_messages(request.messages, img_tok='')
+
+ inputs = self.processor(text=[prompt], images=images, max_length=2048, return_tensors="pt", is_generate=True).to(device=self.model.device)
+
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=False, skip_prompt=True)
+
+ default_params = {
+ 'max_new_tokens': 1024,
+ 'eos_token_id': self.eos_token_id,
+ 'pad_token_id': self.eos_token_id[0],
+ }
+
+ params = self.get_generation_params(request, default_params=default_params)
+
+ generation_kwargs = dict(
+ **inputs,
+ **params,
+ streamer=streamer,
+ )
+
+ t = Thread(target=self.model.generate, kwargs=generation_kwargs)
+ t.start()
+
+ for new_text in streamer:
+ end = new_text.find(self.eos_id)
+ if end == -1:
+ yield new_text
+ else:
+ yield new_text[:end]
+ break
diff --git a/backend/internvl-chat-v1-5.py b/backend/internvl-chat-v1-5.py
index 4dcd7be..06d9539 100644
--- a/backend/internvl-chat-v1-5.py
+++ b/backend/internvl-chat-v1-5.py
@@ -1,4 +1,5 @@
import os
+from threading import Thread
from transformers import AutoTokenizer, AutoModel
from vision_qna import *
import torch
@@ -122,6 +123,8 @@ async def chat_with_images(self, request: ImageChatRequest) -> str:
else:
images, prompt = await chatml_prompt_from_messages(request.messages, img_tok='')
+ # TODO: use detail to set max tiles if detail=low (=512)
+ # if .detail == 'low': max_num=1
images = [load_image(image, max_num=self.max_tiles).to(self.model.dtype).cuda() for image in images]
if len(images) > 1:
pixel_values = torch.cat(images, dim=0)
@@ -153,3 +156,54 @@ async def chat_with_images(self, request: ImageChatRequest) -> str:
response = self.tokenizer.decode(output[0], skip_special_tokens=True)
return response.split(self.eos_token)[0].strip()
+
+ async def stream_chat_with_images(self, request: ImageChatRequest):
+ if self.format == 'phintern':
+ images, prompt = await phintern_prompt_from_messages(request.messages, img_tok='')
+ else:
+ images, prompt = await chatml_prompt_from_messages(request.messages, img_tok='')
+
+ # TODO: use detail to set max tiles if detail=low (=512)
+ # if .detail == 'low': max_num=1
+ images = [load_image(image, max_num=self.max_tiles).to(self.model.dtype).cuda() for image in images]
+ if len(images) > 1:
+ pixel_values = torch.cat(images, dim=0)
+ else:
+ pixel_values = images[0]
+
+ default_params = {
+ 'num_beams': 1,
+ 'max_new_tokens': 512,
+ 'do_sample': False,
+ 'eos_token_id': self.eos_token_id,
+ }
+
+ generation_config = self.get_generation_params(request, default_params)
+
+ del generation_config['use_cache']
+
+ image_tokens = '' + '' * self.model.num_image_token * pixel_values.shape[0] + '\n'
+ model_inputs = self.tokenizer(image_tokens + prompt, return_tensors='pt')
+ input_ids = model_inputs['input_ids'].cuda()
+ attention_mask = model_inputs['attention_mask'].cuda()
+
+ streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=False, skip_prompt=True)
+
+ generation_kwargs = dict(
+ pixel_values=pixel_values,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ **generation_config,
+ streamer=streamer,
+ )
+
+ t = Thread(target=self.model.generate, kwargs=generation_kwargs)
+ t.start()
+
+ for new_text in streamer:
+ end = new_text.find(self.eos_token)
+ if end == -1:
+ yield new_text
+ else:
+ yield new_text[:end]
+ break
diff --git a/chat_with_image.py b/chat_with_image.py
index 9210212..50776ae 100755
--- a/chat_with_image.py
+++ b/chat_with_image.py
@@ -35,6 +35,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
parser.add_argument('-p', '--top_p', type=float, default=None)
parser.add_argument('-u', '--keep-remote-urls', action='store_true', help="Normally, http urls are converted to data: urls for better latency.")
parser.add_argument('-1', '--single', action='store_true', help='Single turn Q&A, output is only the model response.')
+ parser.add_argument('--no-stream', action='store_true', help='Disable streaming response.')
parser.add_argument('image_url', type=str, help='URL or image file to be tested')
parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image')
args = parser.parse_args()
@@ -48,6 +49,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
params['temperature'] = args.temperature
if args.top_p is not None:
params['top_p'] = args.top_p
+ params['stream'] = not args.no_stream
image_url = args.image_url
@@ -64,17 +66,30 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
while True:
if args.start_with:
messages.extend([{ "role": "assistant", "content": [{ "type": "text", "text": args.start_with }] }])
+
response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, **params)
+ if not args.single:
+ print(f"Answer: ", end='', flush=True)
+
+ assistant_text = ''
+
+ if args.no_stream:
+ assistant_text = response.choices[0].message.content
+ print(assistant_text)
+ else:
+ for chunk in response:
+ assistant_text += chunk.choices[0].delta.content
+ print(chunk.choices[0].delta.content, end='', flush=True)
+
+ print('')
+
if args.single:
- print(response.choices[0].message.content)
break
- print(f"Answer: {response.choices[0].message.content}\n")
-
image_url = None
try:
- q = input("Question: ")
+ q = input("\nQuestion: ")
if q.startswith('http') or q.startswith('data:') or q.startswith('file:'):
image_url = q
@@ -90,7 +105,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
break
content = [{"type": "image_url", "image_url": { "url": image_url } }] if image_url else []
- content.extend([{ 'type': 'text', 'text': response.choices[0].message.content }])
+ content.extend([{ 'type': 'text', 'text': assistant_text }])
messages.extend([{ "role": "assistant", "content": content },
{ "role": "user", "content": [{ 'type': 'text', 'text': q }] }])
diff --git a/model_conf_tests.json b/model_conf_tests.json
index 911ee94..ff6e537 100644
--- a/model_conf_tests.json
+++ b/model_conf_tests.json
@@ -103,6 +103,10 @@
["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0"],
["tiiuae/falcon-11B-vlm", "--use-flash-attn", "--load-in-4bit"],
["tiiuae/falcon-11B-vlm", "--use-flash-attn"],
+ ["togethercomputer/Llama-3-8B-Dragonfly-Med-v1", "--load-in-4bit"],
+ ["togethercomputer/Llama-3-8B-Dragonfly-Med-v1"],
+ ["togethercomputer/Llama-3-8B-Dragonfly-v1", "--load-in-4bit"],
+ ["togethercomputer/Llama-3-8B-Dragonfly-v1"],
["vikhyatk/moondream2", "--use-flash-attn", "--load-in-4bit"],
["vikhyatk/moondream2", "--use-flash-attn"]
]
diff --git a/openedai.py b/openedai.py
index 64be255..3a610b6 100644
--- a/openedai.py
+++ b/openedai.py
@@ -1,6 +1,7 @@
-from fastapi import FastAPI
+from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
+from loguru import logger
class OpenAIStub(FastAPI):
def __init__(self, **kwargs) -> None:
@@ -15,6 +16,20 @@ def __init__(self, **kwargs) -> None:
allow_headers=["*"]
)
+ @self.middleware("http")
+ async def log_requests(request: Request, call_next):
+ logger.debug(f"Request path: {request.url.path}")
+ logger.debug(f"Request method: {request.method}")
+ logger.debug(f"Request headers: {request.headers}")
+ logger.debug(f"Request query params: {request.query_params}")
+
+ response = await call_next(request)
+
+ logger.debug(f"Response status code: {response.status_code}")
+ logger.debug(f"Response headers: {response.headers}")
+
+ return response
+
@self.get('/v1/billing/usage')
@self.get('/v1/dashboard/billing/usage')
async def handle_billing_usage():
diff --git a/test_models.py b/test_models.py
index 594e307..d6a9340 100755
--- a/test_models.py
+++ b/test_models.py
@@ -178,6 +178,24 @@ def generate_response(image_url, prompt):
answer = response.choices[0].message.content
return answer
+ def generate_stream_response(image_url, prompt):
+
+ messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else []
+ messages.extend([
+ { "role": "user", "content": [
+ { "type": "image_url", "image_url": { "url": image_url } },
+ { "type": "text", "text": prompt },
+ ]}])
+
+ response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, **params, stream=True)
+ answer = ''
+ for chunk in response:
+ if chunk.choices[0].delta.content:
+ answer += chunk.choices[0].delta.content
+
+ return answer
+
+
def single_round():
# XXX TODO: timeout
@@ -206,6 +224,17 @@ def single_round():
else:
print(f"{name}[data]: pass{', got: ' + answer if args.verbose else ''}")
+ answer = generate_stream_response(data_url, "What is the subject of the image?")
+ correct = name in answer.lower()
+ results.extend([correct])
+ if not correct:
+ print(f"{name}[data_stream]: fail, got: {answer}")
+ if args.abort_on_fail:
+ break
+ else:
+ print(f"{name}[data_stream]: pass{', got: ' + answer if args.verbose else ''}")
+
+
return results
with open('model_conf_tests.json') as f:
diff --git a/vision.py b/vision.py
index a71ab71..e0e578e 100644
--- a/vision.py
+++ b/vision.py
@@ -3,10 +3,13 @@
import os
import sys
import time
+import json
import argparse
import importlib
from contextlib import asynccontextmanager
import uvicorn
+from sse_starlette import EventSourceResponse
+from loguru import logger
import openedai
import torch
@@ -26,33 +29,78 @@ async def lifespan(app):
@app.post(path="/v1/chat/completions")
async def vision_chat_completions(request: ImageChatRequest):
+ t_id = int(time.time())
+ r_id = f"chatcmpl-{t_id}"
+
+ if request.stream:
+ def chat_streaming_chunk(content):
+ chunk = {
+ "id": r_id,
+ "object": "chat.completions.chunk",
+ "created": t_id,
+ "model": vision_qna.model_name,
+ "choices": [{
+ "index": 0,
+ "finish_reason": None,
+ "delta": {'role': 'assistant', 'content': content},
+ }],
+ }
+ return chunk
+
+ async def streamer():
+ yield {"data": json.dumps(chat_streaming_chunk(''))}
+
+ # TODO: count tokens
+ dat = ''
+ async for resp in vision_qna.stream_chat_with_images(request):
+ print(resp, end='')
+ dat += resp
+ if not resp or chr(0xfffd) in dat: # partial unicode char
+ continue
+
+ yield {"data": json.dumps(chat_streaming_chunk(dat))}
+ dat = ''
+
+ chunk = chat_streaming_chunk(dat)
+ chunk['choices'][0]['finish_reason'] = "stop" # XXX
+ chunk['usage'] = {
+ "prompt_tokens": 1, # XXX
+ "completion_tokens": 1, # XXX
+ "total_tokens": 1, # XXX
+ }
+
+ yield {"data": json.dumps(chunk)}
+
+ return EventSourceResponse(streamer())
+ # else:
+
text = await vision_qna.chat_with_images(request)
- choices = [ {
+ vis_chat_resp = {
+ "id": r_id,
+ "object": "chat.completion", # chat.completions.chunk for stream
+ "created": t_id,
+ "model": vision_qna.model_name,
+ "system_fingerprint": "fp_111111111",
+ "choices": [ {
"index": 0,
"message": {
"role": "assistant",
"content": text,
},
"logprobs": None,
- "finish_reason": "stop"
- }
- ]
- t_id = int(time.time())
- vis_chat_resp = {
- "id": f"chatcmpl-{t_id}",
- "object": "chat.completion",
- "created": t_id,
- "model": vision_qna.model_name,
- "system_fingerprint": "fp_111111111",
- "choices": choices,
+ "finish_reason": "stop", # XXX
+ } ],
"usage": {
- "prompt_tokens": 0,
- "completion_tokens": 0,
- "total_tokens": 0
+ "prompt_tokens": 0, # XXX
+ "completion_tokens": 0, # XXX
+ "total_tokens": 0, # XXX
}
}
+ if os.environ.get('OPENEDAI_DEBUG', False):
+ print(f'Response: {vis_chat_resp}')
+
return vis_chat_resp
def parse_args(argv=None):
@@ -71,6 +119,7 @@ def parse_args(argv=None):
parser.add_argument('-8', '--load-in-8bit', action='store_true', help="load in 8bit (doesn't work with all models)")
parser.add_argument('-F', '--use-flash-attn', action='store_true', help="Use Flash Attention 2 (doesn't work with all models or GPU)")
parser.add_argument('-T', '--max-tiles', action='store', default=None, type=int, help="Change the maximum number of tiles. [1-40+] (uses more VRAM for higher resolution, doesn't work with all models)")
+ parser.add_argument('-L', '--log-level', default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the log level")
parser.add_argument('-P', '--port', action='store', default=5006, type=int, help="Server tcp port")
parser.add_argument('-H', '--host', action='store', default='0.0.0.0', help="Host to listen on, Ex. localhost")
parser.add_argument('--preload', action='store_true', help="Preload model and exit.")
@@ -95,6 +144,9 @@ def parse_args(argv=None):
if args.max_tiles:
extra_params['max_tiles'] = args.max_tiles
+ logger.remove()
+ logger.add(sink=sys.stderr, level=args.log_level)
+
extra_params['trust_remote_code'] = not args.no_trust_remote_code
if args.max_memory:
dev_map_max_memory = {int(dev_id) if dev_id not in ['cpu', 'disk'] else dev_id: mem for dev_id, mem in [dev_mem.split(':') for dev_mem in args.max_memory.split(',')]}
diff --git a/vision.sample.env b/vision.sample.env
index b0a98b3..9f48286 100644
--- a/vision.sample.env
+++ b/vision.sample.env
@@ -4,109 +4,113 @@ HF_HOME=hf_home
HF_HUB_ENABLE_HF_TRANSFER=1
#HF_TOKEN=hf-...
#CUDA_VISIBLE_DEVICES=1,0
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 11.9s, mem: 8.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V" # test pass✅, time: 6.6s, mem: 19.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh --load-in-4bit" # test pass✅, time: 7.7s, mem: 9.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh" # test pass✅, time: 4.8s, mem: 10.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B --load-in-4bit" # test pass✅, time: 11.2s, mem: 8.2GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B" # test pass✅, time: 7.2s, mem: 11.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B-zh --load-in-4bit" # test fail❌, time: 4.9s, mem: 8.3GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B-zh" # test pass✅, time: 6.2s, mem: 12.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B --load-in-4bit" # test pass✅, time: 8.7s, mem: 4.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B" # test pass✅, time: 5.8s, mem: 12.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B --load-in-4bit" # test pass✅, time: 10.9s, mem: 5.2GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B" # test pass✅, time: 7.6s, mem: 13.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 11.7s, mem: 9.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V" # test pass✅, time: 7.5s, mem: 19.6GB, 8/8 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 18.2s, mem: 8.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-Llama-3-8B-V" # test pass✅, time: 8.8s, mem: 19.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh --load-in-4bit" # test pass✅, time: 10.2s, mem: 9.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-2B-zh" # test pass✅, time: 5.8s, mem: 10.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B --load-in-4bit" # test pass✅, time: 16.5s, mem: 8.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B" # test pass✅, time: 10.4s, mem: 11.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B-zh --load-in-4bit" # test fail❌, time: 5.0s, mem: 8.3GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-3B-zh" # test pass✅, time: 8.4s, mem: 12.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B --load-in-4bit" # test pass✅, time: 13.1s, mem: 4.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_0-4B" # test pass✅, time: 7.8s, mem: 12.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B --load-in-4bit" # test pass✅, time: 15.9s, mem: 5.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-4B" # test pass✅, time: 11.1s, mem: 13.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V --load-in-4bit" # test pass✅, time: 17.6s, mem: 9.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Bunny-v1_1-Llama-3-8B-V" # test pass✅, time: 11.1s, mem: 19.6GB, 12/12 tests passed.
#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --max-memory=0:78GiB,1:20GiB --load-in-4bit" # test fail❌, time: -1.0s, mem: -1.0GB, Error: Server failed to start (exit).
-#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --max-memory=0:78GiB,1:20GiB" # test pass✅, time: 22.2s, mem: 78.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 8.3s, mem: 10.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b --use-flash-attn --device-map cuda:0" # test pass✅, time: 7.1s, mem: 22.5GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-AWQ --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.6s, mem: 12.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 12.3s, mem: 10.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty --use-flash-attn --device-map cuda:0" # test pass✅, time: 9.0s, mem: 22.4GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty-AWQ --use-flash-attn --device-map cuda:0" # test pass✅, time: 10.0s, mem: 12.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 16.5s, mem: 25.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40 --load-in-4bit" # test pass✅, time: 20.7s, mem: 28.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40" # test pass✅, time: 17.3s, mem: 54.5GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0" # test pass✅, time: 13.2s, mem: 52.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5-Int8 --device-map cuda:0" # test pass✅, time: 26.0s, mem: 31.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --load-in-4bit" # test pass✅, time: 4.8s, mem: 5.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40 --load-in-4bit" # test pass✅, time: 5.1s, mem: 6.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40" # test pass✅, time: 4.6s, mem: 8.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5" # test pass✅, time: 3.9s, mem: 7.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5 --load-in-4bit" # test pass✅, time: 7.5s, mem: 6.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5 --max-tiles 40 --load-in-4bit" # test pass✅, time: 9.7s, mem: 12.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5 --max-tiles 40" # test pass✅, time: 9.0s, mem: 15.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5" # test pass✅, time: 6.9s, mem: 11.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat --load-in-4bit" # test fail❌, time: 3.4s, mem: 6.8GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat" # test pass✅, time: 4.4s, mem: 19.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf --load-in-4bit" # test pass✅, time: 18.7s, mem: 12.4GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf" # test pass✅, time: 14.5s, mem: 37.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf --load-in-4bit" # test pass✅, time: 19.4s, mem: 12.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf" # test pass✅, time: 13.5s, mem: 36.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chat-19B --load-in-4bit" # test pass✅, time: 25.7s, mem: 15.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chat-19B" # test pass✅, time: 21.5s, mem: 40.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chinese-chat-19B --load-in-4bit" # test pass✅, time: 79.2s, mem: 15.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chinese-chat-19B" # test pass✅, time: 69.7s, mem: 40.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 60.0s, mem: 16.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0" # test pass✅, time: 35.1s, mem: 27.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0 --load-in-4bit" # test pass✅, time: 7.0s, mem: 11.2GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0" # test pass✅, time: 6.4s, mem: 20.4GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.3s, mem: 7.2GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0" # test pass✅, time: 6.0s, mem: 17.4GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 6.9s, mem: 8.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 --use-flash-attn --device-map cuda:0" # test pass✅, time: 4.7s, mem: 18.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m YanweiLi/MGM-2B --use-flash-attn --load-in-4bit" # test fail❌, time: 4.1s, mem: 4.9GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m YanweiLi/MGM-2B --use-flash-attn" # test pass✅, time: 4.0s, mem: 8.4GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 15.2s, mem: 15.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0" # test pass✅, time: 13.4s, mem: 25.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m echo840/Monkey --load-in-4bit" # test pass✅, time: 6.4s, mem: 15.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m echo840/Monkey" # test pass✅, time: 6.0s, mem: 21.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat --load-in-4bit" # test pass✅, time: 10.1s, mem: 15.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat" # test pass✅, time: 7.7s, mem: 21.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha --use-flash-attn --load-in-4bit" # test pass✅, time: 7.5s, mem: 6.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha --use-flash-attn" # test pass✅, time: 6.1s, mem: 12.3GB, 8/8 tests passed.
+#CLI_COMMAND="python vision.py -m BAAI/Emu2-Chat --max-memory=0:78GiB,1:20GiB" # test pass✅, time: 31.0s, mem: 78.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 10.9s, mem: 11.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b --use-flash-attn --device-map cuda:0" # test pass✅, time: 10.3s, mem: 22.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-AWQ --use-flash-attn --device-map cuda:0" # test pass✅, time: 12.4s, mem: 12.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 19.3s, mem: 10.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty --use-flash-attn --device-map cuda:0" # test pass✅, time: 13.3s, mem: 22.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m HuggingFaceM4/idefics2-8b-chatty-AWQ --use-flash-attn --device-map cuda:0" # test pass✅, time: 14.5s, mem: 12.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --load-in-4bit" # test pass✅, time: 24.8s, mem: 25.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40 --load-in-4bit" # test pass✅, time: 31.2s, mem: 28.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0 --max-tiles 40" # test pass✅, time: 25.6s, mem: 54.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5 --device-map cuda:0" # test pass✅, time: 19.4s, mem: 52.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/InternVL-Chat-V1-5-Int8 --device-map cuda:0" # test pass✅, time: 39.7s, mem: 31.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --load-in-4bit" # test pass✅, time: 6.7s, mem: 5.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40 --load-in-4bit" # test pass✅, time: 7.3s, mem: 6.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5 --max-tiles 40" # test pass✅, time: 6.2s, mem: 8.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-2B-V1-5" # test pass✅, time: 5.5s, mem: 7.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5 --load-in-4bit" # test pass✅, time: 11.1s, mem: 6.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5 --max-tiles 40 --load-in-4bit" # test pass✅, time: 14.3s, mem: 11.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5 --max-tiles 40" # test pass✅, time: 13.2s, mem: 15.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m OpenGVLab/Mini-InternVL-Chat-4B-V1-5" # test pass✅, time: 9.9s, mem: 11.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat --load-in-4bit" # test fail❌, time: 3.4s, mem: 6.6GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat" # test pass✅, time: 6.3s, mem: 19.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf --load-in-4bit" # test pass✅, time: 29.4s, mem: 12.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf" # test pass✅, time: 22.1s, mem: 37.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf --load-in-4bit" # test pass✅, time: 30.0s, mem: 11.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf" # test pass✅, time: 20.5s, mem: 36.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chat-19B --load-in-4bit" # test pass✅, time: 38.0s, mem: 15.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chat-19B" # test pass✅, time: 32.0s, mem: 40.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chinese-chat-19B --load-in-4bit" # test pass✅, time: 123.0s, mem: 15.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chinese-chat-19B" # test pass✅, time: 102.6s, mem: 40.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 91.3s, mem: 16.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0" # test pass✅, time: 52.1s, mem: 27.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.7s, mem: 11.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0" # test pass✅, time: 8.8s, mem: 20.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 14.4s, mem: 7.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.9s, mem: 17.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.7s, mem: 8.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-siglip-llama3 --use-flash-attn --device-map cuda:0" # test pass✅, time: 6.6s, mem: 18.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m YanweiLi/MGM-2B --use-flash-attn --load-in-4bit" # test fail❌, time: 3.9s, mem: 5.0GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m YanweiLi/MGM-2B --use-flash-attn" # test pass✅, time: 5.8s, mem: 8.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 23.6s, mem: 15.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0" # test pass✅, time: 19.0s, mem: 25.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m echo840/Monkey --load-in-4bit" # test pass✅, time: 8.9s, mem: 15.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m echo840/Monkey" # test pass✅, time: 8.5s, mem: 21.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat --load-in-4bit" # test pass✅, time: 14.8s, mem: 15.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m echo840/Monkey-Chat" # test pass✅, time: 10.9s, mem: 21.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha --use-flash-attn --load-in-4bit" # test pass✅, time: 11.0s, mem: 7.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m failspy/Phi-3-vision-128k-instruct-abliterated-alpha --use-flash-attn" # test pass✅, time: 9.1s, mem: 12.4GB, 12/12 tests passed.
#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: -1.0s, mem: -1.0GB, Error: Server failed to start (exit).
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 17.9s, mem: 25.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.7s, mem: 5.7GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 17.7s, mem: 19.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b-4bit --use-flash-attn" # test pass✅, time: 10.4s, mem: 9.5GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.3s, mem: 2.7GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b --use-flash-attn --device-map cuda:0" # test pass✅, time: 5.6s, mem: 7.2GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.7s, mem: 6.0GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 15.6s, mem: 20.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b-4bit --use-flash-attn" # test pass✅, time: 9.9s, mem: 11.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/bakLlava-v1-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 2.2s, mem: 5.5GB, 0/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/bakLlava-v1-hf --use-flash-attn --device-map cuda:0" # test fail❌, time: 1.8s, mem: 15.4GB, 0/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 10.8s, mem: 8.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0" # test pass✅, time: 6.4s, mem: 26.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 8.5s, mem: 5.0GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0" # test pass✅, time: 5.5s, mem: 14.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 49.3s, mem: 21.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf --use-flash-attn" # test pass✅, time: 45.6s, mem: 68.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-mistral-7b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 16.4s, mem: 7.8GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-mistral-7b-hf --use-flash-attn" # test pass✅, time: 12.7s, mem: 17.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 13.3s, mem: 16.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf --use-flash-attn" # test pass✅, time: 9.2s, mem: 33.6GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 13.8s, mem: 9.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf --use-flash-attn" # test pass✅, time: 8.2s, mem: 18.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct --use-flash-attn --load-in-4bit" # test pass✅, time: 7.5s, mem: 7.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct --use-flash-attn" # test pass✅, time: 6.2s, mem: 12.4GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 12.8s, mem: 12.5GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.2s, mem: 21.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.1s, mem: 3.4GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V --use-flash-attn --device-map cuda:0" # test pass✅, time: 5.8s, mem: 7.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.1s, mem: 3.5GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2 --use-flash-attn --device-map cuda:0" # test pass✅, time: 6.5s, mem: 11.5GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m qihoo360/360VL-70B --use-flash-attn --load-in-4bit" # test fail❌, time: 4.3s, mem: 37.7GB, Test failed with Exception: Internal Server Error
-#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B --use-flash-attn --load-in-4bit" # test pass✅, time: 10.2s, mem: 8.2GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B --use-flash-attn" # test pass✅, time: 5.5s, mem: 17.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 10.0s, mem: 7.7GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA --use-flash-attn --device-map cuda:0" # test pass✅, time: 5.4s, mem: 8.1GB, 8/8 tests passed.
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-4khd-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 26.5s, mem: 25.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 4.8s, mem: 5.8GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 24.1s, mem: 19.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b-4bit --use-flash-attn" # test pass✅, time: 15.4s, mem: 9.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.5s, mem: 2.8GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-1_8b --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.4s, mem: 7.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.4s, mem: 6.1GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0" # test pass✅, time: 22.4s, mem: 20.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b-4bit --use-flash-attn" # test pass✅, time: 20.3s, mem: 10.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/bakLlava-v1-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.0s, mem: 5.5GB, 0/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/bakLlava-v1-hf --use-flash-attn --device-map cuda:0" # test fail❌, time: 2.1s, mem: 15.5GB, 0/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 16.6s, mem: 8.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0" # test pass✅, time: 9.3s, mem: 26.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 13.1s, mem: 5.0GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0" # test pass✅, time: 7.9s, mem: 14.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 76.5s, mem: 21.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-34b-hf --use-flash-attn" # test pass✅, time: 68.9s, mem: 68.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-mistral-7b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 25.8s, mem: 7.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-mistral-7b-hf --use-flash-attn" # test pass✅, time: 18.3s, mem: 17.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 20.3s, mem: 16.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-13b-hf --use-flash-attn" # test pass✅, time: 13.4s, mem: 33.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf --use-flash-attn --load-in-4bit" # test pass✅, time: 21.3s, mem: 9.5GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m llava-hf/llava-v1.6-vicuna-7b-hf --use-flash-attn" # test pass✅, time: 12.4s, mem: 19.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct --use-flash-attn --load-in-4bit" # test pass✅, time: 11.0s, mem: 7.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m microsoft/Phi-3-vision-128k-instruct --use-flash-attn" # test pass✅, time: 8.7s, mem: 12.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 18.6s, mem: 12.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-Llama3-V-2_5 --use-flash-attn --device-map cuda:0" # test pass✅, time: 11.9s, mem: 21.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 2.9s, mem: 3.5GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V --use-flash-attn --device-map cuda:0" # test pass✅, time: 8.1s, mem: 7.7GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test fail❌, time: 3.0s, mem: 3.6GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2 --use-flash-attn --device-map cuda:0" # test pass✅, time: 9.3s, mem: 11.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m qihoo360/360VL-70B --use-flash-attn --load-in-4bit" # test fail❌, time: 4.3s, mem: 37.8GB, Test failed with Exception: Internal Server Error
+#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B --use-flash-attn --load-in-4bit" # test pass✅, time: 15.5s, mem: 8.1GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m qihoo360/360VL-8B --use-flash-attn" # test pass✅, time: 7.8s, mem: 17.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 15.4s, mem: 7.8GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m qnguyen3/nanoLLaVA --use-flash-attn --device-map cuda:0" # test pass✅, time: 7.2s, mem: 8.2GB, 12/12 tests passed.
#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0 --load-in-4bit" # test fail❌, time: -1.0s, mem: -1.0GB, Error: Server failed to start (exit).
-#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0" # test pass✅, time: 6.4s, mem: 19.1GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m tiiuae/falcon-11B-vlm --use-flash-attn --load-in-4bit" # test pass✅, time: 12.1s, mem: 16.5GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m tiiuae/falcon-11B-vlm --use-flash-attn" # test pass✅, time: 10.9s, mem: 32.3GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 --use-flash-attn --load-in-4bit" # test pass✅, time: 5.2s, mem: 2.9GB, 8/8 tests passed.
-#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 --use-flash-attn" # test pass✅, time: 3.7s, mem: 4.6GB, 8/8 tests passed.
\ No newline at end of file
+#CLI_COMMAND="python vision.py -m qresearch/llama-3-vision-alpha-hf --device cuda:0" # test pass✅, time: 9.6s, mem: 19.2GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m tiiuae/falcon-11B-vlm --use-flash-attn --load-in-4bit" # test pass✅, time: 18.4s, mem: 16.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m tiiuae/falcon-11B-vlm --use-flash-attn" # test pass✅, time: 16.0s, mem: 32.4GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-Med-v1 --load-in-4bit" # test pass✅, time: 12.1s, mem: 7.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-Med-v1" # test pass✅, time: 9.7s, mem: 17.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-v1 --load-in-4bit" # test pass✅, time: 15.0s, mem: 7.6GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m togethercomputer/Llama-3-8B-Dragonfly-v1" # test pass✅, time: 13.2s, mem: 17.3GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 --use-flash-attn --load-in-4bit" # test pass✅, time: 7.6s, mem: 2.9GB, 12/12 tests passed.
+#CLI_COMMAND="python vision.py -m vikhyatk/moondream2 --use-flash-attn" # test pass✅, time: 5.3s, mem: 4.7GB, 12/12 tests passed.
diff --git a/vision_qna.py b/vision_qna.py
index 07f7ad6..cafd1ee 100644
--- a/vision_qna.py
+++ b/vision_qna.py
@@ -1,13 +1,13 @@
+import asyncio
import io
import uuid
import requests
from datauri import DataURI
from PIL import Image
import torch
-from typing import Optional, List, Literal
+from typing import Optional, List, Literal, AsyncGenerator
from pydantic import BaseModel
-from transformers import BitsAndBytesConfig
-from transformers.image_utils import load_image
+from transformers import BitsAndBytesConfig, TextIteratorStreamer
class ImageURL(BaseModel):
url: str
@@ -28,6 +28,7 @@ class ImageChatRequest(BaseModel):
max_tokens: int = 512
temperature: float = None
top_p: float = None
+ stream: bool = False
class VisionQnABase:
model_name: str = None
@@ -91,8 +92,13 @@ def select_device_dtype(self, device):
dtype = self.select_dtype(device)
return device, dtype
+ # implement one or both of the stream/chat_with_images functions
async def chat_with_images(self, request: ImageChatRequest) -> str:
- pass
+ return ''.join([r async for r in self.stream_chat_with_images(request)])
+
+ # implement one or both of the stream/chat_with_images functions
+ async def stream_chat_with_images(self, request: ImageChatRequest):
+ yield await self.chat_with_images(request)
def get_generation_params(self, request: ImageChatRequest, default_params = {}) -> dict:
params = {
@@ -117,7 +123,6 @@ def get_generation_params(self, request: ImageChatRequest, default_params = {})
return params
async def url_to_image(img_url: str) -> Image.Image:
- #return load_image(img_url)
if img_url.startswith('http'):
response = requests.get(img_url)
@@ -600,7 +605,7 @@ async def glm4v_prompt_from_messages(messages: list[Message], img_tok = "<|begin
for c in m.content:
if c.type == 'image_url':
- images.extend([ await url_to_image(c.image_url.url) ])
+ images.extend([ await url_handler(c.image_url.url) ])
img_tag += img_tok
for c in m.content:
@@ -746,4 +751,6 @@ def guess_backend(model_name: str) -> str:
if 'falcon' in model_id:
return 'llavanext'
-
\ No newline at end of file
+
+ if 'dragonfly' in model_id:
+ return 'dragonfly'
\ No newline at end of file