Skip to content

Commit

Permalink
refactor(prompts): prompt templates (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst authored Jul 2, 2024
1 parent a4fd411 commit 6510bd8
Show file tree
Hide file tree
Showing 54 changed files with 1,200 additions and 1,081 deletions.
10 changes: 5 additions & 5 deletions benchmark/dbally_benchmark/e2e_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
import dbally
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
from dbally.view_selection.view_selector_prompt_template import default_view_selector_template
from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE


async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult:
Expand Down Expand Up @@ -126,9 +126,9 @@ async def evaluate(cfg: DictConfig) -> Any:
logger.info(f"db-ally predictions saved under directory: {output_dir}")

if run:
run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat)
run["config/view_selection_prompt_template"] = stringify_unsupported(default_view_selector_template.chat)
run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template)
run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat)
run["config/view_selection_prompt_template"] = stringify_unsupported(VIEW_SELECTION_TEMPLATE.chat)
run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE)
run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix())
run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix())
run["evaluation/metrics"] = stringify_unsupported(metrics)
Expand Down
14 changes: 8 additions & 6 deletions benchmark/dbally_benchmark/iql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@

from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template
from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
from dbally.prompts.formatters import IQLInputFormatter
from dbally.views.structured import BaseStructuredView


Expand All @@ -32,14 +31,17 @@ async def _run_iql_for_single_example(
) -> IQLResult:
filter_list = view.list_filters()
event_tracker = EventTracker()
input_formatter = IQLInputFormatter(question=example.question, filters=filter_list)

try:
iql_filters, _ = await iql_generator.generate_iql(input_formatter=input_formatter, event_tracker=event_tracker)
iql_filters = await iql_generator.generate_iql(
question=example.question,
filters=filter_list,
event_tracker=event_tracker,
)
except UnsupportedQueryError:
return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True)

return IQLResult(question=example.question, iql_filters=iql_filters, exception_raised=False)
return IQLResult(question=example.question, iql_filters=str(iql_filters), exception_raised=False)


async def run_iql_for_dataset(
Expand Down Expand Up @@ -139,7 +141,7 @@ async def evaluate(cfg: DictConfig) -> Any:
logger.info(f"IQL predictions saved under directory: {output_dir}")

if run:
run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat)
run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat)
run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix())
run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix())
run["evaluation/metrics"] = stringify_unsupported(metrics)
Expand Down
2 changes: 1 addition & 1 deletion benchmark/dbally_benchmark/text2sql/prompt_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dbally.prompts import PromptTemplate
from dbally.prompt import PromptTemplate

TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(
(
Expand Down
3 changes: 1 addition & 2 deletions docs/about/roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,13 @@ Below you can find a list of planned features and integrations.
## Planned Features

- [ ] **Support analytical queries**: support for exposing operations beyond filtering.
- [ ] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to
- [x] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to
improve IQL generation accuracy.
- [ ] **Request contextualization**: allow to provide extra context for db-ally runs, such as user asking the question.
- [X] **OpenAI Assistants API adapter**: allow to embed db-ally into OpenAI's Assistants API to easily extend the
capabilities of the assistant.
- [ ] **Langchain adapter**: allow to embed db-ally into Langchain applications.


## Integrations

Being agnostic to the underlying technology is one of the main goals of db-ally.
Expand Down
25 changes: 6 additions & 19 deletions docs/how-to/llms/custom.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,42 +44,29 @@ class MyLLMClient(LLMClient[LiteLLMOptions]):

async def call(
self,
prompt: ChatFormat,
response_format: Optional[Dict[str, str]],
conversation: ChatFormat,
options: LiteLLMOptions,
event: LLMEvent,
json_mode: bool = False,
) -> str:
# Your LLM API call
```

The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response.
The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response in string format.

### Step 3: Use tokenizer to count tokens

The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the messages. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model.
The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the prompt. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model.

```python
class MyLLM(LLM[LiteLLMOptions]):

def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int:
# Count tokens in the messages in a custom way
def count_tokens(self, prompt: PromptTemplate) -> int:
# Count tokens in the prompt in a custom way
```
!!!warning
Incorrect token counting can cause problems in the [`NLResponder`](../../reference/nl_responder.md#dbally.nl_responder.nl_responder.NLResponder) and force the use of an explanation prompt template that is more generic and does not include specific rows from the IQL response.

### Step 4: Define custom prompt formatting

The [`format_prompt`](../../reference/llms/index.md#dbally.llms.base.LLM.format_prompt) method is used to apply formatting to the prompt template. You can override this method in your custom class to change how the formatting is performed.

```python
class MyLLM(LLM[LiteLLMOptions]):

def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat:
# Apply custom formatting to the prompt template
```
!!!note
In general, implementation of this method is not required unless the LLM API does not support [OpenAI conversation formatting](https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages){:target="_blank"}. If your model API expects a different format, override this method to avoid issues with inference call.

## Customising LLM Options

[`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) is a class that defines the options your LLM will use. To create a custom options, you need to create a subclass of [`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) and define the required properties that will be passed to the [`LLMClient`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient).
Expand Down
97 changes: 97 additions & 0 deletions docs/how-to/views/few-shots.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# How-To: Define few shots

There are many ways to improve the accuracy of IQL generation - one of them is to use few-shot prompting. db-ally allows you to inject few-shot examples for any type of defined view, both structured and freeform.

Few shots are defined in the [`list_few_shots`](../../reference/views/index.md#dbally.views.base.BaseView.list_few_shots) method, each few shot example should be an instance of [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) class that defines example question and expected LLM answer.

## Structured views

For structured views, both questions and answers for [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) can be defined as a strings, whereas in case of answers Python expressions are also allowed (please see lambda function in example below).

```python
from dbally.prompt.elements import FewShotExample
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

class RecruitmentView(SqlAlchemyBaseView):
"""
A view for retrieving candidates from the database.
"""

def list_few_shots(self) -> List[FewShotExample]:
return [
FewShotExample(
"Which candidates studied at University of Toronto?",
'studied_at("University of Toronto")',
),
FewShotExample(
"Do we have any soon available perfect fits for senior data scientist positions?",
lambda: (
self.is_available_within_months(1)
and self.data_scientist_position()
and self.has_seniority("senior")
),
),
...
]
```

## Freeform views

Currently freeform views accept SQL query syntax as a raw string. The larger variety of passing parameters is considered to be implemented in further db-ally releases.

```python
from dbally.prompt.elements import FewShotExample
from dbally.views.freeform.text2sql import BaseText2SQLView

class RecruitmentView(BaseText2SQLView):
"""
A view for retrieving candidates from the database.
"""

def list_few_shots(self) -> List[FewShotExample]:
return [
FewShotExample(
"Which candidates studied at University of Toronto?",
'SELECT name FROM candidates WHERE university = "University of Toronto"',
),
FewShotExample(
"Which clients are from NY?",
'SELECT name FROM clients WHERE city = "NY"',
),
...
]
```

## Prompt format

By default each few shot is injected subsequent to a system prompt message. The format is as follows:

```python
[
{
"role" "user",
"content": "Question",
},
{
"role": "assistant",
"content": "Answer",
}
]
```

If you use `examples` formatting tag in content field of the system or user message, all examples are going to be injected inside the message without additional conversation.

The example of prompt utilizing `examples` tag:

```python
[
{
"role" "system",
"content": "Here are example resonses:\n {examples}",
},
]
```

!!!info
There is no best way to inject a few shot example. Different models can behave diffrently based on few shots formatting of choice.
Generally, first appoach should yield the best results in most cases. Therefore, adding example tags in your custom prompts is not recommended.
2 changes: 0 additions & 2 deletions docs/reference/collection.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
!!! tip
To understand the general idea better, visit the [Collection concept page](../concepts/collections.md).

::: dbally.create_collection

::: dbally.collection.Collection

::: dbally.collection.results.ExecutionResult
Expand Down
1 change: 0 additions & 1 deletion docs/reference/index.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# dbally


::: dbally.create_collection
4 changes: 0 additions & 4 deletions docs/reference/iql/iql_generator.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
# IQLGenerator

::: dbally.iql_generator.iql_generator.IQLGenerator

::: dbally.iql_generator.iql_prompt_template.IQLPromptTemplate

::: dbally.iql_generator.iql_prompt_template.default_iql_template
4 changes: 0 additions & 4 deletions docs/reference/nl_responder.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,3 @@ Otherwise, a response is generated using a `nl_responder_prompt_template`.
To understand general idea better, visit the [NL Responder concept page](../concepts/nl_responder.md).

::: dbally.nl_responder.nl_responder.NLResponder

::: dbally.nl_responder.query_explainer_prompt_template

::: dbally.nl_responder.nl_responder_prompt_template.default_nl_responder_template
7 changes: 7 additions & 0 deletions docs/reference/prompt.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Prompt

::: dbally.prompt.template.PromptTemplate

::: dbally.prompt.template.PromptFormat

::: dbally.prompt.elements.FewShotExample
2 changes: 0 additions & 2 deletions docs/reference/view_selection/llm_view_selector.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# LLMViewSelector

::: dbally.view_selection.LLMViewSelector

::: dbally.view_selection.view_selector_prompt_template.default_view_selector_template
39 changes: 34 additions & 5 deletions examples/recruiting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,37 @@
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.llms.litellm import LiteLLM
from dbally.prompts import PromptTemplate
from dbally.prompt import PromptTemplate
from dbally.prompt.elements import FewShotExample
from dbally.prompt.template import PromptFormat

TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(

class Text2SQLPromptFormat(PromptFormat):
"""
Formats provided parameters to a form acceptable by SQL prompt.
"""

def __init__(
self,
*,
question: str,
schema: str,
examples: List[FewShotExample] = None,
) -> None:
"""
Constructs a new Text2SQLInputFormat instance.
Args:
question: Question to be asked.
schema: SQL schema description.
examples: List of examples to be injected into the conversation.
"""
super().__init__(examples)
self.question = question
self.schema = schema


TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate[Text2SQLPromptFormat](
(
{
"role": "system",
Expand Down Expand Up @@ -112,9 +140,10 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example
for question in benchmark.questions:
await recruitment_db.ask(question.dbally_question, return_natural_response=True)
gpt_question = question.gpt_question if question.gpt_question else question.dbally_question
gpt_response = await llm.generate_text(
TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_description, "question": gpt_question}, event_tracker=event_tracker
)

prompt_format = Text2SQLPromptFormat(question=gpt_question, schema=db_description)
formatted_prompt = TEXT2SQL_PROMPT_TEMPLATE.format_prompt(prompt_format)
gpt_response = await llm.generate_text(formatted_prompt, event_tracker=event_tracker)

print(f"GPT response: {gpt_response}")

Expand Down
2 changes: 1 addition & 1 deletion examples/recruiting/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy import and_, select

from dbally import SqlAlchemyBaseView, decorators
from dbally.prompts.elements import FewShotExample
from dbally.prompt.elements import FewShotExample

from .db import Candidate

Expand Down
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ nav:
- how-to/views/text-to-sql.md
- how-to/views/pandas.md
- how-to/views/custom.md
- how-to/views/few-shots.md
- Using LLMs:
- how-to/llms/litellm.md
- how-to/llms/custom.md
Expand Down Expand Up @@ -59,6 +60,7 @@ nav:
- LLMs:
- reference/llms/index.md
- reference/llms/litellm.md
- reference/prompt.md
- Similarity:
- reference/similarity/index.md
- Store:
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/assistants/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState
from dbally.collection import Collection
from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError
from dbally.iql_generator.prompt import UnsupportedQueryError

_DBALLY_INFO = "Dbally has access to the following database views: "

Expand Down
2 changes: 1 addition & 1 deletion src/dbally/audit/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

from dbally.collection.results import ExecutionResult
from dbally.prompts import ChatFormat
from dbally.prompt.template import ChatFormat


@dataclass
Expand Down
4 changes: 2 additions & 2 deletions src/dbally/gradio/gradio_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from dbally.audit import CLIEventHandler
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError
from dbally.prompts import PromptTemplateError
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.prompt.template import PromptTemplateError


async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface:
Expand Down
Loading

0 comments on commit 6510bd8

Please sign in to comment.