Skip to content

Commit

Permalink
fix set system message to go down to operator.system
Browse files Browse the repository at this point in the history
  • Loading branch information
matankley committed Sep 1, 2023
1 parent deb0c42 commit eb586af
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
27 changes: 27 additions & 0 deletions docs/features/jinja_templating.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,30 @@ sentiment_classification(string="I love this product but there are some annoying
2. The user message is generated based on the docstring of the function. The Jinja template is rendered with the provided parameters.


Same thing can be done with the `chat` decorator:

```python
import declarai

gpt_35 = declarai.openai(model="gpt-3.5-turbo")


@gpt_35.experimental.chat
class TranslatorBot:
"""
You are a translator bot,
You will translate the provided text from English to {{ language }}.
Do not translate the following categories of words: {{ exclude_words }}
"""


bot = TranslatorBot(language="French", exclude_words=["bad words"])

bot.compile()

>>> {'messages': [
system: You are a translator bot, You will translate the provided text from English to French.
Do not translate the following categories of words: ['bad words']
]
}
```
12 changes: 8 additions & 4 deletions src/declarai/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,17 @@ def __init__(
self.operator = operator
self._chat_history = chat_history or DEFAULT_CHAT_HISTORY()
self.greeting = greeting or self.operator.greeting
self.system = system or self.operator.system
self.system = self.__set_system_prompt(system=system, **kwargs)
self.__set_memory()
self.__set_system_prompt(**kwargs)

def __set_system_prompt(self, **kwargs):
def __set_system_prompt(self, system: str, **kwargs) -> str:
if system:
self.operator.system = system
if kwargs:
self.system = format_prompt_msg(self.system, **kwargs)
formatted_system = format_prompt_msg(self.operator.system, **kwargs)
self.operator.system = formatted_system

return self.operator.system

def __set_memory(self):
if self.greeting and len(self._chat_history.history) == 0:
Expand Down
11 changes: 10 additions & 1 deletion tests/api/test_chat_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,13 @@ class OverrideChatParams:
chat2 = OverrideChatParams(greeting="New Message")

assert chat2.__name__ == "OverrideChatParams"
assert chat2.greeting == "New Message"
assert chat2.greeting == "New Message"

@declarai.experimental.chat(system="This is a decorated chat.\n", greeting="This is a greeting message")
class ChatWithParamsDecorated:
...

chat3 = ChatWithParamsDecorated()

assert chat3.system == "This is a decorated chat.\n"
assert chat3.greeting == "This is a greeting message"

0 comments on commit eb586af

Please sign in to comment.