Skip to content

Commit

Permalink
API: Fix chat completion formatting flow
Browse files Browse the repository at this point in the history
Previously, the flow for parsing chat completion messages and rendering
from the prompt template was disconnected between endpoints. Now, create
a common function to render and handle everything appropriately afterwards.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Nov 21, 2024
1 parent c652a6e commit 902045e
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 116 deletions.
24 changes: 14 additions & 10 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
)
from itertools import zip_longest
from loguru import logger
from PIL import Image
from typing import List, Optional, Union

from ruamel.yaml import YAML
Expand Down Expand Up @@ -374,6 +373,8 @@ async def create(cls, model_directory: pathlib.Path, quiet=False, **kwargs):
self.draft_config.max_input_len = chunk_size
self.draft_config.max_attention_size = chunk_size**2

self.prompt_template = None

# Return the created instance
return self

Expand Down Expand Up @@ -875,17 +876,18 @@ async def unload(self, loras_only: bool = False, **kwargs):
async with self.load_condition:
self.load_condition.notify_all()

def encode_tokens(
self, text: str, embeddings: MultimodalEmbeddingWrapper, **kwargs
):
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string."""

mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []

return (
self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
embeddings=embeddings.content,
embeddings=mm_embeddings_content,
)
.flatten()
.tolist()
Expand Down Expand Up @@ -931,15 +933,14 @@ def get_logprobs(self, token_ids: torch.Tensor, token_probs: torch.Tensor):
async def generate(
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event = None,
**kwargs,
):
"""Generate a response to a prompt."""
generations = []
async for generation in self.generate_gen(
prompt, embeddings, request_id, abort_event, **kwargs
prompt, request_id, abort_event, **kwargs
):
generations.append(generation)

Expand Down Expand Up @@ -1005,7 +1006,6 @@ def check_unsupported_settings(self, **kwargs):
async def generate_gen(
self,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
Expand Down Expand Up @@ -1270,13 +1270,17 @@ async def generate_gen(
else:
stop_conditions += eos_tokens

# Get multimodal embeddings if present
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []

# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
embeddings=embeddings.content,
embeddings=mm_embeddings_content,
)
for prompt in prompts
]
Expand Down Expand Up @@ -1327,7 +1331,7 @@ async def generate_gen(
banned_strings=banned_strings,
token_healing=token_healing,
identifier=job_id,
embeddings=embeddings.content,
embeddings=mm_embeddings_content,
)

# Save generated tokens and full response
Expand Down
14 changes: 2 additions & 12 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
Expand All @@ -16,9 +15,8 @@
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
apply_chat_template,
generate_chat_completion,
preprocess_vision_request,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
Expand Down Expand Up @@ -125,15 +123,7 @@ async def chat_completion_request(

model_path = model.container.model_dir

embeddings = MultimodalEmbeddingWrapper()

if isinstance(data.messages, str):
prompt = data.messages
else:
if model.container.use_vision:
data.messages, embeddings = await preprocess_vision_request(data.messages)

prompt = await format_prompt_with_template(data)
prompt, embeddings = await apply_chat_template(data)

# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
Expand Down
104 changes: 49 additions & 55 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,11 @@ def _create_stream_chunk(
return chunk


async def _append_template_metadata(data: ChatCompletionRequest):
async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict):
"""Adding metadata is a one-time process."""

template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars
template_vars
)

# Stop strings
Expand All @@ -199,7 +199,43 @@ async def _append_template_metadata(data: ChatCompletionRequest):
data.stop.extend(template_metadata.tool_starts)


async def format_prompt_with_template(
async def format_messages_with_template(
messages: List[ChatCompletionMessage],
existing_template_vars: Optional[dict] = None,
add_bos_token: bool = True,
ban_eos_token: bool = False,
):
"""Barebones function to format chat completion messages into a prompt."""

template_vars = unwrap(existing_template_vars, {})
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None

for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url" and mm_embeddings:
await mm_embeddings.add(content.image_url.url)
concatenated_content += mm_embeddings.text_alias[-1]

if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)

message.content = concatenated_content

special_tokens_dict = model.container.get_special_tokens(
add_bos_token, ban_eos_token
)

template_vars.update({"messages": messages, **special_tokens_dict})

prompt = await model.container.prompt_template.render(template_vars)
return prompt, mm_embeddings, template_vars


async def apply_chat_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
Expand All @@ -208,40 +244,18 @@ async def format_prompt_with_template(
"""

try:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False),
)

# Convert list to text-based content
# Use the first instance of text inside the part list
for message in data.messages:
if isinstance(message.content, list):
message.content = next(
(
content.text
for content in message.content
if content.type == "text"
),
"",
)

if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)

# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
**special_tokens_dict,
}
)

prompt = await model.container.prompt_template.render(data.template_vars)
prompt, mm_embeddings, template_vars = await format_messages_with_template(
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
)

# Append response prefix if present
if data.response_prefix:
Expand All @@ -255,14 +269,14 @@ async def format_prompt_with_template(

# Removes the starting BOS token if present
# This is to prevent add_bos_token from adding multiple bos tokens
bos_token = special_tokens_dict.get("bos_token")
bos_token = template_vars.get("bos_token")
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)

# Add template metadata
await _append_template_metadata(data)
await _append_template_metadata(data, template_vars)

return prompt
return prompt, mm_embeddings

except KeyError as exc:
error_message = handle_request_error(
Expand Down Expand Up @@ -302,9 +316,9 @@ async def stream_generate_chat_completion(
n,
gen_queue,
prompt,
embeddings,
request.state.id,
abort_event,
embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
)
)
Expand Down Expand Up @@ -391,8 +405,8 @@ async def generate_chat_completion(
asyncio.create_task(
model.container.generate(
prompt,
embeddings,
request.state.id,
embeddings=embeddings,
**data.model_dump(exclude={"prompt"}),
)
)
Expand Down Expand Up @@ -439,13 +453,11 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
pre_tool_prompt = await apply_chat_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)
pre_tool_prompt = await apply_chat_template(data, current_generations)

gen_tasks.append(
asyncio.create_task(
Expand All @@ -471,21 +483,3 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]


# TODO: Combine this with the existing preprocessor in format_prompt_with_template
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url":
await embeddings.add(content.image_url.url)
concatenated_content += embeddings.text_alias[-1]

message.content = concatenated_content

return messages, embeddings
4 changes: 1 addition & 3 deletions endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import asyncio
import pathlib
from asyncio import CancelledError
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import HTTPException, Request
from typing import List, Union

Expand Down Expand Up @@ -88,7 +87,6 @@ async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event,
**kwargs,
Expand All @@ -97,7 +95,7 @@ async def _stream_collector(

try:
new_generation = model.container.generate_gen(
prompt, embeddings, request_id, abort_event, **kwargs
prompt, request_id, abort_event, **kwargs
)
async for generation in new_generation:
generation["index"] = task_idx
Expand Down
Loading

0 comments on commit 902045e

Please sign in to comment.