Skip to content

Commit

Permalink
Reuse jinja environment for a prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
jantrienes committed Sep 19, 2024
1 parent 289ef5d commit 94a6d8c
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions outlines/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class Prompt:

def __post_init__(self):
self.parameters: List[str] = list(self.signature.parameters.keys())
self.jinja_environment = create_jinja_template(self.template)

def __call__(self, *args, **kwargs) -> str:
"""Render and return the template.
Expand All @@ -35,7 +36,7 @@ def __call__(self, *args, **kwargs) -> str:
"""
bound_arguments = self.signature.bind(*args, **kwargs)
bound_arguments.apply_defaults()
return render(self.template, **bound_arguments.arguments)
return self.jinja_environment.render(**bound_arguments.arguments)

def __str__(self):
return self.template
Expand Down Expand Up @@ -182,6 +183,11 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
A string that contains the rendered template.
"""
jinja_template = create_jinja_template(template)
return jinja_template.render(**values)


def create_jinja_template(template: str):
# Dedent, and remove extra linebreak
cleaned_template = inspect.cleandoc(template)

Expand Down Expand Up @@ -210,8 +216,7 @@ def render(template: str, **values: Optional[Dict[str, Any]]) -> str:
env.filters["args"] = get_fn_args

jinja_template = env.from_string(cleaned_template)

return jinja_template.render(**values)
return jinja_template


def get_fn_name(fn: Callable):
Expand Down

0 comments on commit 94a6d8c

Please sign in to comment.