Skip to content

Commit

Permalink
0.30.0 +pixtral, attn_changes
Browse files Browse the repository at this point in the history
  • Loading branch information
matatonic committed Sep 13, 2024
1 parent 0a9f5e6 commit 14ef2fc
Show file tree
Hide file tree
Showing 10 changed files with 515 additions and 209 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Can't decide which to use? See the [OpenVLM Leaderboard](https://huggingface.co/
- - [X] [Phi-3-vision-128k-instruct](https://huggingface.co/microsoft/Phi-3-vision-128k-instruct)
- - [X] [Florence-2-large-ft](https://huggingface.co/microsoft/Florence-2-large-ft) (wont gpu split)
- - [X] [Florence-2-base-ft](https://huggingface.co/microsoft/Florence-2-base-ft) (wont gpu split)
- [X] [Mistral AI](https://huggingface.co/mistralai)
- - [X] [Pixtral-12B](https://huggingface.co/mistralai/Pixtral-12B-2409)
- [X] [openbmb](https://huggingface.co/openbmb)
- - [X] [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) (video not supported yet)
- - [X] [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-2_5)
Expand Down Expand Up @@ -131,10 +133,13 @@ If you can't find your favorite model, you can [open a new issue](https://github

## Recent updates


Version 0.30.0

- Update moondream2 to version 2024-08-26
- new model support: mistralai/Pixtral-12B-2409 (no streaming yet, no quants yet)
- new model support: LMMs-Lab's llava-onevision-qwen2, 0.5b, 7b and 72b (72b untested, 4bit support doesn't seem to work properly yet)
- Update moondream2 to version 2024-08-26
- Performance fixed: idefics2-8b-AWQ, idefics2-8b-chatty-AWQ

Version 0.29.0

Expand Down
48 changes: 48 additions & 0 deletions backend/pixtral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

from huggingface_hub import snapshot_download
from safetensors import safe_open
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer

from vision_qna import *

# mistralai/Pixtral-12B-2409

class VisionQnA(VisionQnABase):
model_name: str = "pixtral"
format: str = "pixtral"
visual_layers: List[str] = ["vision_encoder", 'vision_language_adapter']

def __init__(self, model_id: str, device: str, device_map: str = 'auto', extra_params = {}, format = None):
super().__init__(model_id, device, device_map, extra_params, format)

mistral_models_path = snapshot_download(repo_id=model_id, allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"])

self.tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
self.model = Transformer.from_folder(mistral_models_path, device=self.device, dtype=self.dtype)

# bitsandbytes already moves the model to the device, so we don't need to do it again.
#if not (extra_params.get('load_in_4bit', False) or extra_params.get('load_in_8bit', False)):
# self.model = self.model.to(self.device)

self.loaded_banner()

async def chat_with_images(self, request: ImageChatRequest) -> str:
prompt = await pixtral_messages(request.messages)

# tokenize image urls and text
tokenized = self.tokenizer.encode_chat_completion(prompt)

generation_kwargs = dict(
eos_id = self.tokenizer.instruct_tokenizer.tokenizer.eos_id,
max_tokens = request.max_tokens,
temperature= 0.35 if request.temperature is None else request.temperature,
)

out_tokens, _ = generate([tokenized.tokens], self.model, images=[tokenized.images], **generation_kwargs)

return self.tokenizer.decode(out_tokens[0])
17 changes: 10 additions & 7 deletions chat_with_image.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
#!/usr/bin/env python
try:
import dotenv
dotenv.load_dotenv(override=True)
except:
pass

import os
import requests
import argparse
from datauri import DataURI
from openai import OpenAI

try:
import dotenv
dotenv.load_dotenv(override=True)
except:
pass

def url_for_api(img_url: str = None, filename: str = None, always_data=False) -> str:
if img_url.startswith('http'):
Expand All @@ -29,6 +30,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
parser = argparse.ArgumentParser(description='Test vision using OpenAI',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-s', '--system-prompt', type=str, default=None)
parser.add_argument('--openai-model', type=str, default="gpt-4-vision-preview")
parser.add_argument('-S', '--start-with', type=str, default=None, help="Start reply with, ex. 'Sure, ' (doesn't work with all models)")
parser.add_argument('-m', '--max-tokens', type=int, default=None)
parser.add_argument('-t', '--temperature', type=float, default=None)
Expand All @@ -40,7 +42,8 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
parser.add_argument('questions', type=str, nargs='*', help='The question to ask the image')
args = parser.parse_args()

client = OpenAI(base_url=os.environ.get('OPENAI_BASE_URL', 'http://localhost:5006/v1'), api_key='skip')
client = OpenAI(base_url=os.environ.get('OPENAI_BASE_URL', 'http://localhost:5006/v1'),
api_key=os.environ.get('OPENAI_API_KEY', 'sk-ip'))

params = {}
if args.max_tokens is not None:
Expand All @@ -67,7 +70,7 @@ def url_for_api(img_url: str = None, filename: str = None, always_data=False) ->
if args.start_with:
messages.extend([{ "role": "assistant", "content": [{ "type": "text", "text": args.start_with }] }])

response = client.chat.completions.create(model="gpt-4-vision-preview", messages=messages, **params)
response = client.chat.completions.create(model=args.openai_model, messages=messages, **params)

if not args.single:
print(f"Answer: ", end='', flush=True)
Expand Down
22 changes: 11 additions & 11 deletions model_conf_tests.alt.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,15 @@
["THUDM/cogvlm2-llama3-chat-19B"],
["THUDM/cogvlm2-llama3-chinese-chat-19B", "--load-in-4bit"],
["THUDM/cogvlm2-llama3-chinese-chat-19B"],
["cognitivecomputations/dolphin-vision-72b", "--use-flash-attn", "--load-in-4bit", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-7b", "--use-flash-attn", "--load-in-4bit", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-7b", "--use-flash-attn", "--device-map", "cuda:0"],
["llava-hf/llava-v1.6-mistral-7b-hf", "--use-flash-attn", "--load-in-4bit"],
["llava-hf/llava-v1.6-mistral-7b-hf", "--use-flash-attn"],
["openbmb/MiniCPM-V", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V", "--use-flash-attn", "--device-map", "cuda:0"],
["openbmb/MiniCPM-V-2", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V-2", "--use-flash-attn", "--device-map", "cuda:0"],
["tiiuae/falcon-11B-vlm", "--use-flash-attn", "--load-in-4bit"],
["tiiuae/falcon-11B-vlm", "--use-flash-attn"]
["cognitivecomputations/dolphin-vision-72b", "-A", "flash_attention_2", "--load-in-4bit", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-7b", "-A", "flash_attention_2", "--load-in-4bit", "--device-map", "cuda:0"],
["cognitivecomputations/dolphin-vision-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["llava-hf/llava-v1.6-mistral-7b-hf", "-A", "flash_attention_2", "--load-in-4bit"],
["llava-hf/llava-v1.6-mistral-7b-hf", "-A", "flash_attention_2"],
["openbmb/MiniCPM-V", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["openbmb/MiniCPM-V-2", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V-2", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["tiiuae/falcon-11B-vlm", "-A", "flash_attention_2", "--load-in-4bit"],
["tiiuae/falcon-11B-vlm", "-A", "flash_attention_2"]
]
124 changes: 64 additions & 60 deletions model_conf_tests.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
["BAAI/Bunny-v1_1-Llama-3-8B-V"],
["BAAI/Emu2-Chat", "--load-in-4bit"],
["BAAI/Emu2-Chat", "--max-memory=0:78GiB,1:20GiB"],
["HuggingFaceM4/idefics2-8b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["HuggingFaceM4/idefics2-8b", "--use-flash-attn", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-AWQ", "--use-flash-attn", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-chatty", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["HuggingFaceM4/idefics2-8b-chatty", "--use-flash-attn", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-chatty-AWQ", "--use-flash-attn", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["HuggingFaceM4/idefics2-8b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-AWQ", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-chatty", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["HuggingFaceM4/idefics2-8b-chatty", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["HuggingFaceM4/idefics2-8b-chatty-AWQ", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0", "--load-in-4bit"],
["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0", "--max-tiles", "40", "--load-in-4bit"],
["OpenGVLab/InternVL-Chat-V1-5", "--device-map", "cuda:0", "--max-tiles", "40"],
Expand All @@ -29,6 +29,8 @@
["OpenGVLab/InternVL2-1B", "--device-map", "cuda:0"],
["OpenGVLab/InternVL2-2B", "--device-map", "cuda:0", "--load-in-4bit"],
["OpenGVLab/InternVL2-2B", "--device-map", "cuda:0"],
["OpenGVLab/InternVL2-4B", "--device-map", "cuda:0", "--load-in-4bit"],
["OpenGVLab/InternVL2-4B", "--device-map", "cuda:0"],
["OpenGVLab/InternVL2-8B", "--device-map", "cuda:0", "--load-in-4bit"],
["OpenGVLab/InternVL2-8B", "--device-map", "cuda:0"],
["OpenGVLab/InternVL2-26B", "--device-map", "cuda:0", "--load-in-4bit"],
Expand All @@ -40,6 +42,9 @@
["OpenGVLab/Mini-InternVL-Chat-2B-V1-5", "--max-tiles", "40", "--load-in-4bit"],
["OpenGVLab/Mini-InternVL-Chat-2B-V1-5", "--max-tiles", "40"],
["OpenGVLab/Mini-InternVL-Chat-2B-V1-5"],
["OpenGVLab/Mini-InternVL-Chat-4B-V1-5", "--max-tiles", "40", "--load-in-4bit"],
["OpenGVLab/Mini-InternVL-Chat-4B-V1-5", "--load-in-4bit"],
["OpenGVLab/Mini-InternVL-Chat-4B-V1-5"],
["Qwen/Qwen-VL-Chat", "--load-in-4bit"],
["Qwen/Qwen-VL-Chat"],
["Salesforce/xgen-mm-phi3-mini-instruct-dpo-r-v1.5"],
Expand All @@ -50,68 +55,67 @@
["THUDM/glm-4v-9b", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-Fuyu", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-clip-llama3", "--use-flash-attn", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-siglip-llama3", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-siglip-llama3", "--use-flash-attn", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-clip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-clip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["TIGER-Lab/Mantis-8B-siglip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["TIGER-Lab/Mantis-8B-siglip-llama3", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["adept/fuyu-8b", "--device-map", "cuda:0", "--load-in-4bit"],
["adept/fuyu-8b", "--device-map", "cuda:0"],
["echo840/Monkey", "--load-in-4bit"],
["echo840/Monkey"],
["echo840/Monkey-Chat", "--load-in-4bit"],
["echo840/Monkey-Chat"],
["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "--use-flash-attn", "--load-in-4bit"],
["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "--use-flash-attn"],
["fancyfeast/joy-caption-pre-alpha", "--load-in-4bit", "--use-flash-attn"],
["fancyfeast/joy-caption-pre-alpha", "--use-flash-attn"],
["internlm/internlm-xcomposer2d5-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2d5-7b", "--use-flash-attn", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-4khd-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-4khd-7b", "--use-flash-attn", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-7b", "--use-flash-attn", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-7b-4bit", "--use-flash-attn"],
["internlm/internlm-xcomposer2-vl-1_8b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-vl-1_8b", "--use-flash-attn", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-vl-7b", "--use-flash-attn", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-vl-7b-4bit", "--use-flash-attn"],
["llava-hf/llava-1.5-13b-hf", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["llava-hf/llava-1.5-13b-hf", "--use-flash-attn", "--device-map", "cuda:0"],
["llava-hf/llava-1.5-7b-hf", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["llava-hf/llava-1.5-7b-hf", "--use-flash-attn", "--device-map", "cuda:0"],
["llava-hf/llava-v1.6-34b-hf", "--use-flash-attn", "--load-in-4bit"],
["llava-hf/llava-v1.6-34b-hf", "--use-flash-attn"],
["llava-hf/llava-v1.6-vicuna-13b-hf", "--use-flash-attn", "--load-in-4bit"],
["llava-hf/llava-v1.6-vicuna-13b-hf", "--use-flash-attn"],
["llava-hf/llava-v1.6-vicuna-7b-hf", "--use-flash-attn", "--load-in-4bit"],
["llava-hf/llava-v1.6-vicuna-7b-hf", "--use-flash-attn"],
["lmms-lab/llava-onevision-qwen2-0.5b-ov", "--use-flash-attn"],
["lmms-lab/llava-onevision-qwen2-7b-ov", "--use-flash-attn"],
["microsoft/Florence-2-base-ft", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-base-ft", "--use-flash-attn", "--device-map", "cuda:0"],
["microsoft/Florence-2-large-ft", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-large-ft", "--use-flash-attn", "--device-map", "cuda:0"],
["microsoft/Phi-3-vision-128k-instruct", "--use-flash-attn", "--load-in-4bit"],
["microsoft/Phi-3-vision-128k-instruct", "--use-flash-attn"],
["microsoft/Phi-3.5-vision-instruct", "--use-flash-attn", "--load-in-4bit"],
["microsoft/Phi-3.5-vision-instruct", "--use-flash-attn"],
["openbmb/MiniCPM-V-2_6", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V-2_6", "--use-flash-attn", "--device-map", "cuda:0"],
["openbmb/MiniCPM-Llama3-V-2_5", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-Llama3-V-2_5", "--use-flash-attn", "--device-map", "cuda:0"],
["qihoo360/360VL-8B", "--use-flash-attn", "--load-in-4bit"],
["qihoo360/360VL-8B", "--use-flash-attn"],
["qnguyen3/nanoLLaVA", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["qnguyen3/nanoLLaVA", "--use-flash-attn", "--device-map", "cuda:0"],
["qnguyen3/nanoLLaVA-1.5", "--use-flash-attn", "--device-map", "cuda:0", "--load-in-4bit"],
["qnguyen3/nanoLLaVA-1.5", "--use-flash-attn", "--device-map", "cuda:0"],
["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "-A", "flash_attention_2", "--load-in-4bit"],
["failspy/Phi-3-vision-128k-instruct-abliterated-alpha", "-A", "flash_attention_2"],
["fancyfeast/joy-caption-pre-alpha", "--load-in-4bit", "-A", "flash_attention_2"],
["fancyfeast/joy-caption-pre-alpha", "-A", "flash_attention_2"],
["internlm/internlm-xcomposer2d5-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2d5-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-4khd-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-4khd-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-7b-4bit", "-A", "flash_attention_2"],
["internlm/internlm-xcomposer2-vl-1_8b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-vl-1_8b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-vl-7b", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["internlm/internlm-xcomposer2-vl-7b", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["internlm/internlm-xcomposer2-vl-7b-4bit", "-A", "flash_attention_2"],
["llava-hf/llava-1.5-13b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["llava-hf/llava-1.5-13b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["llava-hf/llava-1.5-7b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["llava-hf/llava-1.5-7b-hf", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["llava-hf/llava-v1.6-34b-hf", "-A", "flash_attention_2", "--load-in-4bit"],
["llava-hf/llava-v1.6-34b-hf", "-A", "flash_attention_2"],
["llava-hf/llava-v1.6-vicuna-13b-hf", "-A", "flash_attention_2", "--load-in-4bit"],
["llava-hf/llava-v1.6-vicuna-13b-hf", "-A", "flash_attention_2"],
["llava-hf/llava-v1.6-vicuna-7b-hf", "-A", "flash_attention_2", "--load-in-4bit"],
["llava-hf/llava-v1.6-vicuna-7b-hf", "-A", "flash_attention_2"],
["lmms-lab/llava-onevision-qwen2-0.5b-ov", "-A", "flash_attention_2"],
["lmms-lab/llava-onevision-qwen2-7b-ov", "-A", "flash_attention_2"],
["microsoft/Florence-2-base-ft", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-base-ft", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["microsoft/Florence-2-large-ft", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["microsoft/Florence-2-large-ft", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["microsoft/Phi-3-vision-128k-instruct", "-A", "flash_attention_2", "--load-in-4bit"],
["microsoft/Phi-3-vision-128k-instruct", "-A", "flash_attention_2"],
["microsoft/Phi-3.5-vision-instruct", "-A", "flash_attention_2", "--load-in-4bit"],
["microsoft/Phi-3.5-vision-instruct", "-A", "flash_attention_2"],
["mistralai/Pixtral-12B-2409"],
["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-V-2_6", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["openbmb/MiniCPM-Llama3-V-2_5", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["openbmb/MiniCPM-Llama3-V-2_5", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["qihoo360/360VL-8B", "-A", "flash_attention_2", "--load-in-4bit"],
["qihoo360/360VL-8B", "-A", "flash_attention_2"],
["qnguyen3/nanoLLaVA", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["qnguyen3/nanoLLaVA", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["qnguyen3/nanoLLaVA-1.5", "-A", "flash_attention_2", "--device-map", "cuda:0", "--load-in-4bit"],
["qnguyen3/nanoLLaVA-1.5", "-A", "flash_attention_2", "--device-map", "cuda:0"],
["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0", "--load-in-4bit"],
["qresearch/llama-3-vision-alpha-hf", "--device", "cuda:0"],
["togethercomputer/Llama-3-8B-Dragonfly-Med-v1", "--load-in-4bit"],
["togethercomputer/Llama-3-8B-Dragonfly-Med-v1"],
["togethercomputer/Llama-3-8B-Dragonfly-v1", "--load-in-4bit"],
["togethercomputer/Llama-3-8B-Dragonfly-v1"],
["vikhyatk/moondream2", "--use-flash-attn", "--load-in-4bit"],
["vikhyatk/moondream2", "--use-flash-attn"]
["vikhyatk/moondream2", "-A", "flash_attention_2", "--load-in-4bit"],
["vikhyatk/moondream2", "-A", "flash_attention_2"]
]
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,7 @@ logger

# llava-onevision
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git

# mistral
mistral_inference>=1.4.0
mistral_common>=1.4.0
Loading

0 comments on commit 14ef2fc

Please sign in to comment.