Skip to content

Commit

Permalink
0.23.0 +streaming, +dragonfly
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Jun 8, 2024
1 parent ab6dd4d commit c62212d
Show file tree
Hide file tree
Showing 11 changed files with 400 additions and 137 deletions.
15 changes: 9 additions & 6 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
FROM python:3.11-slim

RUN apt-get update && apt-get install -y git gcc
RUN pip install --no-cache-dir --upgrade pip
RUN apt-get update && apt-get install -y git gcc \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN --mount=type=cache,target=/root/.cache/pip pip install --upgrade pip

RUN mkdir -p /app
WORKDIR /app
RUN git clone https://github.com/01-ai/Yi --single-branch /app/Yi
RUN git clone https://github.com/dvlab-research/MGM.git --single-branch /app/MGM
RUN git clone https://github.com/TIGER-AI-Lab/Mantis.git --single-branch /app/Mantis
RUN git clone https://github.com/togethercomputer/Dragonfly --single-branch /app/Dragonfly

WORKDIR /app
COPY requirements.txt .
ARG VERSION=latest
RUN if [ "$VERSION" = "alt" ]; then echo "transformers==4.36.2" >> requirements.txt; else echo "transformers>=4.41.2\nautoawq>=0.2.5" >> requirements.txt ; fi
Expand All @@ -21,12 +22,14 @@ RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .
WORKDIR /app/Mantis
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .

WORKDIR /app/Dragonfly
RUN --mount=type=cache,target=/root/.cache/pip pip install --no-deps -e .

WORKDIR /app

COPY *.py .
COPY backend /app/backend

COPY model_conf_tests.json /app/model_conf_tests.json
COPY model_conf_tests.json .

ENV CLI_COMMAND="python vision.py"
CMD $CLI_COMMAND
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview`
- - [X] [Mantis-8B-siglip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-siglip-llama3) (wont gpu split)
- - [X] [Mantis-8B-clip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-clip-llama3) (wont gpu split)
- - [X] [Mantis-8B-Fuyu](https://huggingface.co/TIGER-Lab/Mantis-8B-Fuyu) (wont gpu split)
- [X] [Together.ai](https://huggingface.co/togethercomputer)
- - [X] [Llama-3-8B-Dragonfly-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-v1)
- - [X] [Llama-3-8B-Dragonfly-Med-v1](https://huggingface.co/togethercomputer/Llama-3-8B-Dragonfly-Med-v1)
- [X] [fuyu-8b](https://huggingface.co/adept/fuyu-8b) [pretrain]
- [X] [falcon-11B-vlm](https://huggingface.co/tiiuae/falcon-11B-vlm)
- [X] [Monkey-Chat](https://huggingface.co/echo840/Monkey-Chat)
Expand Down Expand Up @@ -100,6 +103,12 @@ See: [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_le

## Recent updates

Version 0.23.0

- New model support: Together.ai's Llama-3-8B-Dragonfly-v1, Llama-3-8B-Dragonfly-Med-v1 (medical image model)
- Compatibility: chatboxai.app can now use openedai-vision as a backend!
- Initial support for streaming (real streaming for some [dragonfly, internvl-chat-v1-5], fake streaming for the rest). More to come.

Version 0.22.0

- new model support: THUDM/glm-4v-9b
Expand Down
71 changes: 71 additions & 0 deletions backend/dragonfly.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from threading import Thread
from transformers import AutoTokenizer, AutoProcessor, logging
from dragonfly.models.modeling_dragonfly import DragonflyForCausalLM
from dragonfly.models.processing_dragonfly import DragonflyProcessor

import warnings
# disable some warnings
logging.set_verbosity_error()
warnings.filterwarnings('ignore')

from vision_qna import *

# togethercomputer/Llama-3-8B-Dragonfly-v1
# togethercomputer/Llama-3-8B-Dragonfly-Med-v1

class VisionQnA(VisionQnABase):
model_name: str = "dragonfly"
format: str = 'llama3'
vision_layers: List[str] = ['image_encoder', 'vision_model', 'encoder', 'mpl', 'vision_embed_tokens']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

del self.params['trust_remote_code']

self.tokenizer = AutoTokenizer.from_pretrained(model_id)
clip_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.processor = DragonflyProcessor(image_processor=clip_processor.image_processor, tokenizer=self.tokenizer, image_encoding_style="llava-hd")

self.model = DragonflyForCausalLM.from_pretrained(**self.params)

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(dtype=self.dtype, device=self.device)

self.eos_id = "<|eot_id|>"
self.eos_token_id = self.tokenizer.encode(self.eos_id, add_special_tokens=False)

print(f"Loaded {model_id} on device: {self.model.device} with dtype: {self.model.dtype}")

async def stream_chat_with_images(self, request: ImageChatRequest):
images, prompt = await llama3_prompt_from_messages(request.messages, img_tok='')

inputs = self.processor(text=[prompt], images=images, max_length=2048, return_tensors="pt", is_generate=True).to(device=self.model.device)

streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=False, skip_prompt=True)

default_params = {
'max_new_tokens': 1024,
'eos_token_id': self.eos_token_id,
'pad_token_id': self.eos_token_id[0],
}

params = self.get_generation_params(request, default_params=default_params)

generation_kwargs = dict(
**inputs,
**params,
streamer=streamer,
)

t = Thread(target=self.model.generate, kwargs=generation_kwargs)
t.start()

for new_text in streamer:
end = new_text.find(self.eos_id)
if end == -1:
yield new_text
else:
yield new_text[:end]
break
54 changes: 54 additions & 0 deletions backend/internvl-chat-v1-5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from threading import Thread
from transformers import AutoTokenizer, AutoModel
from vision_qna import *
import torch
Expand Down Expand Up @@ -122,6 +123,8 @@ async def chat_with_images(self, request: ImageChatRequest) -> str:
else:
images, prompt = await chatml_prompt_from_messages(request.messages, img_tok='')

# TODO: use detail to set max tiles if detail=low (=512)
# if .detail == 'low': max_num=1
images = [load_image(image, max_num=self.max_tiles).to(self.model.dtype).cuda() for image in images]
if len(images) > 1:
pixel_values = torch.cat(images, dim=0)
Expand Down Expand Up @@ -153,3 +156,54 @@ async def chat_with_images(self, request: ImageChatRequest) -> str:
response = self.tokenizer.decode(output[0], skip_special_tokens=True)

return response.split(self.eos_token)[0].strip()

async def stream_chat_with_images(self, request: ImageChatRequest):
if self.format == 'phintern':
images, prompt = await phintern_prompt_from_messages(request.messages, img_tok='')
else:
images, prompt = await chatml_prompt_from_messages(request.messages, img_tok='')

# TODO: use detail to set max tiles if detail=low (=512)
# if .detail == 'low': max_num=1
images = [load_image(image, max_num=self.max_tiles).to(self.model.dtype).cuda() for image in images]
if len(images) > 1:
pixel_values = torch.cat(images, dim=0)
else:
pixel_values = images[0]

default_params = {
'num_beams': 1,
'max_new_tokens': 512,
'do_sample': False,
'eos_token_id': self.eos_token_id,
}

generation_config = self.get_generation_params(request, default_params)

del generation_config['use_cache']

image_tokens = '<img>' + '<IMG_CONTEXT>' * self.model.num_image_token * pixel_values.shape[0] + '</img>\n'
model_inputs = self.tokenizer(image_tokens + prompt, return_tensors='pt')
input_ids = model_inputs['input_ids'].cuda()
attention_mask = model_inputs['attention_mask'].cuda()

streamer = TextIteratorStreamer(self.tokenizer, skip_special_tokens=False, skip_prompt=True)

generation_kwargs = dict(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
**generation_config,
streamer=streamer,
)

t = Thread(target=self.model.generate, kwargs=generation_kwargs)
t.start()

for new_text in streamer:
end = new_text.find(self.eos_token)
if end == -1:
yield new_text
else:
yield new_text[:end]
break
25 changes: 20 additions & 5 deletions chat_with_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
parser.add_argument('-p', '--top_p', type=float, default=None)
parser.add_argument('-u', '--keep-remote-urls', action='store_true', help="Normally, http urls are converted to data: urls for better latency.")
parser.add_argument('-1', '--single', action='store_true', help='Single turn Q&A, output is only the model response.')
parser.add_argument('--no-stream', action='store_true', help='Disable streaming response.')
parser.add_argument('image_url', type=str, help='URL or image file to be tested')
parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image')
args = parser.parse_args()
Expand All @@ -48,6 +49,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
params['temperature'] = args.temperature
if args.top_p is not None:
params['top_p'] = args.top_p
params['stream'] = not args.no_stream

image_url = args.image_url

Expand All @@ -64,17 +66,30 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
while True:
if args.start_with:
messages.extend([{ "role": "assistant", "content": [{ "type": "text", "text": args.start_with }] }])

response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, **params)

if not args.single:
print(f"Answer: ", end='', flush=True)

assistant_text = ''

if args.no_stream:
assistant_text = response.choices[0].message.content
print(assistant_text)
else:
for chunk in response:
assistant_text += chunk.choices[0].delta.content
print(chunk.choices[0].delta.content, end='', flush=True)

print('')

if args.single:
print(response.choices[0].message.content)
break

print(f"Answer: {response.choices[0].message.content}\n")

image_url = None
try:
q = input("Question: ")
q = input("\nQuestion: ")

if q.startswith('http') or q.startswith('data:') or q.startswith('file:'):
image_url = q
Expand All @@ -90,7 +105,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
break

content = [{"type": "image_url", "image_url": { "url": image_url } }] if image_url else []
content.extend([{ 'type': 'text', 'text': response.choices[0].message.content }])
content.extend([{ 'type': 'text', 'text': assistant_text }])
messages.extend([{ "role": "assistant", "content": content },
{ "role": "user", "content": [{ 'type': 'text', 'text': q }] }])

Expand Down
4 changes: 4 additions & 0 deletions model_conf_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@
["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0"],
["tiiuae/falcon-11B-vlm", "--use-flash-attn", "--load-in-4bit"],
["tiiuae/falcon-11B-vlm", "--use-flash-attn"],
["togethercomputer/Llama-3-8B-Dragonfly-Med-v1", "--load-in-4bit"],
["togethercomputer/Llama-3-8B-Dragonfly-Med-v1"],
["togethercomputer/Llama-3-8B-Dragonfly-v1", "--load-in-4bit"],
["togethercomputer/Llama-3-8B-Dragonfly-v1"],
["vikhyatk/moondream2", "--use-flash-attn", "--load-in-4bit"],
["vikhyatk/moondream2", "--use-flash-attn"]
]
17 changes: 16 additions & 1 deletion openedai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from loguru import logger

class OpenAIStub(FastAPI):
def __init__(self, **kwargs) -> None:
Expand All @@ -15,6 +16,20 @@ def __init__(self, **kwargs) -> None:
allow_headers=["*"]
)

@self.middleware("http")
async def log_requests(request: Request, call_next):
logger.debug(f"Request path: {request.url.path}")
logger.debug(f"Request method: {request.method}")
logger.debug(f"Request headers: {request.headers}")
logger.debug(f"Request query params: {request.query_params}")

response = await call_next(request)

logger.debug(f"Response status code: {response.status_code}")
logger.debug(f"Response headers: {response.headers}")

return response

@self.get('/v1/billing/usage')
@self.get('/v1/dashboard/billing/usage')
async def handle_billing_usage():
Expand Down
29 changes: 29 additions & 0 deletions test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,24 @@ def generate_response(image_url, prompt):
answer = response.choices[0].message.content
return answer

def generate_stream_response(image_url, prompt):

messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else []
messages.extend([
{ "role": "user", "content": [
{ "type": "image_url", "image_url": { "url": image_url } },
{ "type": "text", "text": prompt },
]}])

response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, **params, stream=True)
answer = ''
for chunk in response:
if chunk.choices[0].delta.content:
answer += chunk.choices[0].delta.content

return answer



def single_round():
# XXX TODO: timeout
Expand Down Expand Up @@ -206,6 +224,17 @@ def single_round():
else:
print(f"{name}[data]: pass{', got: ' + answer if args.verbose else ''}")

answer = generate_stream_response(data_url, "What is the subject of the image?")
correct = name in answer.lower()
results.extend([correct])
if not correct:
print(f"{name}[data_stream]: fail, got: {answer}")
if args.abort_on_fail:
break
else:
print(f"{name}[data_stream]: pass{', got: ' + answer if args.verbose else ''}")


return results

with open('model_conf_tests.json') as f:
Expand Down
Loading

0 comments on commit c62212d

Please sign in to comment.