Skip to content

Commit

Permalink
0.6.1 +gptq 4bit for internlm
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Apr 6, 2024
1 parent 7724a24 commit 6a24f4f
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 14 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@ Model support:
- [ ] [openbmb/OmniLMM-12B](https://huggingface.co/openbmb/OmniLMM-12B)
- [ ] [echo840/Monkey](https://huggingface.co/echo840/Monkey)
- [ ] [YanweiLi/MiniGemini](https://huggingface.co/collections/YanweiLi/)
- [ ] [NousResearch/Obsidian-3B-V0.5](https://huggingface.co/NousResearch/Obsidian-3B-V0.5)
- [ ] ...


Some vision systems include their own OpenAI compatible API server. Also included are some pre-built images and docker-compose for them:
- [X] [THUDM/CogVLM](https://github.com/THUDM/CogVLM) ([cogvlm-chat-hf](https://huggingface.co/THUDM/cogvlm-chat-hf), [cogagent-chat-hf](https://huggingface.co/THUDM/cogagent-chat-hf)), `docker-compose.cogvlm.yml` **Recommended for 16GB-40GB GPU**s
- [X] [01-ai](https://huggingface.co/01-ai)/Yi-VL ([Yi-VL-6B](https://huggingface.co/01-ai/Yi-VL-6B), [Yi-VL-34B](https://huggingface.co/01-ai/Yi-VL-34B)), `docker-compose.yi-vl.yml`

Version: 0.6.0
Version: 0.6.1

Recent updates:
- AutoGPTQ support for internlm/internlm-xcomposer2-7b-4bit, internlm/internlm-xcomposer2-vl-7b-4bit
- Automatic selection of backend, based on the model name
- Enable trust_remote_code by default
- Improved parameter support: temperature, top_p, max_tokens, system prompts
Expand Down
27 changes: 25 additions & 2 deletions backend/xcomposer2-vl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,23 @@
import os
from transformers import AutoTokenizer, AutoModel

from vision_qna import *
import auto_gptq
import torch

# internlm/internlm-xcomposer2-vl-7b
# internlm/internlm-xcomposer2-vl-7b-4bit

class InternLMXComposer2QForCausalLM(auto_gptq.modeling.BaseGPTQForCausalLM):
layers_block_name = "model.layers"
outside_layer_modules = [
'vit', 'vision_proj', 'model.tok_embeddings', 'model.norm', 'output',
]
inside_layer_modules = [
["attention.wqkv.linear"],
["attention.wo.linear"],
["feed_forward.w1.linear", "feed_forward.w3.linear"],
["feed_forward.w2.linear"],
]

class VisionQnA(VisionQnABase):
model_name: str = "xcomposer2-vl"
Expand All @@ -12,7 +26,16 @@ def __init__(self, model_id: str, device: str, extra_params = {}, format = None)
super().__init__(model_id, device, extra_params, format)

self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModel.from_pretrained(**self.params).eval()

if '-4bit' in model_id:
if self.params['torch_dtype'] == torch.bfloat16:
self.params['torch_dtype'] = torch.float16

torch.set_grad_enabled(False)
auto_gptq.modeling._base.SUPPORTED_MODELS = ["internlm"]
self.model = InternLMXComposer2QForCausalLM.from_quantized(model_name_or_path=model_id, **self.params).eval()
else:
self.model = AutoModel.from_pretrained(**self.params).eval()

print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}")

Expand Down
30 changes: 27 additions & 3 deletions backend/xcomposer2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
import os
from transformers import AutoTokenizer, AutoModel

from vision_qna import *
import auto_gptq
import torch

# internlm/internlm-xcomposer2-7b
# internlm/internlm-xcomposer2-7b-4bit

class InternLMXComposer2QForCausalLM(auto_gptq.modeling.BaseGPTQForCausalLM):
layers_block_name = "model.layers"
outside_layer_modules = [
'vit', 'vision_proj', 'model.tok_embeddings', 'model.norm', 'output',
]
inside_layer_modules = [
["attention.wqkv.linear"],
["attention.wo.linear"],
["feed_forward.w1.linear", "feed_forward.w3.linear"],
["feed_forward.w2.linear"],
]

# internlm/internlm-xcomposer2-7b

Expand All @@ -12,8 +28,16 @@ def __init__(self, model_id: str, device: str, extra_params = {}, format = None)
super().__init__(model_id, device, extra_params, format)

self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=self.params.get('trust_remote_code', False))
self.model = AutoModel.from_pretrained(**self.params).eval()

if '-4bit' in model_id:
if self.params['torch_dtype'] == torch.bfloat16:
self.params['torch_dtype'] = torch.float16

torch.set_grad_enabled(False)
auto_gptq.modeling._base.SUPPORTED_MODELS = ["internlm"]
self.model = InternLMXComposer2QForCausalLM.from_quantized(model_name_or_path=model_id, **self.params).eval()
else:
self.model = AutoModel.from_pretrained(**self.params).eval()

print(f"Loaded on device: {self.model.device} with dtype: {self.model.dtype}")

async def chat_with_images(self, request: ImageChatRequest) -> str:
Expand Down
10 changes: 4 additions & 6 deletions chat_with_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ def url_to_data_url(img_url: str) -> str:
image_url = str(DataURI.from_file(image_url))

messages = [{ "role": "system", "content": [{ 'type': 'text', 'text': args.system_prompt }] }] if args.system_prompt else []
content = [
{ "type": "image_url", "image_url": { "url": image_url } },
{ "type": "text", "text": ' '.join(args.questions) },
]
content = [{ "type": "image_url", "image_url": { "url": image_url } },
{ "type": "text", "text": ' '.join(args.questions) }]
messages.extend([{ "role": "user", "content": content }])

while True:
Expand All @@ -70,8 +68,8 @@ def url_to_data_url(img_url: str) -> str:
break

content = [{"type": "image_url", "image_url": { "url": image_url } }] if image_url else []
content.extend([{ 'type': 'text', 'text': response.choices[0].message.content } ])
content.extend([{ 'type': 'text', 'text': response.choices[0].message.content }])
messages.extend([{ "role": "assistant", "content": content },
{ "role": "user", "content": [ { 'type': 'text', 'text': q } ] } ])
{ "role": "user", "content": [{ 'type': 'text', 'text': q }] }])


2 changes: 2 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@ services:
ports:
- 5006:5006
#command: ["python", "vision.py", "-m", "internlm/internlm-xcomposer2-7b", "--use-flash-attn"]
#command: ["python", "vision.py", "-m", "internlm/internlm-xcomposer2-7b-4bit", "--use-flash-attn"]
#command: ["python", "vision.py", "-m", "internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn"]
#command: ["python", "vision.py", "-m", "internlm/internlm-xcomposer2-vl-7b-4bit", "--use-flash-attn"]
#command: ["python", "vision.py", "--host", "0.0.0.0", "--port", "5006", "--model", "llava-hf/llava-v1.6-34b-hf", "--load-in-4bit", "--use-flash-attn"] # WIP
#command: ["python", "vision.py", "-m", "echo840/Monkey-Chat"] # broken
#command: ["python", "vision.py", "-m", "openbmb/OmniLMM-12B"] # WIP
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ bitsandbytes
flash_attn
sentencepiece
protobuf
peft
peft
auto_gptq
2 changes: 1 addition & 1 deletion vision_qna.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, model_id: str, device: str, extra_params = {}, format = None)
'quantization_config': {
'load_in_4bit': True,
'bnb_4bit_quant_type': "nf4",
'bnb_4bit_use_double_quant': True,
'bnb_4bit_use_double_quant': True, # XXX make this an option
'bnb_4bit_compute_dtype': self.dtype,
}
}
Expand Down

0 comments on commit 6a24f4f

Please sign in to comment.