Skip to content

Commit

Permalink
OAI: Add ability to pass extra vars in jinja templates
Browse files Browse the repository at this point in the history
A chat completion can now declare extra template_vars to pass when
a template is rendered, opening up the possibility of using state
outside of huggingface's parameters.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Apr 11, 2024
1 parent b1f3baa commit 2a0aaa2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
1 change: 1 addition & 0 deletions endpoints/OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}


class ChatCompletionResponse(BaseModel):
Expand Down
15 changes: 9 additions & 6 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,17 @@ def format_prompt_with_template(data: ChatCompletionRequest):
unwrap(data.ban_eos_token, False),
)

template_vars = {
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
**special_tokens_dict,
}
# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
**special_tokens_dict,
}
)

prompt, template_stop_strings = get_prompt_from_template(
model.container.prompt_template, template_vars
model.container.prompt_template, data.template_vars
)

# Append template stop strings
Expand Down

0 comments on commit 2a0aaa2

Please sign in to comment.