Skip to content

Commit

Permalink
Templates: Switch to async jinja engine
Browse files Browse the repository at this point in the history
This prevents any possible blocking of the event loop due to template
rendering.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed Aug 17, 2024
1 parent b4752c1 commit a51acb9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
11 changes: 5 additions & 6 deletions common/templating.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""

from functools import lru_cache
import json
import pathlib
from importlib.metadata import version as package_version
Expand Down Expand Up @@ -33,11 +32,11 @@ class PromptTemplate:
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
trim_blocks=True, lstrip_blocks=True, enable_async=True
)
metadata: Optional[TemplateMetadata] = None

def extract_metadata(self, template_vars: dict):
async def extract_metadata(self, template_vars: dict):
"""
Returns deserialized template metadata from a chat template.
Expand All @@ -52,7 +51,7 @@ def extract_metadata(self, template_vars: dict):

template_metadata = TemplateMetadata()

template_module = self.template.make_module(template_vars)
template_module = await self.template.make_module_async(template_vars)

if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
Expand All @@ -74,7 +73,7 @@ def extract_metadata(self, template_vars: dict):
self.metadata = template_metadata
return template_metadata

def render(self, template_vars: dict):
async def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
Expand All @@ -84,7 +83,7 @@ def render(self, template_vars: dict):
"pip install --upgrade jinja2"
)

rendered_template = self.template.render(**template_vars)
rendered_template = await self.template.render_async(**template_vars)

return rendered_template

Expand Down
2 changes: 1 addition & 1 deletion endpoints/OAI/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ async def chat_completion_request(
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
prompt = await format_prompt_with_template(data)

# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
Expand Down
18 changes: 11 additions & 7 deletions endpoints/OAI/utils/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,10 @@ def _create_stream_chunk(
return chunk


def _append_template_metadata(data: ChatCompletionRequest):
async def _append_template_metadata(data: ChatCompletionRequest):
"""Adding metadata is a one-time process."""

template_metadata = model.container.prompt_template.extract_metadata(
template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars
)

Expand All @@ -200,7 +200,7 @@ def _append_template_metadata(data: ChatCompletionRequest):
data.stop.extend(template_metadata.tool_starts)


def format_prompt_with_template(
async def format_prompt_with_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
Expand Down Expand Up @@ -242,7 +242,7 @@ def format_prompt_with_template(
}
)

prompt = model.container.prompt_template.render(data.template_vars)
prompt = await model.container.prompt_template.render(data.template_vars)

# Append response prefix if present
if data.response_prefix:
Expand All @@ -261,7 +261,9 @@ def format_prompt_with_template(
prompt = prompt.removeprefix(bos_token)

# Add template metadata
_append_template_metadata(data)
await _append_template_metadata(data)
print(prompt)
print(model.container.prompt_template.metadata.tool_starts)

return prompt

Expand Down Expand Up @@ -441,11 +443,13 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = format_prompt_with_template(data, gen["text"])
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = format_prompt_with_template(data, current_generations)
pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)

gen_tasks.append(
asyncio.create_task(
Expand Down

0 comments on commit a51acb9

Please sign in to comment.