Skip to content

Commit

Permalink
Merge pull request #118 from vendi-ai/99-support-jinja-templates-in-d…
Browse files Browse the repository at this point in the history
…ocstrings

Support jinja templates in docstrings
  • Loading branch information
matankley authored Sep 1, 2023
2 parents 0fbaafa + 3451b60 commit deb0c42
Show file tree
Hide file tree
Showing 11 changed files with 296 additions and 276 deletions.
24 changes: 23 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ Declarai aims to promote clean and readable code by enforcing the use of doc-str
The resulting code is readable and easily maintainable.


### Tasks with python native output parsing:
### Tasks with python native output parsing

Python primitives
```python
Expand Down Expand Up @@ -262,9 +262,31 @@ suggest_animals(location="jungle")
]
}
```
### Jinja templates 🥷
```python
import declarai

gpt_35 = declarai.openai(model="gpt-3.5-turbo")

@gpt_35.task
def sentiment_classification(string: str, examples: List[str, int]) -> int:
"""
Classify the sentiment of the provided string, based on the provided examples.
The sentiment is ranked on a scale of 1-5, with 5 being the most positive.
{% for example in examples %}
{{ example[0] }} // {{ example[1] }}
{% endfor %}
{{ string }} //
"""

sentiment_classification(string="I love this product but there are some annoying bugs",
examples=[["I love this product", 5], ["I hate this product", 1]])

>>> 4
```

### Simple Chat interface

```python
import declarai

Expand Down
54 changes: 54 additions & 0 deletions docs/features/jinja_templating.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
## Jinja Templating

[Jinja](https://jinja.palletsprojects.com/en/2.11.x/) is a templating language for Python.

We can use Jinja to create templates for our tasks. This is useful when:
- Task has a lot of boilerplate code
- Task has a lot of parameters.
- You want to control the task's prompt structure.

For example, let's say we want to create a task that takes in a string and ranks its sentiment. We
can use Jinja to create a template for this task:

```python
import declarai
from typing import List

gpt_35 = declarai.openai(model="gpt-3.5-turbo")


@gpt_35.task
def sentiment_classification(string: str, examples: List[str, int]) -> int:
"""
Classify the sentiment of the provided string, based on the provided examples.
The sentiment is ranked on a scale of 1-5, with 5 being the most positive.
{% for example in examples %}
{{ example[0] }} // {{ example[1] }}
{% endfor %}
{{ string }} //
"""


sentiment_classification.compile(string="I love this product but there are some annoying bugs",
examples=[["I love this product", 5], ["I hate this product", 1]])

>>> {'messages': [
system: respond only with the value of type int:, # (1)!
user: Classify the sentiment of the provided string, based on the provided examples. The sentiment is ranked on a scale of 1-5, with 5 being the most positive. # (2)!
I love this product // 5
I hate this product // 1
I love this product //
]
}

sentiment_classification(string="I love this product but there are some annoying bugs",
examples=[["I love this product", 5], ["I hate this product", 1]])

>>> 4
```


1. The system message is generated based on the return type `int` of the function.
2. The user message is generated based on the docstring of the function. The Jinja template is rendered with the provided parameters.


24 changes: 18 additions & 6 deletions src/declarai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
MessageRole,
resolve_operator,
)
from declarai.operators.utils import format_prompt_msg
from declarai.python_parser.parser import PythonParser
from declarai.task import Task

Expand Down Expand Up @@ -89,6 +90,8 @@ class Chat(BaseTask, metaclass=ChatMeta):
`DEFAULT_CHAT_HISTORY()`.
greeting (str, optional): Greeting message to use. Defaults to operator's greeting or None.
system (str, optional): System message to use. Defaults to operator's system message or None.
stream (bool, optional): Whether to stream the response from the LLM or not. Defaults to False.
**kwargs: Additional keyword arguments to pass to the formatting of the system message.
"""

is_declarai = True
Expand All @@ -103,13 +106,19 @@ def __init__(
chat_history: BaseChatMessageHistory = None,
greeting: str = None,
system: str = None,
**kwargs,
):
self.middlewares = middlewares
self.operator = operator
self._chat_history = chat_history or DEFAULT_CHAT_HISTORY()
self.greeting = greeting or self.operator.greeting
self.system = system or self.operator.system
self.__set_memory()
self.__set_system_prompt(**kwargs)

def __set_system_prompt(self, **kwargs):
if kwargs:
self.system = format_prompt_msg(self.system, **kwargs)

def __set_memory(self):
if self.greeting and len(self._chat_history.history) == 0:
Expand Down Expand Up @@ -186,7 +195,7 @@ def _exec_middlewares(self, kwargs) -> Any:
return self._exec(kwargs)

def __call__(
self, *, messages: List[Message], llm_params: LLMParamsType = None, **kwargs
self, *, messages: List[Message], llm_params: LLMParamsType = None
) -> Any:
"""
Executes the call to the LLM, based on the messages passed as argument, and the llm_params.
Expand All @@ -195,21 +204,20 @@ def __call__(
Args:
messages: The messages to pass to the LLM.
llm_params: The llm_params to use for the call to the LLM.
**kwargs: run time kwargs to use when formatting the system message prompt.
Returns:
The parsed response from the LLM.
"""
kwargs["messages"] = messages
runtime_kwargs = dict(messages=messages)
runtime_llm_params = (
llm_params or self.llm_params
) # order is important! We prioritize runtime params that
if runtime_llm_params:
kwargs["llm_params"] = runtime_llm_params
runtime_kwargs["llm_params"] = runtime_llm_params

self._call_kwargs = kwargs
return self._exec_middlewares(kwargs)
self._call_kwargs = runtime_kwargs
return self._exec_middlewares(runtime_kwargs)

def send(
self,
Expand Down Expand Up @@ -272,6 +280,9 @@ def chat(
middlewares: List[TaskMiddleware] = None,
llm_params: LLMParamsType = None,
chat_history: BaseChatMessageHistory = None,
greeting: str = None,
system: str = None,
streaming: bool = None,
**kwargs,
) -> Callable[..., Type[Chat]]:
"""
Expand Down Expand Up @@ -304,6 +315,7 @@ def chat(
chat_history (BaseChatMessageHistory, optional): Chat history mechanism to use. Defaults to None.
greeting (str, optional): Greeting message to use. Defaults to None.
system (str, optional): System message to use. Defaults to None.
streaming (bool, optional): Whether to use streaming or not. Defaults to None.
Returns:
(Type[Chat]): A new Chat class that inherits from the original class and has chat capabilities.
Expand Down
58 changes: 1 addition & 57 deletions src/declarai/operators/openai_operators/chat_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,9 @@
Chat implementation of OpenAI operator.
"""
import logging
from typing import List

from declarai.operators import Message, MessageRole
from declarai.operators.openai_operators.openai_llm import AzureOpenAILLM, OpenAILLM
from declarai.operators.operator import BaseChatOperator, CompiledTemplate
from declarai.operators.operator import BaseChatOperator
from declarai.operators.registry import register_operator
from declarai.operators.templates import (
StructuredOutputChatPrompt,
compile_output_prompt,
)

logger = logging.getLogger("OpenAIChatOperator")

Expand All @@ -27,55 +20,6 @@ class OpenAIChatOperator(BaseChatOperator):

llm: OpenAILLM

def _compile_output_prompt(self, template) -> str:
if not self.parsed_send_func.has_any_return_defs:
logger.warning(
"Couldn't create output schema for function %s."
"Falling back to unstructured output."
"Please add at least one of the following: return type, return doc, return name",
self.parsed_send_func.name,
)
return ""

signature_return = self.parsed_send_func.signature_return
return_name, return_doc = self.parsed_send_func.docstring_return
return compile_output_prompt(
return_type=signature_return.str_schema,
str_schema=return_name,
return_docstring=return_doc,
return_magic=self.parsed_send_func.magic.return_name,
structured=self.parsed_send_func.has_structured_return_type,
structured_template=template,
)

def compile(self, messages: List[Message], **kwargs) -> CompiledTemplate:
"""
Implementation of the compile method for the chat operator.
Compiles a system prompt based on the initialized system message
Compiles the message based on the user input and the StructuredOutputChatPrompt template
Args:
messages (List[Message]): A list of messages
**kwargs:
Returns:
"""
self.system = self.system.format(**kwargs)
structured_template = StructuredOutputChatPrompt
if self.parsed_send_func:
output_schema = self._compile_output_prompt(structured_template)
else:
output_schema = None

if output_schema:
compiled_system_prompt = f"{self.system}/n{output_schema}"
else:
compiled_system_prompt = self.system
messages = [
Message(message=compiled_system_prompt, role=MessageRole.system)
] + messages
return {"messages": messages}


@register_operator(provider="azure-openai", operator_type="chat")
class AzureOpenAIChatOperator(OpenAIChatOperator):
Expand Down
31 changes: 10 additions & 21 deletions src/declarai/operators/openai_operators/task_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from declarai.operators.registry import register_operator
from declarai.operators.message import Message, MessageRole
from declarai.operators.operator import BaseOperator, CompiledTemplate
from ..utils import can_be_jinja
from declarai.operators.templates import (
InstructFunctionTemplate,
StructuredOutputInstructionPrompt,
Expand Down Expand Up @@ -104,28 +105,16 @@ def compile_template(self) -> CompiledTemplate:
if output_schema:
messages.append(Message(message=output_schema, role=MessageRole.system))

populated_instruction = instruction_template.format(
input_instructions=self.parsed.docstring_freeform,
input_placeholder=self._compile_input_placeholder(),
)
messages.append(Message(message=populated_instruction, role=MessageRole.user))
return messages

def compile(self, **kwargs) -> CompiledTemplate:
"""
Implements the compile method of the BaseOperator class.
Args:
**kwargs:
Returns:
Dict[str, List[Message]]: A dictionary containing a list of messages.
if not can_be_jinja(self.parsed.docstring_freeform):
instruction_message = instruction_template.format(
input_instructions=self.parsed.docstring_freeform,
input_placeholder=self._compile_input_placeholder(),
)
else:
instruction_message = self.parsed.docstring_freeform

"""
template = self.compile_template()
if kwargs:
template[-1].message = template[-1].message.format(**kwargs)
return {"messages": template}
return {"messages": template}
messages.append(Message(message=instruction_message, role=MessageRole.user))
return messages


@register_operator(provider="azure-openai", operator_type="task")
Expand Down
Loading

0 comments on commit deb0c42

Please sign in to comment.