From 2a0aaa2e8a231dd10f918324b8b7f7c1a52b616f Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 11 Apr 2024 00:55:32 -0400 Subject: [PATCH] OAI: Add ability to pass extra vars in jinja templates A chat completion can now declare extra template_vars to pass when a template is rendered, opening up the possibility of using state outside of huggingface's parameters. Signed-off-by: kingbri --- endpoints/OAI/types/chat_completion.py | 1 + endpoints/OAI/utils/chat_completion.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 7d3138e5..5c1151f4 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -44,6 +44,7 @@ class ChatCompletionRequest(CommonCompletionRequest): messages: Union[str, List[Dict[str, str]]] prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True + template_vars: Optional[dict] = {} class ChatCompletionResponse(BaseModel): diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 4b3d39cd..0ddaa942 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -141,14 +141,17 @@ def format_prompt_with_template(data: ChatCompletionRequest): unwrap(data.ban_eos_token, False), ) - template_vars = { - "messages": data.messages, - "add_generation_prompt": data.add_generation_prompt, - **special_tokens_dict, - } + # Overwrite any protected vars with their values + data.template_vars.update( + { + "messages": data.messages, + "add_generation_prompt": data.add_generation_prompt, + **special_tokens_dict, + } + ) prompt, template_stop_strings = get_prompt_from_template( - model.container.prompt_template, template_vars + model.container.prompt_template, data.template_vars ) # Append template stop strings