Skip to content

Commit

Permalink
OAI: Tokenize chat completion messages
Browse files Browse the repository at this point in the history
Since chat completion messages are a structure, format the prompt
before checking in the tokenizer.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Apr 15, 2024
1 parent ed05f37 commit 515b3c2
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 5 deletions.
4 changes: 3 additions & 1 deletion backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,9 @@ def decode_tokens(self, ids: List[int], **kwargs):
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]

def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
Expand Down
23 changes: 21 additions & 2 deletions endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import coalesce, unwrap
Expand Down Expand Up @@ -386,8 +387,26 @@ async def unload_loras():
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string into tokens."""
raw_tokens = model.container.encode_tokens(data.text, **data.get_params())
"""Encodes a string or chat completion messages into tokens."""

if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)

template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}

text, _ = get_prompt_from_template(
model.container.prompt_template, template_vars
)

raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))

Expand Down
4 changes: 2 additions & 2 deletions endpoints/OAI/types/token.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tokenization types"""

from pydantic import BaseModel
from typing import List
from typing import Dict, List, Union


class CommonTokenRequest(BaseModel):
Expand All @@ -23,7 +23,7 @@ def get_params(self):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""

text: str
text: Union[str, List[Dict[str, str]]]


class TokenEncodeResponse(BaseModel):
Expand Down

0 comments on commit 515b3c2

Please sign in to comment.