Skip to content

Commit

Permalink
OAI: Add response_prefix and fix BOS token issues in chat completions
Browse files Browse the repository at this point in the history
response_prefix is used to add a prefix before generating the next
message. This is used in many cases such as continuining a prompt
(see #96).

Also if a template has BOS token specified, add_bos_token will
append two BOS tokens. Add a check which strips a starting BOS token
from the prompt if it exists.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Apr 25, 2024
1 parent ed7cd3c commit fb1d2f3
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 1 deletion.
1 change: 1 addition & 0 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,6 +878,7 @@ def generate_gen_sync(
encode_special_tokens=True,
return_offsets=True,
)
print(ids)
mask = (
self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
Expand Down
1 change: 1 addition & 0 deletions endpoints/OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None


class ChatCompletionResponse(BaseModel):
Expand Down
17 changes: 17 additions & 0 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from fastapi import HTTPException
from jinja2 import TemplateError
from loguru import logger

from common import model
from common.networking import (
Expand Down Expand Up @@ -153,6 +154,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
data.template_vars
)

# Append response prefix if present
if data.response_prefix:
if data.add_generation_prompt:
prompt += data.response_prefix
else:
logger.warning(
"Could not add response prefix because "
"add_generation_prompt is False"
)

# Removes the starting BOS token if present
# This is to prevent add_bos_token from adding multiple bos tokens
bos_token = special_tokens_dict.get("bos_token")
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)

# Append template stop strings
if isinstance(data.stop, str):
data.stop = [data.stop] + template_stop_strings
Expand Down
2 changes: 1 addition & 1 deletion endpoints/OAI/utils/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)

try:
generation = await model.container.generate(data.prompt, **data.to_gen_params())

response = _create_response(generation, model_path.name)

return response
except Exception as exc:
error_message = handle_request_error(
Expand Down

0 comments on commit fb1d2f3

Please sign in to comment.