Skip to content

Commit

Permalink
0.22.0 +glm-4v-9b
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Jun 5, 2024
1 parent 02e30bd commit ab6dd4d
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 3 deletions.
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ An OpenAI API compatible vision server, it functions like `gpt-4-vision-preview`
- - [X] [cogvlm2-llama3-chinese-chat-19B](https://huggingface.co/THUDM/cogvlm2-llama3-chinese-chat-19B)
- - [X] [cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf)
- - [X] [cogagent-chat-hf](https://huggingface.co/THUDM/cogagent-chat-hf)
- - [X] [glm-4v-9b](https://huggingface.co/THUDM/glm-4v-9b) (wont gpu split)
- [X] [InternLM](https://huggingface.co/internlm/)
- - [X] [XComposer2-4KHD-7b](https://huggingface.co/internlm/internlm-xcomposer2-4khd-7b) (wont gpu split)
- - [X] [XComposer2-7b](https://huggingface.co/internlm/internlm-xcomposer2-7b) [finetune] (wont gpu split)
Expand Down Expand Up @@ -99,6 +100,10 @@ See: [OpenVLM Leaderboard](https://huggingface.co/spaces/opencompass/open_vlm_le

## Recent updates

Version 0.22.0

- new model support: THUDM/glm-4v-9b

Version 0.21.0

- new model support: Salesforce/xgen-mm-phi3-mini-instruct-r-v1
Expand Down
66 changes: 66 additions & 0 deletions backend/glm-4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
from torchvision import transforms
import torch
from vision_qna import *

# THUDM/glm-4v-9b

class VisionQnA(VisionQnABase):
model_name: str = "glm-4v"
format: str = 'glm-4v'
vision_layers: List[str] = ['transformer.vision']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModelForCausalLM.from_pretrained(**self.params).eval()

# bitsandbytes already moves the model to the device, so we don't need to do it again.
if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
self.model = self.model.to(self.device)

self.transform = transforms.Compose(
[
transforms.Resize(
(self.model.config.vision_config['image_size'], self.model.config.vision_config['image_size']), interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)

print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}")

async def chat_with_images(self, request: ImageChatRequest) -> str:
images, prompt = await glm4v_prompt_from_messages(request.messages)

images = torch.stack([ self.transform(img) for img in images ])

input_ids = self.tokenizer.encode(prompt)
inputs = self.tokenizer.batch_encode_plus(
[input_ids],
padding=False,
truncation=False,
max_length=None,
return_tensors="pt",
is_split_into_words=True,
add_special_tokens=False
)

inputs["images"] = images
inputs = inputs.to(device=self.device)

default_params = {
'max_new_tokens': 2500,
'do_sample': False,
}

params = self.get_generation_params(request, default_params)

with torch.no_grad():
outputs = self.model.generate(**inputs, **params)

answer = self.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True).strip()

return answer
2 changes: 2 additions & 0 deletions model_conf_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
["THUDM/cogvlm2-llama3-chat-19B"],
["THUDM/cogvlm2-llama3-chinese-chat-19B", "--load-in-4bit"],
["THUDM/cogvlm2-llama3-chinese-chat-19B"],
["THUDM/glm-4v-9b", "--device-map", "cuda:0", "--load-in-4bit"],
["THUDM/glm-4v-9b", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
Expand Down
2 changes: 2 additions & 0 deletions vision.sample.env
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ HF_HUB_ENABLE_HF_TRANSFER=1
#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chat-19B" # test pass✅, time: 21.5s, mem: 40.7GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chinese-chat-19B --load-in-4bit" # test pass✅, time: 79.2s, mem: 15.3GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m THUDM/cogvlm2-llama3-chinese-chat-19B" # test pass✅, time: 69.7s, mem: 40.7GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0 --load-in-4bit" # test pass✅, time: 60.0s, mem: 16.1GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m THUDM/glm-4v-9b --device-map cuda:0" # test pass✅, time: 35.1s, mem: 27.9GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0 --load-in-4bit" # test pass✅, time: 7.0s, mem: 11.2GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-Fuyu --device-map cuda:0" # test pass✅, time: 6.4s, mem: 20.4GB, 8/8 tests passed.
#CLI_COMMAND="python vision.py -m TIGER-Lab/Mantis-8B-clip-llama3 --use-flash-attn --device-map cuda:0 --load-in-4bit" # test pass✅, time: 9.3s, mem: 7.2GB, 8/8 tests passed.
Expand Down
36 changes: 33 additions & 3 deletions vision_qna.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ async def phi3_prompt_from_messages(messages: list[Message], img_tok = "<image>\

return images, prompt


async def phintern_prompt_from_messages(messages: list[Message], img_tok = "<image>\n"):
prompt = ''
images = []
Expand Down Expand Up @@ -554,7 +553,6 @@ async def falcon_prompt_from_messages(messages: list[Message], img_tok = "<image

return images, prompt


async def prompt_history_images_system_from_messages(messages: list[Message], img_tok = "<image>\n", url_handler = url_to_image):
history = []
images = []
Expand Down Expand Up @@ -585,7 +583,33 @@ async def prompt_history_images_system_from_messages(messages: list[Message], im

return prompt, history, images, system_prompt

async def glm4v_prompt_from_messages(messages: list[Message], img_tok = "<|begin_of_image|><|endoftext|><|end_of_image|>", url_handler = url_to_image):
prompt = '[gMASK]<sop>'
images = []
generation_msg = '<|assistant|>\n'

if messages and messages[-1].role == 'assistant':
generation_msg += messages[-1].content[0].text
messages.pop(-1)

for m in messages:
img_tag = ''
metadata = '' # not used

# TODO: handle tool role and build system prompt?

for c in m.content:
if c.type == 'image_url':
images.extend([ await url_to_image(c.image_url.url) ])
img_tag += img_tok

for c in m.content:
if c.type == 'text':
prompt += f"<|{m.role}|>{metadata}\n{img_tag}{c.text}"

prompt += generation_msg

return images, prompt


async def prompt_from_messages(messages: list[Message], format: str) -> str:
Expand All @@ -594,6 +618,7 @@ async def prompt_from_messages(messages: list[Message], format: str) -> str:
'falcon': falcon_prompt_from_messages,
'fuyu': fuyu_prompt_from_messages,
'gemma': gemma_prompt_from_messages,
'glm4v': glm4v_prompt_from_messages,
'llama2': llama2_prompt_from_messages,
'llama3': llama3_prompt_from_messages,
'mistral': llama2_prompt_from_messages, # simplicity
Expand All @@ -617,6 +642,7 @@ def guess_model_format(model_name: str) -> str:
'falcon': ['falcon'],
'fuyu': ['fuyu'],
'gemma': ['gemma', '-2b'],
'glm4v': ['glm-4v'],
'llama2': ['bakllava', '8x7b', 'mistral', 'mixtral'],
'llama3': ['llama-3-vision', '360vl'],
'phi15': ['moondream1', 'moondream2', 'monkey'],
Expand Down Expand Up @@ -684,6 +710,9 @@ def guess_backend(model_name: str) -> str:
if 'cogagent-' in model_id or 'cogvlm-' in model_id:
return 'cogvlm'

if 'glm-4v' in model_id:
return 'glm-4v'

if 'fuyu' in model_id:
return 'fuyu'

Expand Down Expand Up @@ -716,4 +745,5 @@ def guess_backend(model_name: str) -> str:
return 'phi3'

if 'falcon' in model_id:
return 'llavanext'
return 'llavanext'

0 comments on commit ab6dd4d

Please sign in to comment.