From fe3c8c865bd00deaf662e029e3773e4b7bb3e1ce Mon Sep 17 00:00:00 2001 From: matatonic Date: Sat, 18 May 2024 17:44:17 -0400 Subject: [PATCH] 0.14.0 +em2, +360vl --- README.md | 12 ++++++- backend/360vl.py | 65 ++++++++++++++++++++++++++++++++++++ backend/emu.py | 75 ++++++++++++++++++++++++++++++++++++++++++ docker-compose.alt.yml | 2 +- docker-compose.yml | 2 +- model_conf_tests.json | 14 ++++++-- requirements.txt | 3 ++ test_models.py | 2 ++ vision.py | 4 +++ vision_qna.py | 42 +++++++++++++++++++++-- 10 files changed, 213 insertions(+), 8 deletions(-) create mode 100644 backend/360vl.py create mode 100644 backend/emu.py diff --git a/README.md b/README.md index ec4a140..bda994e 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,9 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview` - [X] [HuggingFaceM4/idefics2](https://huggingface.co/HuggingFaceM4) - - [X] [idefics2-8b](https://huggingface.co/HuggingFaceM4/idefics2-8b) (main docker only, wont gpu split) - - [X] [idefics2-8b-AWQ](https://huggingface.co/HuggingFaceM4/idefics2-8b-AWQ) (main docker only, wont gpu split) +- [X] [qihoo360](https://huggingface.co/qihoo360) +- - [X] [360VL-8B](https://huggingface.co/qihoo360/360VL-8B) +- - [X] [360VL-70B](https://huggingface.co/qihoo360/360VL-70B) (loading error, [see note](https://huggingface.co/qihoo360/360VL-70B/discussions/1)) - [X] [LlavaNext](https://huggingface.co/llava-hf) (main docker only) - - [X] [llava-v1.6-34b-hf](https://huggingface.co/llava-hf/llava-v1.6-34b-hf) (main docker only) - - [X] [llava-v1.6-vicuna-13b-hf](https://huggingface.co/llava-hf/llava-v1.6-vicuna-13b-hf) (main docker only) @@ -39,6 +42,7 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview` - [X] [qresearch](https://huggingface.co/qresearch/) - - [X] [llama-3-vision-alpha-hf](https://huggingface.co/qresearch/llama-3-vision-alpha-hf) (main docker only, wont gpu split) - [X] [BAAI](https://huggingface.co/BAAI/) +- - [X] [Emu2-Chat](https://huggingface.co/BAAI/Emu2-Chat) (main docker only, may need the --max-memory option to GPU split) - - [X] [Bunny-Llama-3-8B-V](https://huggingface.co/BAAI/Bunny-Llama-3-8B-V) (main docker only) - [X] [TIGER-Lab](https://huggingface.co/TIGER-Lab) - - [X] [Mantis-8B-siglip-llama3](https://huggingface.co/TIGER-Lab/Mantis-8B-siglip-llama3) (main docker only, wont gpu split) @@ -76,6 +80,9 @@ See: [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_le Version: 0.14.0 +- docker-compose.yml: Assume the runtime supports the device (ie. nvidia) +- new model support: qihoo360/360VL-8B, qihoo360/360VL-70B (70B loading error, [see note](https://huggingface.co/qihoo360/360VL-70B/discussions/1)) +- new model support: BAAI/Emu2-Chat, Can be slow to load, may need --max-memory option control the loading on multiple gpus - new model support: TIGER-Labs/Mantis: Mantis-8B-siglip-llama3, Mantis-8B-clip-llama3, Mantis-8B-Fuyu @@ -145,7 +152,8 @@ For MiniGemini support the docker image is recommended. See `prepare_minigemini. ## Usage ``` -usage: vision.py [-h] -m MODEL [-b BACKEND] [-f FORMAT] [-d DEVICE] [--device-map DEVICE_MAP] [--no-trust-remote-code] [-4] [-8] [-F] [-P PORT] [-H HOST] [--preload] +usage: vision.py [-h] -m MODEL [-b BACKEND] [-f FORMAT] [-d DEVICE] [--device-map DEVICE_MAP] [--max-memory MAX_MEMORY] [--no-trust-remote-code] [-4] [-8] [-F] + [-P PORT] [-H HOST] [--preload] OpenedAI Vision API Server @@ -161,6 +169,8 @@ options: Set the torch device for the model. Ex. cpu, cuda:1 (default: auto) --device-map DEVICE_MAP Set the default device map policy for the model. (auto, balanced, sequential, balanced_low_0, cuda:1, etc.) (default: auto) + --max-memory MAX_MEMORY + (emu2 only) Set the per cuda device_map max_memory. Ex. 0:22GiB,1:22GiB,cpu:128GiB (default: None) --no-trust-remote-code Don't trust remote code (required for many models) (default: False) -4, --load-in-4bit load in 4bit (doesn't work with all models) (default: False) diff --git a/backend/360vl.py b/backend/360vl.py new file mode 100644 index 0000000..7aafbf4 --- /dev/null +++ b/backend/360vl.py @@ -0,0 +1,65 @@ +from transformers import AutoTokenizer, AutoModelForCausalLM + +import transformers +import warnings +# disable some warnings +transformers.logging.set_verbosity_error() +warnings.filterwarnings('ignore') + +from vision_qna import * +# "qihoo360/360VL-8B" +# "qihoo360/360VL-70B" + +class VisionQnA(VisionQnABase): + model_name: str = "360vl" + format = "llama3" + + def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None): + super().__init__(model_id, device, device_map, extra_params, format) + + if not format: + self.format = guess_model_format(model_id) + + self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False)) + self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval() + + self.vision_tower = self.model.get_vision_tower() + self.vision_tower.load_model() + self.vision_tower.to(device=self.device, dtype=self.dtype) + self.image_processor = self.vision_tower.image_processor + self.tokenizer.pad_token = self.tokenizer.eos_token + self.terminators = [ + self.tokenizer.convert_tokens_to_ids("<|eot_id|>",) + ] + + print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}") + + async def chat_with_images(self, request: ImageChatRequest) -> str: + images, prompt = await llama3_prompt_from_messages(request.messages, img_tok = "<|reserved_special_token_44|>\n") + + default_system = "You are a multilingual, helpful, respectful and honest assistant who can respond in the same language, depending on the language of the question. Try to be as helpful as possible while still being safe. Your answer should not contain anything that is false, unhealthy, harmful, immoral, racist, sexist, toxic, dangerous, or illegal, and if the question relates to such content, please decline to answer. Make sure your answer is socially fair and positive. If a question doesn't make any sense, or is inconsistent with the facts, explain why instead of answering the wrong answer. If you don't know the answer to a question, don't share false information." + + input_ids = self.tokenizer.encode(prompt, return_tensors="pt") + + input_id_list = input_ids[0].tolist() + input_id_list[input_id_list.index(128049)]=-200 + input_ids = torch.tensor(input_id_list, dtype=input_ids.dtype, device=input_ids.device).unsqueeze(0) + + image_tensor = self.model.process_images_slid_window(images[0], self.image_processor).unsqueeze(0) + + default_params = dict( + do_sample=False, + num_beams=1, + ) + + params = self.get_generation_params(request, default_params) + + output_ids = self.model.generate( + input_ids=input_ids.to(device=self.device, non_blocking=True), + images=image_tensor.to(dtype=self.dtype, device=self.device, non_blocking=True), + eos_token_id=self.terminators, + **params) + + outputs = self.tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0] + + return outputs.strip() diff --git a/backend/emu.py b/backend/emu.py new file mode 100644 index 0000000..857ebdf --- /dev/null +++ b/backend/emu.py @@ -0,0 +1,75 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch +from huggingface_hub import snapshot_download + +from vision_qna import * + +# "BAAI/Emu2-Chat" + +class VisionQnA(VisionQnABase): + model_name: str = 'emu' + format: str = 'emu' + + def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None): + super().__init__(model_id, device, device_map, extra_params, format) + + if self.params['torch_dtype'] == torch.bfloat16: + self.params['torch_dtype'] = torch.float16 + + checkpoint = snapshot_download(model_id) + with init_empty_weights(): + self.model = AutoModelForCausalLM.from_pretrained(**self.params) + + max_memory=extra_params.get('max_memory', None) + + device_map = infer_auto_device_map(self.model, max_memory=max_memory, no_split_module_classes=['Block','LlamaDecoderLayer']) + # input and output logits should be on same device + device_map["model.decoder.lm.lm_head"] = 0 + + self.model = load_checkpoint_and_dispatch(self.model, checkpoint=checkpoint, device_map=device_map).eval() + """ + self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval() + """ + + # bitsandbytes already moves the model to the device, so we don't need to do it again. + if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)): + self.model = self.model.to(self.device) + + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + # self.model.device/dtype are overloaded with some other object + print(f"Loaded on device: {self.device} with dtype: {self.dtype}") + + async def chat_with_images(self, request: ImageChatRequest) -> str: + images, prompt, system = await emu_images_prompt_system_from_messages(request.messages) + + if not system: + system = "You are a helpful assistant, dedicated to delivering comprehensive and meticulous responses." + + prompt = system + prompt + + inputs = self.model.build_input_ids( + text=[prompt], + tokenizer=self.tokenizer, + image=images + ) + # .cuda() + + default_params = { + 'length_penalty': -1, + } + + params = self.get_generation_params(request, default_params) + + with torch.no_grad(): + outputs = self.model.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + image=inputs["image"].to(torch.float16), # should be torch.float16 + **params, + ) + + response = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] + + return response diff --git a/docker-compose.alt.yml b/docker-compose.alt.yml index a9a6b2a..0b337e5 100644 --- a/docker-compose.alt.yml +++ b/docker-compose.alt.yml @@ -15,7 +15,7 @@ services: - ./model_conf_tests.alt.json:/app/model_conf_tests.json ports: - 5006:5006 - runtime: nvidia + #runtime: nvidia deploy: resources: reservations: diff --git a/docker-compose.yml b/docker-compose.yml index ab5ab8d..073f62d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -15,7 +15,7 @@ services: - ./model_conf_tests.json:/app/model_conf_tests.json ports: - 5006:5006 - runtime: nvidia + #runtime: nvidia deploy: resources: reservations: diff --git a/model_conf_tests.json b/model_conf_tests.json index 61b9bf8..1db5a69 100644 --- a/model_conf_tests.json +++ b/model_conf_tests.json @@ -1,11 +1,14 @@ [ - ["TIGER-Lab/Mantis-8B-siglip-llama3", "--use-flash-attn", "--device-map", "cuda:0"], - ["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0"], - ["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0"], + ["qihoo360/360VL-8B", "--use-flash-attn"], + ["qihoo360/360VL-70B", "--use-flash-attn"], + ["BAAI/Emu2-Chat", "--max-memory=0:78GiB,1:20GiB"], + ["BAAI/Emu2-Chat", "--load-in-4bit", "--device-map", "cuda:0"], ["vikhyatk/moondream2", "--use-flash-attn"], ["vikhyatk/moondream1"], ["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0"], ["HuggingFaceM4/idefics2-8b", "--use-flash-attn", "--device-map", "cuda:0"], + ["qihoo360/360VL-8B", "--use-flash-attn"], + ["qihoo360/360VL-70B", "--use-flash-attn"], ["qnguyen3/nanoLLaVA", "--use-flash-attn", "--device-map", "cuda:0"], ["echo840/Monkey"], ["echo840/Monkey-Chat"], @@ -14,6 +17,9 @@ ["Qwen/Qwen-VL-Chat"], ["BAAI/Bunny-Llama-3-8B-V"], ["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0"], + ["TIGER-Lab/Mantis-8B-siglip-llama3", "--use-flash-attn", "--device-map", "cuda:0"], + ["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0"], + ["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0"], ["adept/fuyu-8b", "--device-map", "cuda:0"], ["internlm/internlm-xcomposer2-4khd-7b", "--use-flash-attn", "--device-map", "cuda:0"], ["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0"], @@ -34,6 +40,8 @@ ["OpenGVLab/InternVL-Chat-V1-5-Int8", "--device-map", "cuda:0"], ["OpenGVLab/InternVL-Chat-V1-5", "--load-in-4bit", "--device-map", "cuda:0"], ["HuggingFaceM4/idefics2-8b-AWQ", "--use-flash-attn", "--device-map", "cuda:0"], + ["qihoo360/360VL-8B", "--use-flash-attn", "--load-in-4bit"], + ["qihoo360/360VL-70B", "--use-flash-attn", "--load-in-4bit"], ["qnguyen3/nanoLLaVA", "--use-flash-attn", "--load-in-4bit", "--device-map", "cuda:0"], ["THUDM/cogvlm-chat-hf", "--load-in-4bit"], ["THUDM/cogagent-chat-hf", "--load-in-4bit"], diff --git a/requirements.txt b/requirements.txt index df290e0..e68bb2e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -42,6 +42,9 @@ transformers_stream_generator loguru sse_starlette +# 360vl +logger + # alt #transformers==4.36.2 diff --git a/test_models.py b/test_models.py index b1864a4..a12e3bd 100755 --- a/test_models.py +++ b/test_models.py @@ -5,6 +5,7 @@ import requests import argparse import subprocess +import traceback from datauri import DataURI from openai import OpenAI import torch @@ -106,6 +107,7 @@ def test(cmd_args: list[str]) -> int: try: results = single_round() except Exception as e: + traceback.print_exc() note = f'Test failed with Exception: {e}' print(f"{note}") results = [False] diff --git a/vision.py b/vision.py index b01a871..9d25560 100644 --- a/vision.py +++ b/vision.py @@ -65,6 +65,7 @@ def parse_args(argv=None): parser.add_argument('-f', '--format', action='store', default=None, help="Force a specific chat format. (vicuna, mistral, chatml, llama2, phi15, gemma) (doesn't work with all models)") parser.add_argument('-d', '--device', action='store', default="auto", help="Set the torch device for the model. Ex. cpu, cuda:1") parser.add_argument('--device-map', action='store', default=os.environ.get('OPENEDAI_DEVICE_MAP', "auto"), help="Set the default device map policy for the model. (auto, balanced, sequential, balanced_low_0, cuda:1, etc.)") + parser.add_argument('--max-memory', action='store', default=None, help="(emu2 only) Set the per cuda device_map max_memory. Ex. 0:22GiB,1:22GiB,cpu:128GiB") parser.add_argument('--no-trust-remote-code', action='store_true', help="Don't trust remote code (required for many models)") parser.add_argument('-4', '--load-in-4bit', action='store_true', help="load in 4bit (doesn't work with all models)") parser.add_argument('-8', '--load-in-8bit', action='store_true', help="load in 8bit (doesn't work with all models)") @@ -105,6 +106,9 @@ def parse_args(argv=None): extra_params['use_flash_attn'] = True extra_params['trust_remote_code'] = not args.no_trust_remote_code + if args.max_memory: + dev_map_max_memory = {int(dev_id) if dev_id not in ['cpu', 'disk'] else dev_id: mem for dev_id, mem in [dev_mem.split(':') for dev_mem in args.max_memory.split(',')]} + extra_params['max_memory'] = dev_map_max_memory vision_qna = backend.VisionQnA(args.model, args.device, args.device_map, extra_params, format=args.format) diff --git a/vision_qna.py b/vision_qna.py index 1001c8f..3cc801e 100644 --- a/vision_qna.py +++ b/vision_qna.py @@ -388,6 +388,38 @@ async def fuyu_prompt_from_messages(messages: list[Message], img_tok = "", img_e return images, prompt +async def emu_images_prompt_system_from_messages(messages: list[Message], img_tok = "[]"): + prompt = '' + images = [] + system_message = None + + for m in messages: + if m.role == 'user': + text = '' + has_image = False + + for c in m.content: + if c.type == 'image_url': + images.extend([ await url_to_image(c.image_url.url) ]) + has_image = True + if c.type == 'text': + text = c.text + + img_tag = img_tok if has_image else '' + prompt += f" [USER]: {img_tag}{text}" + elif m.role == 'assistant': + for c in m.content: + if c.type == 'text': + prompt += f" [ASSISTANT]: {c.text}" + elif m.role == 'system': + for c in m.content: + if c.type == 'text': + system_message = c.text + + prompt += " [ASSISTANT]:" + + return images, prompt, system_message + async def prompt_history_images_system_from_messages(messages: list[Message], img_tok = "\n", url_handler = url_to_image): history = [] images = [] @@ -444,7 +476,7 @@ def guess_model_format(model_name: str) -> str: model_format_match_map = { 'llama2': ['bakllava', '8x7b', 'mistral', 'mixtral'], - 'llama3': ['llama-3-vision'], + 'llama3': ['llama-3-vision', '360vl'], 'gemma': ['gemma', '-2b'], 'vicuna': ['vicuna', '13b'], 'vicuna0': ['yi-vl'], @@ -524,4 +556,10 @@ def guess_backend(model_name: str) -> str: return 'bunny' if 'mantis' in model_id: - return 'mantis' \ No newline at end of file + return 'mantis' + + if 'emu' in model_id: + return 'emu' + + if '360vl' in model_id: + return '360vl' \ No newline at end of file