Skip to content

Commit

Permalink
OAI: Initial vision support in OAI chat completions
Browse files Browse the repository at this point in the history
* Support image_url inputs containing URLs or base64 strings following OAI vision spec
* Use async lru cache for image embeddings
* Add generic wrapper class for multimodal embeddings
  • Loading branch information
DocShotgun committed Nov 18, 2024
1 parent 5fa298e commit dd41eec
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 26 deletions.
21 changes: 18 additions & 3 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import math
import pathlib
import traceback
from backends.exllamav2.vision import clear_image_embedding_cache
from common.multimodal import MultimodalEmbeddingWrapper
import torch
import uuid
from copy import deepcopy
Expand Down Expand Up @@ -816,6 +818,9 @@ async def unload(self, loras_only: bool = False, **kwargs):
# Delete references held in the grammar module
clear_grammar_func_cache()

# Clear the image embedding cache
clear_image_embedding_cache()

# Unload LoRAs
if self.generator and self.generator.generator.current_loras:
for lora in self.generator.generator.current_loras:
Expand Down Expand Up @@ -908,12 +913,17 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
return dict(zip_longest(top_tokens, cleaned_values))

async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
):
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
prompt, embeddings, request_id, abort_event, **kwargs
):
generations.append(generation)

Expand Down Expand Up @@ -979,6 +989,7 @@ def check_unsupported_settings(self, **kwargs):
async def generate_gen(
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
Expand Down Expand Up @@ -1246,7 +1257,10 @@ async def generate_gen(
# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=embeddings.content,
)
for prompt in prompts
]
Expand Down Expand Up @@ -1297,6 +1311,7 @@ async def generate_gen(
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=embeddings.content,
)

# Save generated tokens and full response
Expand Down
30 changes: 14 additions & 16 deletions backends/exllamav2/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,14 @@
import base64
import re
from PIL import Image
from common import model
import aiohttp
from common.networking import (
handle_request_error,
)
from fastapi import HTTPException
from exllamav2 import (
ExLlamaV2,
ExLlamaV2Tokenizer,
ExLlamaV2VisionTower,
ExLlamaV2MMEmbedding,
)
from functools import lru_cache
from exllamav2.generator import ExLlamaV2MMEmbedding
from async_lru import alru_cache


async def get_image(url: str) -> Image:
Expand Down Expand Up @@ -50,14 +46,16 @@ async def get_image(url: str) -> Image:
return Image.open(io.BytesIO(bytes_image))


@lru_cache(20)
async def get_image_embedding(
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
vision_model: ExLlamaV2VisionTower,
url: str,
) -> ExLlamaV2MMEmbedding:
@alru_cache(20)
async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding:
image = await get_image(url)
return vision_model.get_image_embeddings(
model=model, tokenizer=tokenizer, image=image
return model.container.vision_model.get_image_embeddings(
model=model.container.model,
tokenizer=model.container.tokenizer,
image=image,
text_alias=None,
)


def clear_image_embedding_cache():
get_image_embedding.cache_clear()
36 changes: 36 additions & 0 deletions common/multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding
from common import model
from pydantic import BaseModel
from loguru import logger

from common.optional_dependencies import dependencies

if dependencies.exllamav2:
from exllamav2 import ExLlamaV2VisionTower


class MultimodalEmbeddingWrapper(BaseModel):
"""Common multimodal embedding wrapper"""

type: str = None
content: List = []
text_alias: List[str] = []


async def add_image_embedding(
embeddings: MultimodalEmbeddingWrapper, url: str
) -> MultimodalEmbeddingWrapper:
# Determine the type of vision embedding to use
if not embeddings.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
embeddings.type = "ExLlamaV2MMEmbedding"

if embeddings.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
embeddings.content.append(embedding)
embeddings.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")

return embeddings
9 changes: 7 additions & 2 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
preprocess_vision_request,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
Expand Down Expand Up @@ -126,6 +127,8 @@ async def chat_completion_request(
if isinstance(data.messages, str):
prompt = data.messages
else:
if model.container.use_vision:
data.messages, embeddings = await preprocess_vision_request(data.messages)
prompt = await format_prompt_with_template(data)

# Set an empty JSON schema if the request wants a JSON response
Expand All @@ -136,12 +139,14 @@ async def chat_completion_request(

if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
generate_chat_completion(prompt, embeddings, data, request, model_path)
)

response = await run_with_request_disconnect(
Expand Down
40 changes: 36 additions & 4 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import asyncio
import pathlib
from asyncio import CancelledError
from typing import List, Optional
from typing import Dict, List, Optional
import json

from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
Expand Down Expand Up @@ -279,7 +280,11 @@ async def format_prompt_with_template(


async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
"""Generator for the generation process."""
abort_event = asyncio.Event()
Expand All @@ -298,6 +303,7 @@ async def stream_generate_chat_completion(
n,
gen_queue,
prompt,
embeddings,
request.state.id,
abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
Expand Down Expand Up @@ -372,7 +378,11 @@ async def stream_generate_chat_completion(


async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
gen_tasks: List[asyncio.Task] = []

Expand All @@ -381,7 +391,10 @@ async def generate_chat_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
prompt,
embeddings,
request.state.id,
**data.model_dump(exclude={"prompt"}),
)
)
)
Expand Down Expand Up @@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]


async def preprocess_vision_request(messages: List[Dict]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message["content"], list):
concatenated_content = ""
for content in message["content"]:
if content["type"] == "text":
concatenated_content += content["text"]
elif content["type"] == "image_url":
embeddings = await add_image_embedding(
embeddings, content["image_url"]["url"]
)
concatenated_content += embeddings.text_alias[-1]

message["content"] = concatenated_content

return messages, embeddings
4 changes: 3 additions & 1 deletion endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import asyncio
import pathlib
from asyncio import CancelledError
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import HTTPException, Request
from typing import List, Union

Expand Down Expand Up @@ -87,6 +88,7 @@ async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event,
**kwargs,
Expand All @@ -95,7 +97,7 @@ async def _stream_collector(

try:
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
prompt, embeddings, request_id, abort_event, **kwargs
)
async for generation in new_generation:
generation["index"] = task_idx
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"lm-format-enforcer >= 0.9.6",
"aiofiles",
"aiohttp",
"async_lru",
"huggingface_hub",
"psutil",
"httptools>=0.5.0",
Expand Down

0 comments on commit dd41eec

Please sign in to comment.