Skip to content

Commit

Permalink
API: Transform multimodal into an actual class
Browse files Browse the repository at this point in the history
Migrate the add method into the class itself. Also, a BaseModel isn't
needed here since this isn't a serialized class.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Nov 20, 2024
1 parent 8ffc636 commit c652a6e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 23 deletions.
32 changes: 13 additions & 19 deletions common/multimodal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import List
from backends.exllamav2.vision import get_image_embedding
from common import model
from pydantic import BaseModel
from loguru import logger

from common.optional_dependencies import dependencies
Expand All @@ -10,27 +9,22 @@
from exllamav2 import ExLlamaV2VisionTower


class MultimodalEmbeddingWrapper(BaseModel):
class MultimodalEmbeddingWrapper:
"""Common multimodal embedding wrapper"""

type: str = None
content: List = []
text_alias: List[str] = []


async def add_image_embedding(
embeddings: MultimodalEmbeddingWrapper, url: str
) -> MultimodalEmbeddingWrapper:
# Determine the type of vision embedding to use
if not embeddings.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
embeddings.type = "ExLlamaV2MMEmbedding"

if embeddings.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
embeddings.content.append(embedding)
embeddings.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")

return embeddings
async def add(self, url: str):
# Determine the type of vision embedding to use
if not self.type:
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
self.type = "ExLlamaV2MMEmbedding"

if self.type == "ExLlamaV2MMEmbedding":
embedding = await get_image_embedding(url)
self.content.append(embedding)
self.text_alias.append(embedding.text_alias)
else:
logger.error("No valid vision model to create embedding")
6 changes: 2 additions & 4 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from loguru import logger

from common import model
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import (
get_generator_error,
handle_request_disconnect,
Expand Down Expand Up @@ -483,9 +483,7 @@ async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url":
embeddings = await add_image_embedding(
embeddings, content.image_url.url
)
await embeddings.add(content.image_url.url)
concatenated_content += embeddings.text_alias[-1]

message.content = concatenated_content
Expand Down

0 comments on commit c652a6e

Please sign in to comment.