Skip to content

Commit

Permalink
Merge pull request #249 from theroyallab/vision
Browse files Browse the repository at this point in the history
Vision
  • Loading branch information
bdashore3 authored Nov 22, 2024
2 parents a69f860 + 388d36e commit 9c8186c
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 69 deletions.
58 changes: 56 additions & 2 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import pathlib
import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch
import uuid
from copy import deepcopy
Expand All @@ -20,6 +22,7 @@
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
Expand Down Expand Up @@ -91,6 +94,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False

# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None

# Load state
model_is_loading: bool = False
model_loaded: bool = False
Expand Down Expand Up @@ -144,6 +151,15 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)

# Set vision state and error if vision isn't supported on the current model
self.use_vision = unwrap(kwargs.get("vision"), False)
if self.use_vision and not self.config.vision_model_type:
raise ValueError(
"The provided model does not have vision capabilities that are "
"supported by ExllamaV2. "
"Please reload with vision disabled."
)

# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
Expand Down Expand Up @@ -477,6 +493,7 @@ def get_model_parameters(self):
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
"use_vision": self.use_vision,
}

if self.draft_config:
Expand Down Expand Up @@ -620,6 +637,14 @@ def progress(loaded_modules: int, total_modules: int)
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)

# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)

for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value

self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)
Expand Down Expand Up @@ -811,6 +836,9 @@ async def unload(self, loras_only: bool = False, **kwargs):
# Delete references held in the grammar module
clear_grammar_func_cache()

# Clear the image embedding cache
clear_image_embedding_cache()

# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
Expand All @@ -824,6 +852,16 @@ async def unload(self, loras_only: bool = False, **kwargs):
self.model.unload()
self.model = None

if self.vision_model:
# TODO: Remove this with newer exl2 versions
# Required otherwise unload function won't finish
try:
self.vision_model.unload()
except AttributeError:
pass

self.vision_model = None

if self.draft_model:
self.draft_model.unload()
self.draft_model = None
Expand Down Expand Up @@ -855,11 +893,15 @@ async def unload(self, loras_only: bool = False, **kwargs):
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string."""

mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []

return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=mm_embeddings_content,
)
.flatten()
.tolist()
Expand Down Expand Up @@ -903,7 +945,11 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
return dict(zip_longest(top_tokens, cleaned_values))

async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
self,
prompt: str,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
):
"""Generate a response to a prompt."""
generations = []
Expand Down Expand Up @@ -1238,10 +1284,17 @@ async def generate_gen(
else:
stop_conditions += eos_tokens

# Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []

# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=mm_embeddings_content,
)
for prompt in prompts
]
Expand Down Expand Up @@ -1292,6 +1345,7 @@ async def generate_gen(
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=mm_embeddings_content,
)

# Save generated tokens and full response
Expand Down
70 changes: 70 additions & 0 deletions backends/exllamav2/vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Vision utilities for ExLlamaV2."""

import io
import base64
import re
from PIL import Image
from common import model
import aiohttp
from common.networking import (
handle_request_error,
)
from common.tabby_config import config
from fastapi import HTTPException
from exllamav2.generator import ExLlamaV2MMEmbedding
from async_lru import alru_cache


async def get_image(url: str) -> Image:
if url.startswith("data:image"):
# Handle base64 image
match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url)
if match:
base64_image = match.group(1)
bytes_image = base64.b64decode(base64_image)
else:
error_message = handle_request_error(
"Failed to read base64 image input.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

else:
# Handle image URL
if config.network.disable_fetch_requests:
error_message = handle_request_error(
f"Failed to fetch image from {url} as fetch requests are disabled.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
if response.status == 200:
bytes_image = await response.read()
else:
error_message = handle_request_error(
f"Failed to fetch image from {url}.",
exc_info=False,
).error.message

raise HTTPException(400, error_message)

return Image.open(io.BytesIO(bytes_image))


@alru_cache(20)
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
image = await get_image(url)
return model.container.vision_model.get_image_embeddings(
model=model.container.model,
tokenizer=model.container.tokenizer,
image=image,
text_alias=None,
)


def clear_image_embedding_cache():
get_image_embedding.cache_clear()
13 changes: 13 additions & 0 deletions common/config_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,13 @@ class NetworkConfig(BaseConfigModel):
"Turn on this option if you are ONLY connecting from localhost."
),
)
disable_fetch_requests: Optional[bool] = Field(
False,
description=(
"Disable fetching external content in response to requests,"
"such as images from URLs."
),
)
send_tracebacks: Optional[bool] = Field(
False,
description=(
Expand Down Expand Up @@ -281,6 +288,12 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!"
),
)
vision: Optional[bool] = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
),
)
num_experts_per_token: Optional[int] = Field(
None,
description=(
Expand Down
23 changes: 17 additions & 6 deletions common/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
VISION = "vision"


def load_progress(module, modules):
Expand Down Expand Up @@ -70,29 +71,39 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)

model_type = "draft" if container.draft_config else "model"
# Add possible types of models that can be loaded
model_type = [ModelType.MODEL]

if container.use_vision:
model_type.insert(0, ModelType.VISION)

if container.draft_config:
model_type.insert(0, ModelType.DRAFT)

load_status = container.load_gen(load_progress, **kwargs)

progress = get_loading_progress_bar()
progress.start()

try:
index = 0
async for module, modules in load_status:
current_model_type = model_type[index].value
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
f"[cyan]Loading {current_model_type} modules", total=modules
)
else:
progress.advance(loading_task)

yield module, modules, model_type
yield module, modules, current_model_type

if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
if index == len(model_type):
progress.stop()
else:
index += 1
finally:
progress.stop()

Expand Down
30 changes: 30 additions & 0 deletions common/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding
from common import model
from loguru import logger

from common.optional_dependencies import dependencies

if dependencies.exllamav2:
from exllamav2 import ExLlamaV2VisionTower


class MultimodalEmbeddingWrapper:
"""Common multimodal embedding wrapper"""

type: str = None
content: List = []
text_alias: List[str] = []

async def add(self, url: str):
# Determine the type of vision embedding to use
if not self.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
self.type = "ExLlamaV2MMEmbedding"

if self.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
self.content.append(embedding)
self.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")
6 changes: 6 additions & 0 deletions config_sample.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ network:
# Turn on this option if you are ONLY connecting from localhost.
disable_auth: false

# Disable fetching external content in response to requests, such as images from URLs.
disable_fetch_requests: false

# Send tracebacks over the API (default: False).
# NOTE: Only enable this for debug purposes.
send_tracebacks: false
Expand Down Expand Up @@ -130,6 +133,9 @@ model:
# NOTE: Only works with chat completion message lists!
prompt_template:

# Enables vision support if the model supports it. (default: False)
vision: false

# Number of experts to use per token.
# Fetched from the model's config.json if empty.
# NOTE: For MoE models only.
Expand Down
13 changes: 6 additions & 7 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
apply_chat_template,
generate_chat_completion,
stream_generate_chat_completion,
)
Expand Down Expand Up @@ -123,10 +123,7 @@ async def chat_completion_request(

model_path = model.container.model_dir

if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = await format_prompt_with_template(data)
prompt, embeddings = await apply_chat_template(data)

# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
Expand All @@ -136,12 +133,14 @@ async def chat_completion_request(

if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
generate_chat_completion(prompt, embeddings, data, request, model_path)
)

response = await run_with_request_disconnect(
Expand Down
Loading

0 comments on commit 9c8186c

Please sign in to comment.