diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index eaa431c..e82c94e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -31,7 +31,6 @@ ) from itertools import zip_longest from loguru import logger -from PIL import Image from typing import List, Optional, Union from ruamel.yaml import YAML @@ -374,6 +373,8 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs): self.draft_config.max_input_len = chunk_size self.draft_config.max_attention_size = chunk_size**2 + self.prompt_template = None + # Return the created instance return self @@ -875,17 +876,18 @@ async def unload(self, loras_only: bool = False, **kwargs): async with self.load_condition: self.load_condition.notify_all() - def encode_tokens( - self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs - ): + def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string.""" + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + return ( self.tokenizer.encode( text, add_bos=unwrap(kwargs.get("add_bos_token"), True), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), - embeddings=embeddings.content, + embeddings=mm_embeddings_content, ) .flatten() .tolist() @@ -931,7 +933,6 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor): async def generate( self, prompt: str, - embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: asyncio.Event = None, **kwargs, @@ -939,7 +940,7 @@ async def generate( """Generate a response to a prompt.""" generations = [] async for generation in self.generate_gen( - prompt, embeddings, request_id, abort_event, **kwargs + prompt, request_id, abort_event, **kwargs ): generations.append(generation) @@ -1005,7 +1006,6 @@ def check_unsupported_settings(self, **kwargs): async def generate_gen( self, prompt: str, - embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: Optional[asyncio.Event] = None, **kwargs, @@ -1270,13 +1270,17 @@ async def generate_gen( else: stop_conditions += eos_tokens + # Get multimodal embeddings if present + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + # Encode both positive and negative prompts input_ids = [ self.tokenizer.encode( prompt, add_bos=add_bos_token, encode_special_tokens=True, - embeddings=embeddings.content, + embeddings=mm_embeddings_content, ) for prompt in prompts ] @@ -1327,7 +1331,7 @@ async def generate_gen( banned_strings=banned_strings, token_healing=token_healing, identifier=job_id, - embeddings=embeddings.content, + embeddings=mm_embeddings_content, ) # Save generated tokens and full response diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 8403a87..8f4e7a4 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,4 @@ import asyncio -from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -16,9 +15,8 @@ ) from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( - format_prompt_with_template, + apply_chat_template, generate_chat_completion, - preprocess_vision_request, stream_generate_chat_completion, ) from endpoints.OAI.utils.completion import ( @@ -125,15 +123,7 @@ async def chat_completion_request( model_path = model.container.model_dir - embeddings = MultimodalEmbeddingWrapper() - - if isinstance(data.messages, str): - prompt = data.messages - else: - if model.container.use_vision: - data.messages, embeddings = await preprocess_vision_request(data.messages) - - prompt = await format_prompt_with_template(data) + prompt, embeddings = await apply_chat_template(data) # Set an empty JSON schema if the request wants a JSON response if data.response_format.type == "json": diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 7a31f39..84905db 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -177,11 +177,11 @@ def _create_stream_chunk( return chunk -async def _append_template_metadata(data: ChatCompletionRequest): +async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" template_metadata = await model.container.prompt_template.extract_metadata( - data.template_vars + template_vars ) # Stop strings @@ -199,7 +199,43 @@ async def _append_template_metadata(data: ChatCompletionRequest): data.stop.extend(template_metadata.tool_starts) -async def format_prompt_with_template( +async def format_messages_with_template( + messages: List[ChatCompletionMessage], + existing_template_vars: Optional[dict] = None, + add_bos_token: bool = True, + ban_eos_token: bool = False, +): + """Barebones function to format chat completion messages into a prompt.""" + + template_vars = unwrap(existing_template_vars, {}) + mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + + for message in messages: + if isinstance(message.content, list): + concatenated_content = "" + for content in message.content: + if content.type == "text": + concatenated_content += content.text + elif content.type == "image_url" and mm_embeddings: + await mm_embeddings.add(content.image_url.url) + concatenated_content += mm_embeddings.text_alias[-1] + + if message.tool_calls: + message.tool_calls_json = json.dumps(message.tool_calls, indent=2) + + message.content = concatenated_content + + special_tokens_dict = model.container.get_special_tokens( + add_bos_token, ban_eos_token + ) + + template_vars.update({"messages": messages, **special_tokens_dict}) + + prompt = await model.container.prompt_template.render(template_vars) + return prompt, mm_embeddings, template_vars + + +async def apply_chat_template( data: ChatCompletionRequest, tool_precursor: Optional[str] = None ): """ @@ -208,40 +244,18 @@ async def format_prompt_with_template( """ try: - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True), - unwrap(data.ban_eos_token, False), - ) - - # Convert list to text-based content - # Use the first instance of text inside the part list - for message in data.messages: - if isinstance(message.content, list): - message.content = next( - ( - content.text - for content in message.content - if content.type == "text" - ), - "", - ) - - if message.tool_calls: - message.tool_calls_json = json.dumps(message.tool_calls, indent=2) - - # Overwrite any protected vars with their values data.template_vars.update( { - "messages": data.messages, "add_generation_prompt": data.add_generation_prompt, "tools_json": json.dumps(data.model_dump()["tools"], indent=2), "functions_json": json.dumps(data.functions, indent=2), "tool_precursor": tool_precursor, - **special_tokens_dict, } ) - prompt = await model.container.prompt_template.render(data.template_vars) + prompt, mm_embeddings, template_vars = await format_messages_with_template( + data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token + ) # Append response prefix if present if data.response_prefix: @@ -255,14 +269,14 @@ async def format_prompt_with_template( # Removes the starting BOS token if present # This is to prevent add_bos_token from adding multiple bos tokens - bos_token = special_tokens_dict.get("bos_token") + bos_token = template_vars.get("bos_token") if bos_token and prompt.startswith(bos_token): prompt = prompt.removeprefix(bos_token) # Add template metadata - await _append_template_metadata(data) + await _append_template_metadata(data, template_vars) - return prompt + return prompt, mm_embeddings except KeyError as exc: error_message = handle_request_error( @@ -302,9 +316,9 @@ async def stream_generate_chat_completion( n, gen_queue, prompt, - embeddings, request.state.id, abort_event, + embeddings=embeddings, **task_gen_params.model_dump(exclude={"prompt"}), ) ) @@ -391,8 +405,8 @@ async def generate_chat_completion( asyncio.create_task( model.container.generate( prompt, - embeddings, request.state.id, + embeddings=embeddings, **data.model_dump(exclude={"prompt"}), ) ) @@ -439,13 +453,11 @@ async def generate_tool_calls( if gen["stop_str"] in tool_data.tool_call_start: if "text" in gen: # non streaming, all generations will have the text they generated - pre_tool_prompt = await format_prompt_with_template(data, gen["text"]) + pre_tool_prompt = await apply_chat_template(data, gen["text"]) elif current_generations is not None: # streaming, we wont have text in the generation, # we'll have to use the current_generations - pre_tool_prompt = await format_prompt_with_template( - data, current_generations - ) + pre_tool_prompt = await apply_chat_template(data, current_generations) gen_tasks.append( asyncio.create_task( @@ -471,21 +483,3 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]: tool_call["function"]["arguments"] ) return [ToolCall(**tool_call) for tool_call in tool_calls] - - -# TODO: Combine this with the existing preprocessor in format_prompt_with_template -async def preprocess_vision_request(messages: List[ChatCompletionMessage]): - embeddings = MultimodalEmbeddingWrapper() - for message in messages: - if isinstance(message.content, list): - concatenated_content = "" - for content in message.content: - if content.type == "text": - concatenated_content += content.text - elif content.type == "image_url": - await embeddings.add(content.image_url.url) - concatenated_content += embeddings.text_alias[-1] - - message.content = concatenated_content - - return messages, embeddings diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index e798176..9fd8b90 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,7 +7,6 @@ import asyncio import pathlib from asyncio import CancelledError -from common.multimodal import MultimodalEmbeddingWrapper from fastapi import HTTPException, Request from typing import List, Union @@ -88,7 +87,6 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, - embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: asyncio.Event, **kwargs, @@ -97,7 +95,7 @@ async def _stream_collector( try: new_generation = model.container.generate_gen( - prompt, embeddings, request_id, abort_event, **kwargs + prompt, request_id, abort_event, **kwargs ) async for generation in new_generation: generation["index"] = task_idx diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 64450f4..ccb26d9 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -1,6 +1,7 @@ import asyncio import pathlib from sys import maxsize +from typing import Optional from common.multimodal import MultimodalEmbeddingWrapper from fastapi import APIRouter, Depends, HTTPException, Request, Response from sse_starlette import EventSourceResponse @@ -14,6 +15,7 @@ from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap from common.health import HealthManager +from endpoints.OAI.utils.chat_completion import format_messages_with_template from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse @@ -359,61 +361,48 @@ async def unload_embedding_model(): ) async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: """Encodes a string or chat completion messages into tokens.""" - embeddings = MultimodalEmbeddingWrapper() + + mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None if isinstance(data.text, str): text = data.text - elif isinstance(data.text, list) and "oai" in config.network.api_servers: - # TODO: Support additional chat completion args for encode - # i.e. add_generation_prompt, template selection, tool args, template kwargs - if model.container.prompt_template is None: + elif isinstance(data.text, list): + if "oai" not in config.network.api_servers: error_message = handle_request_error( - "Tokenization of chat completion requests is disabled " - "because a prompt template is not set.", + "Enable the OAI server to handle chat completion messages.", exc_info=False, ).error.message raise HTTPException(422, error_message) - from endpoints.OAI.utils.chat_completion import preprocess_vision_request - - if model.container.use_vision: - data.text, embeddings = await preprocess_vision_request(data.text) - - # Keeping behavior consistent with format_prompt_with_template - # Deal with list in messages.content - # Just replace the content list with the very first text message - for message in data.text: - if isinstance(message["content"], list): - message["content"] = next( - ( - content["text"] - for content in message["content"] - if content["type"] == "text" - ), - "", - ) - - special_tokens_dict = model.container.get_special_tokens( - unwrap(data.add_bos_token, True) - ) + if not model.container.prompt_template: + error_message = handle_request_error( + "Cannot tokenize chat completion message because " + + "a prompt template is not set.", + exc_info=False, + ).error.message + + raise HTTPException(422, error_message) template_vars = { - "messages": data.text, "add_generation_prompt": False, - **special_tokens_dict, } - text = await model.container.prompt_template.render(template_vars) + # Don't need template vars again + text, mm_embeddings, _ = await format_messages_with_template( + data.text, template_vars, data.add_bos_token + ) else: error_message = handle_request_error( - "OAI API server must be enabled to handle chat completion message inputs.", + "Unable to tokenize the provided text. Check your formatting?", exc_info=False, ).error.message raise HTTPException(422, error_message) - raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params()) + raw_tokens = model.container.encode_tokens( + text, embeddings=mm_embeddings, **data.get_params() + ) tokens = unwrap(raw_tokens, []) response = TokenEncodeResponse(tokens=tokens, length=len(tokens)) diff --git a/endpoints/core/types/token.py b/endpoints/core/types/token.py index 2c205ab..d43e65e 100644 --- a/endpoints/core/types/token.py +++ b/endpoints/core/types/token.py @@ -1,7 +1,9 @@ """Tokenization types""" from pydantic import BaseModel -from typing import Dict, List, Union +from typing import List, Union + +from endpoints.OAI.types.chat_completion import ChatCompletionMessage class CommonTokenRequest(BaseModel): @@ -23,7 +25,7 @@ def get_params(self): class TokenEncodeRequest(CommonTokenRequest): """Represents a tokenization request.""" - text: Union[str, List[Dict]] + text: Union[str, List[ChatCompletionMessage]] class TokenEncodeResponse(BaseModel):