diff --git a/README.md b/README.md index 0b64f38..8a050ea 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,7 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview` - [X] [01-ai/Yi-VL](https://huggingface.co/01-ai) - - [ ] [Yi-VL-6B](https://huggingface.co/01-ai/Yi-VL-6B) (currently errors) - - [ ] [Yi-VL-34B](https://huggingface.co/01-ai/Yi-VL-34B) (currently errors) +- [X] [fuyu-8b](https://huggingface.co/adept/fuyu-8b) [pretrain] - [X] [Monkey-Chat](https://huggingface.co/echo840/Monkey-Chat) - [X] [Monkey](https://huggingface.co/echo840/Monkey) - [X] [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat) @@ -56,10 +57,11 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview` See: [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_leaderboard) -Version: 0.9.1 +Version: 0.10.0 ## Recent updates +- new model support: adept/fuyu-8b - new model support: MiniCPM-V-2 - new model support: MiniGemini-7B -> MiniGemini-8x7B-HD, alternate docker. - new openai_example.sh shell script for simple command line generation. @@ -68,16 +70,6 @@ Version: 0.9.1 - Fix: moondream1 (use alt container) - Split images into main (transformers>=4.39.0) and alt (transformers==4.36.2) - Big performance gains (10x) for some models, especially llava-v1.6-34B (`use_cache` missing from many models, all llava* models, more.) -- new model support: qnguyen3/nanoLLaVA (sub 1B model) -- Updated chat_with_image.py to include --single (-1) answer mode -- More testing -- `sample.env` contains VRAM usage and some notes about model configurations. -- new model support: MiniGemini-2B (it's still a bit complex to use, see `prepare_minigemini.sh`) -- new model support: echo840/Monkey-Chat, echo840/Monkey -- AutoGPTQ support for internlm/internlm-xcomposer2-7b-4bit, internlm/internlm-xcomposer2-vl-7b-4bit -- Automatic selection of backend, based on the model name - - ## API Documentation diff --git a/backend/fuyu.py b/backend/fuyu.py new file mode 100644 index 0000000..1d3e3d5 --- /dev/null +++ b/backend/fuyu.py @@ -0,0 +1,34 @@ +from transformers import FuyuProcessor, FuyuForCausalLM + +from vision_qna import * + +# "adept/fuyu-8b" + +class VisionQnA(VisionQnABase): + model_name: str = "fuyu" + format: str = "fuyu" + + def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None): + super().__init__(model_id, device, device_map, extra_params, format) + + if not format: + self.format = guess_model_format(model_id) + + del self.params['trust_remote_code'] # not needed. + + self.processor = FuyuProcessor.from_pretrained(model_id) + self.model = FuyuForCausalLM.from_pretrained(**self.params) + + print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") + + async def chat_with_images(self, request: ImageChatRequest) -> str: + images, prompt = await prompt_from_messages(request.messages, self.format) + + inputs = self.processor(text=prompt, images=images[0], return_tensors="pt").to(self.model.device) + + params = self.get_generation_params(request) + + output = self.model.generate(**inputs, **params) + response = self.processor.decode(output[0][inputs.input_ids.size(1):].cpu(), skip_special_tokens=True) + + return response.strip() diff --git a/chat_with_image.py b/chat_with_image.py index 76bc934..a60e47b 100755 --- a/chat_with_image.py +++ b/chat_with_image.py @@ -1,9 +1,15 @@ #!/usr/bin/env python +import os import requests import argparse from datauri import DataURI from openai import OpenAI +try: + import dotenv + dotenv.load_dotenv(override=True) +except: + pass def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> str: if img_url.startswith('http'): @@ -31,7 +37,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image') args = parser.parse_args() - client = OpenAI(base_url='http://localhost:5006/v1', api_key='skip') + client = OpenAI(base_url=os.environ.get('OPENAI_BASE_URL', 'http://localhost:5006/v1'), api_key='skip') params = {} if args.max_tokens is not None: diff --git a/model_conf_tests.alt.json b/model_conf_tests.alt.json index 90844bc..14a44f6 100644 --- a/model_conf_tests.alt.json +++ b/model_conf_tests.alt.json @@ -1,12 +1,21 @@ [ ["vikhyatk/moondream2", "--use-flash-attn"], ["vikhyatk/moondream1"], - ["echo840/Monkey"], ["echo840/Monkey-Chat"], ["THUDM/cogvlm-chat-hf"], ["THUDM/cogagent-chat-hf"], ["Qwen/Qwen-VL-Chat"], + ["YanweiLi/Mini-Gemini-2B", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-7B", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-7B-HD", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-13B", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-13B-HD", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-34B", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-34B-HD", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-8x7B", "--use-flash-attn"], + ["YanweiLi/Mini-Gemini-8x7B-HD", "--use-flash-attn"], + ["adept/fuyu-8b", "--device-map", "cuda:0"], ["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0"], ["internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn", "--device-map", "cuda:0"], ["openbmb/MiniCPM-V-2", "--use-flash-attn", "--device-map", "cuda:0"], @@ -30,15 +39,5 @@ ["YanweiLi/Mini-Gemini-34B", "--load-in-4bit", "--use-flash-attn"], ["YanweiLi/Mini-Gemini-34B-HD", "--load-in-4bit", "--use-flash-attn"], ["YanweiLi/Mini-Gemini-8x7B", "--load-in-4bit", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-8x7B-HD", "--load-in-4bit", "--use-flash-attn"], - - ["YanweiLi/Mini-Gemini-2B", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-7B", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-7B-HD", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-13B", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-13B-HD", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-34B", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-34B-HD", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-8x7B", "--use-flash-attn"], - ["YanweiLi/Mini-Gemini-8x7B-HD", "--use-flash-attn"] + ["YanweiLi/Mini-Gemini-8x7B-HD", "--load-in-4bit", "--use-flash-attn"] ] diff --git a/model_conf_tests.json b/model_conf_tests.json index baa1e49..7fe744f 100644 --- a/model_conf_tests.json +++ b/model_conf_tests.json @@ -1,13 +1,13 @@ [ ["vikhyatk/moondream2", "--use-flash-attn"], ["vikhyatk/moondream1"], - ["qnguyen3/nanoLLaVA", "--use-flash-attn"], ["echo840/Monkey"], ["echo840/Monkey-Chat"], ["THUDM/cogvlm-chat-hf"], ["THUDM/cogagent-chat-hf"], ["Qwen/Qwen-VL-Chat"], + ["adept/fuyu-8b", "--device-map", "cuda:0"], ["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0"], ["internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn", "--device-map", "cuda:0"], ["openbmb/MiniCPM-V-2", "--use-flash-attn", "--device-map", "cuda:0"], diff --git a/prepare_minigemini.sh b/prepare_minigemini.sh index dc6ef80..0a8fde5 100755 --- a/prepare_minigemini.sh +++ b/prepare_minigemini.sh @@ -1,6 +1,11 @@ #!/bin/bash export HF_HOME=hf_home +if [ -z "$(which huggingface-cli)" ]; then + echo "First install huggingface-hub: pip install huggingface-hub" + exit 1 +fi + echo "Edit this script and uncomment which models to download" huggingface-cli download OpenAI/clip-vit-large-patch14-336 --local-dir model_zoo/OpenAI/clip-vit-large-patch14-336 diff --git a/test_models.py b/test_models.py index ea93665..4a394a1 100755 --- a/test_models.py +++ b/test_models.py @@ -48,7 +48,7 @@ def record_result(cmd_args, results, t, mem, note): 'note': note }]) result = all(results) - print(f"\n#CLI_COMMAND={cmd_args} # test {'pass' if result else 'fail'}, time: {t:.1f}s, mem: {mem:.1f}GB, {note}") + print(f"#CLI_COMMAND=\"python vision.py -m {' '.join(cmd_args)}\" # test {'pass' if result else 'fail'}, time: {t:.1f}s, mem: {mem:.1f}GB, {note}") torch_memory_baseline = 0 diff --git a/vision-alt.sample.env b/vision-alt.sample.env index 21fff2a..065da26 100644 --- a/vision-alt.sample.env +++ b/vision-alt.sample.env @@ -9,6 +9,16 @@ HF_HOME=hf_home #CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf" # test pass, time: 13.4s, mem: 36.3GB, All tests passed. #CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf" # test pass, time: 14.7s, mem: 37.2GB, All tests passed. #CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat" # test pass, time: 4.9s, mem: 19.5GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-2B --use-flash-attn" # test fail, time: -1.0s, mem: -1.0GB, Error: Server failed to start (exit). +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-7B --use-flash-attn" # test pass, time: 5.4s, mem: 15.6GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-7B-HD --use-flash-attn" # test pass, time: 15.8s, mem: 18.8GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-13B --use-flash-attn" # test pass, time: 21.3s, mem: 27.6GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-13B-HD --use-flash-attn" # test pass, time: 15.9s, mem: 31.7GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-34B --use-flash-attn" # test pass, time: 11.1s, mem: 67.2GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-34B-HD --use-flash-attn" # test pass, time: 145.1s, mem: 70.3GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B --use-flash-attn" # test pass, time: 14.3s, mem: 91.3GB, All tests passed. +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B-HD --use-flash-attn" # test pass, time: 18.5s, mem: 96.1GB, All tests passed. +#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0" # test pass, time: 13.4s, mem: 25.0GB, All tests passed. #CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0" # test pass, time: 18.2s, mem: 19.0GB, All tests passed. #CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0" # test pass, time: 16.7s, mem: 20.2GB, All tests passed. #CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2 --use-flash-attn --device-map cuda:0" # test pass, time: 6.6s, mem: 11.4GB, All tests passed. @@ -16,10 +26,10 @@ HF_HOME=hf_home #CLI_COMMAND="python vision.py -m llava-hf/bakLlava-v1-hf --use-flash-attn --device-map cuda:0" # test fail, time: 2.0s, mem: 15.6GB, #CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --use-flash-attn --device-map cuda:0" # test pass, time: 5.4s, mem: 14.5GB, All tests passed. #CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --use-flash-attn --device-map cuda:0" # test pass, time: 6.6s, mem: 26.9GB, All tests passed. -#CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf --load-in-4bit" # test pass, time: 19.5s, mem: 12.2GB, All tests passed. -#CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf --load-in-4bit" # test pass, time: 20.4s, mem: 12.2GB, All tests passed. #CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b-4bit --use-flash-attn --device cuda:0" # test pass, time: 10.5s, mem: 9.5GB, All tests passed. #CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b-4bit --use-flash-attn --device cuda:0" # test pass, time: 11.6s, mem: 10.9GB, All tests passed. +#CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf --load-in-4bit" # test pass, time: 19.5s, mem: 12.2GB, All tests passed. +#CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf --load-in-4bit" # test pass, time: 20.4s, mem: 12.2GB, All tests passed. #CLI_COMMAND="python vision.py -m llava-hf/bakLlava-v1-hf --load-in-4bit --use-flash-attn" # test fail, time: 2.5s, mem: 6.0GB, #CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-7b-hf --load-in-4bit --use-flash-attn" # test pass, time: 9.2s, mem: 5.6GB, All tests passed. #CLI_COMMAND="python vision.py -m llava-hf/llava-1.5-13b-hf --load-in-4bit --use-flash-attn" # test pass, time: 10.0s, mem: 9.0GB, All tests passed. @@ -30,13 +40,4 @@ HF_HOME=hf_home #CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-34B --load-in-4bit --use-flash-attn" # test pass, time: 16.8s, mem: 21.5GB, All tests passed. #CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-34B-HD --load-in-4bit --use-flash-attn" # test pass, time: 215.3s, mem: 24.2GB, All tests passed. #CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B --load-in-4bit --use-flash-attn" # test pass, time: 22.2s, mem: 26.3GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B-HD --load-in-4bit --use-flash-attn" # test pass, time: 24.7s, mem: 29.5GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-2B --use-flash-attn" # test fail, time: -1.0s, mem: -1.0GB, Error: Server failed to start (exit). -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-7B --use-flash-attn" # test pass, time: 5.4s, mem: 15.6GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-7B-HD --use-flash-attn" # test pass, time: 15.8s, mem: 18.8GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-13B --use-flash-attn" # test pass, time: 21.3s, mem: 27.6GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-13B-HD --use-flash-attn" # test pass, time: 15.9s, mem: 31.7GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-34B --use-flash-attn" # test pass, time: 11.1s, mem: 67.2GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-34B-HD --use-flash-attn" # test pass, time: 145.1s, mem: 70.3GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B --use-flash-attn" # test pass, time: 14.3s, mem: 91.3GB, All tests passed. -#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B-HD --use-flash-attn" # test pass, time: 18.5s, mem: 96.1GB, All tests passed. \ No newline at end of file +#CLI_COMMAND="python vision.py -m YanweiLi/Mini-Gemini-8x7B-HD --load-in-4bit --use-flash-attn" # test pass, time: 24.7s, mem: 29.5GB, All tests passed. \ No newline at end of file diff --git a/vision.sample.env b/vision.sample.env index d98fccd..c4da3cd 100644 --- a/vision.sample.env +++ b/vision.sample.env @@ -10,6 +10,7 @@ HF_HOME=hf_home #CLI_COMMAND="python vision.py -m THUDM/cogvlm-chat-hf" # test pass, time: 14.1s, mem: 36.2GB, All tests passed. #CLI_COMMAND="python vision.py -m THUDM/cogagent-chat-hf" # test pass, time: 14.7s, mem: 37.2GB, All tests passed. #CLI_COMMAND="python vision.py -m Qwen/Qwen-VL-Chat" # test pass, time: 4.8s, mem: 19.5GB, All tests passed. +#CLI_COMMAND="python vision.py -m adept/fuyu-8b --device-map cuda:0" # test pass, time: 13.4s, mem: 25.0GB, All tests passed. #CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-7b --use-flash-attn --device-map cuda:0" # test pass, time: 18.3s, mem: 19.0GB, All tests passed. #CLI_COMMAND="python vision.py -m internlm/internlm-xcomposer2-vl-7b --use-flash-attn --device-map cuda:0" # test pass, time: 14.9s, mem: 20.2GB, All tests passed. #CLI_COMMAND="python vision.py -m openbmb/MiniCPM-V-2 --use-flash-attn --device-map cuda:0" # test pass, time: 6.7s, mem: 11.5GB, All tests passed. diff --git a/vision_qna.py b/vision_qna.py index e348c48..737879b 100644 --- a/vision_qna.py +++ b/vision_qna.py @@ -320,6 +320,31 @@ async def gemma_prompt_from_messages(messages: list[Message], img_tok = " return images, prompt +async def fuyu_prompt_from_messages(messages: list[Message], img_tok = "", img_end = ''): + prompt = '' + images = [] + + for m in messages: + if m.role == 'user': + p = '' + for c in m.content: + if c.type == 'image_url': + images.extend([ await url_to_image(c.image_url.url) ]) + p = img_tok + p + img_end + if c.type == 'text': + p += f"{c.text}\n\n" # Question: + prompt += p + elif m.role == 'assistant': + for c in m.content: + if c.type == 'text': + prompt += f"\x04{c.text}\n" + elif m.role == 'system': + for c in m.content: + if c.type == 'text': + prompt += f"{c.text}\n\n" # fake system prompt doesn't work. + + return images, prompt + async def prompt_history_images_system_from_messages(messages: list[Message], img_tok = "\n", url_handler = url_to_image): history = [] images = [] @@ -361,7 +386,8 @@ async def prompt_from_messages(messages: list[Message], format: str) -> str: 'llama2': llama2_prompt_from_messages, 'mistral': llama2_prompt_from_messages, # simplicity 'chatml': chatml_prompt_from_messages, - 'gemma': gemma_prompt_from_messages + 'gemma': gemma_prompt_from_messages, + 'fuyu': fuyu_prompt_from_messages, } if format not in known_formats: @@ -379,6 +405,7 @@ def guess_model_format(model_name: str) -> str: 'vicuna0': ['yi-vl'], 'phi15': ['moondream1', 'moondream2', 'monkey'], 'chatml': ['34b', 'yi-6b', 'nanollava'], + 'fuyu': ['fuyu'], } for format, options in model_format_match_map.items(): if any(x in model_id for x in options): @@ -434,4 +461,8 @@ def guess_backend(model_name: str) -> str: return 'yi-vl' if 'thudm/cog' in model_id: - return 'cogvlm' \ No newline at end of file + return 'cogvlm' + + if 'fuyu' in model_id: + return 'fuyu' + \ No newline at end of file