diff --git a/.coveragerc b/.coveragerc
index 2c70198f..5ba74cca 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -12,4 +12,5 @@ omit =
exclude_lines =
pragma: no cover
if __name__ == .__main__.
+ \.\.\.
show_missing = True
\ No newline at end of file
diff --git a/.github/ISSUE_TEMPLATE/01_feature_request.yml b/.github/ISSUE_TEMPLATE/01_feature_request.yml
index c5d3a360..90087c8b 100644
--- a/.github/ISSUE_TEMPLATE/01_feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/01_feature_request.yml
@@ -1,7 +1,7 @@
name: 🚀 Feature Request
description: Submit a proposal/request for a new db-ally feature.
title: "feat: "
-labels: ["enhancement"]
+labels: ["feature"]
body:
- type: markdown
attributes:
diff --git a/README.md b/README.md
index 9088bd15..0bbc2843 100644
--- a/README.md
+++ b/README.md
@@ -1,4 +1,4 @@
-#
db-ally
+# 🦮 db-ally
Efficient, consistent and secure library for querying structured data with natural language
diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py
index 9ba0871c..aa686727 100644
--- a/benchmark/dbally_benchmark/e2e_benchmark.py
+++ b/benchmark/dbally_benchmark/e2e_benchmark.py
@@ -23,9 +23,9 @@
import dbally
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
-from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template
+from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
-from dbally.view_selection.view_selector_prompt_template import default_view_selector_template
+from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE
async def _run_dbally_for_single_example(example: BIRDExample, collection: Collection) -> Text2SQLResult:
@@ -126,9 +126,9 @@ async def evaluate(cfg: DictConfig) -> Any:
logger.info(f"db-ally predictions saved under directory: {output_dir}")
if run:
- run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat)
- run["config/view_selection_prompt_template"] = stringify_unsupported(default_view_selector_template.chat)
- run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template)
+ run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat)
+ run["config/view_selection_prompt_template"] = stringify_unsupported(VIEW_SELECTION_TEMPLATE.chat)
+ run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE)
run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix())
run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix())
run["evaluation/metrics"] = stringify_unsupported(metrics)
diff --git a/benchmark/dbally_benchmark/iql_benchmark.py b/benchmark/dbally_benchmark/iql_benchmark.py
index adf33710..2557b2c2 100644
--- a/benchmark/dbally_benchmark/iql_benchmark.py
+++ b/benchmark/dbally_benchmark/iql_benchmark.py
@@ -21,7 +21,7 @@
from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_generator import IQLGenerator
-from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template
+from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, UnsupportedQueryError
from dbally.llms.litellm import LiteLLM
from dbally.views.structured import BaseStructuredView
@@ -33,13 +33,15 @@ async def _run_iql_for_single_example(
event_tracker = EventTracker()
try:
- iql_filters, _ = await iql_generator.generate_iql(
- question=example.question, filters=filter_list, event_tracker=event_tracker
+ iql_filters = await iql_generator.generate_iql(
+ question=example.question,
+ filters=filter_list,
+ event_tracker=event_tracker,
)
except UnsupportedQueryError:
return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True)
- return IQLResult(question=example.question, iql_filters=iql_filters, exception_raised=False)
+ return IQLResult(question=example.question, iql_filters=str(iql_filters), exception_raised=False)
async def run_iql_for_dataset(
@@ -139,7 +141,7 @@ async def evaluate(cfg: DictConfig) -> Any:
logger.info(f"IQL predictions saved under directory: {output_dir}")
if run:
- run["config/iql_prompt_template"] = stringify_unsupported(default_iql_template.chat)
+ run["config/iql_prompt_template"] = stringify_unsupported(IQL_GENERATION_TEMPLATE.chat)
run[f"evaluation/{metrics_file_name}"].upload((output_dir / metrics_file_name).as_posix())
run[f"evaluation/{results_file_name}"].upload((output_dir / results_file_name).as_posix())
run["evaluation/metrics"] = stringify_unsupported(metrics)
diff --git a/benchmark/dbally_benchmark/text2sql/prompt_template.py b/benchmark/dbally_benchmark/text2sql/prompt_template.py
index abee9659..60349f38 100644
--- a/benchmark/dbally_benchmark/text2sql/prompt_template.py
+++ b/benchmark/dbally_benchmark/text2sql/prompt_template.py
@@ -1,4 +1,4 @@
-from dbally.prompts import PromptTemplate
+from dbally.prompt import PromptTemplate
TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(
(
diff --git a/docs/about/roadmap.md b/docs/about/roadmap.md
index f6449c88..288aa359 100644
--- a/docs/about/roadmap.md
+++ b/docs/about/roadmap.md
@@ -10,14 +10,13 @@ Below you can find a list of planned features and integrations.
## Planned Features
- [ ] **Support analytical queries**: support for exposing operations beyond filtering.
-- [ ] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to
+- [x] **Few-shot prompting configuration**: allow users to configure the few-shot prompting in View definition to
improve IQL generation accuracy.
- [ ] **Request contextualization**: allow to provide extra context for db-ally runs, such as user asking the question.
- [X] **OpenAI Assistants API adapter**: allow to embed db-ally into OpenAI's Assistants API to easily extend the
capabilities of the assistant.
- [ ] **Langchain adapter**: allow to embed db-ally into Langchain applications.
-
## Integrations
Being agnostic to the underlying technology is one of the main goals of db-ally.
diff --git a/docs/assets/guide_dog_lg.png b/docs/assets/guide_dog_lg.png
new file mode 100644
index 00000000..dee16c22
Binary files /dev/null and b/docs/assets/guide_dog_lg.png differ
diff --git a/docs/assets/guide_dog_sm.png b/docs/assets/guide_dog_sm.png
new file mode 100644
index 00000000..85f91ee5
Binary files /dev/null and b/docs/assets/guide_dog_sm.png differ
diff --git a/docs/how-to/create_custom_event_handler.md b/docs/how-to/create_custom_event_handler.md
index 410973c5..d4c26f74 100644
--- a/docs/how-to/create_custom_event_handler.md
+++ b/docs/how-to/create_custom_event_handler.md
@@ -10,8 +10,7 @@ In this guide we will implement a simple [Event Handler](../reference/event_hand
First, we need to create a new class that inherits from `EventHandler` and implements the all abstract methods.
```python
-from dbally.audit import EventHandler
-from dbally.data_models.audit import RequestStart, RequestEnd
+from dbally.audit import EventHandler, RequestStart, RequestEnd
class FileEventHandler(EventHandler):
diff --git a/docs/how-to/llms/custom.md b/docs/how-to/llms/custom.md
index c262351d..7e249847 100644
--- a/docs/how-to/llms/custom.md
+++ b/docs/how-to/llms/custom.md
@@ -44,42 +44,29 @@ class MyLLMClient(LLMClient[LiteLLMOptions]):
async def call(
self,
- prompt: ChatFormat,
- response_format: Optional[Dict[str, str]],
+ conversation: ChatFormat,
options: LiteLLMOptions,
event: LLMEvent,
+ json_mode: bool = False,
) -> str:
# Your LLM API call
```
-The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response.
+The [`call`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient.call) method is an abstract method that must be implemented in your subclass. This method should call the LLM inference API and return the response in string format.
### Step 3: Use tokenizer to count tokens
-The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the messages. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model.
+The [`count_tokens`](../../reference/llms/index.md#dbally.llms.base.LLM.count_tokens) method is used to count the number of tokens in the prompt. You can override this method in your custom class to use the tokenizer and count tokens specifically for your model.
```python
class MyLLM(LLM[LiteLLMOptions]):
- def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int:
- # Count tokens in the messages in a custom way
+ def count_tokens(self, prompt: PromptTemplate) -> int:
+ # Count tokens in the prompt in a custom way
```
!!!warning
Incorrect token counting can cause problems in the [`NLResponder`](../../reference/nl_responder.md#dbally.nl_responder.nl_responder.NLResponder) and force the use of an explanation prompt template that is more generic and does not include specific rows from the IQL response.
-### Step 4: Define custom prompt formatting
-
-The [`format_prompt`](../../reference/llms/index.md#dbally.llms.base.LLM.format_prompt) method is used to apply formatting to the prompt template. You can override this method in your custom class to change how the formatting is performed.
-
-```python
-class MyLLM(LLM[LiteLLMOptions]):
-
- def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat:
- # Apply custom formatting to the prompt template
-```
-!!!note
- In general, implementation of this method is not required unless the LLM API does not support [OpenAI conversation formatting](https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages){:target="_blank"}. If your model API expects a different format, override this method to avoid issues with inference call.
-
## Customising LLM Options
[`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) is a class that defines the options your LLM will use. To create a custom options, you need to create a subclass of [`LLMOptions`](../../reference/llms/index.md#dbally.llms.clients.base.LLMOptions) and define the required properties that will be passed to the [`LLMClient`](../../reference/llms/index.md#dbally.llms.clients.base.LLMClient).
diff --git a/docs/how-to/llms/litellm.md b/docs/how-to/llms/litellm.md
index 03f41208..6d995af8 100644
--- a/docs/how-to/llms/litellm.md
+++ b/docs/how-to/llms/litellm.md
@@ -48,6 +48,24 @@ Integrate db-ally with your LLM vendor.
llm=LiteLLM(model_name="anyscale/meta-llama/Llama-2-70b-chat-hf")
```
+=== "Azure OpenAI"
+
+ ```python
+ import os
+ from dbally.llms.litellm import LiteLLM
+
+ ## set ENV variables
+ os.environ["AZURE_API_KEY"] = "your-api-key"
+ os.environ["AZURE_API_BASE"] = "your-api-base-url"
+ os.environ["AZURE_API_VERSION"] = "your-api-version"
+
+ # optional
+ os.environ["AZURE_AD_TOKEN"] = ""
+ os.environ["AZURE_API_TYPE"] = ""
+
+ llm = LiteLLM(model_name="azure/")
+ ```
+
Use LLM in your collection.
```python
diff --git a/docs/how-to/views/few-shots.md b/docs/how-to/views/few-shots.md
new file mode 100644
index 00000000..806ab171
--- /dev/null
+++ b/docs/how-to/views/few-shots.md
@@ -0,0 +1,97 @@
+# How-To: Define few shots
+
+There are many ways to improve the accuracy of IQL generation - one of them is to use few-shot prompting. db-ally allows you to inject few-shot examples for any type of defined view, both structured and freeform.
+
+Few shots are defined in the [`list_few_shots`](../../reference/views/index.md#dbally.views.base.BaseView.list_few_shots) method, each few shot example should be an instance of [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) class that defines example question and expected LLM answer.
+
+## Structured views
+
+For structured views, both questions and answers for [`FewShotExample`](../../reference/prompt.md#dbally.prompt.elements.FewShotExample) can be defined as a strings, whereas in case of answers Python expressions are also allowed (please see lambda function in example below).
+
+```python
+from dbally.prompt.elements import FewShotExample
+from dbally.views.sqlalchemy_base import SqlAlchemyBaseView
+
+class RecruitmentView(SqlAlchemyBaseView):
+ """
+ A view for retrieving candidates from the database.
+ """
+
+ def list_few_shots(self) -> List[FewShotExample]:
+ return [
+ FewShotExample(
+ "Which candidates studied at University of Toronto?",
+ 'studied_at("University of Toronto")',
+ ),
+ FewShotExample(
+ "Do we have any soon available perfect fits for senior data scientist positions?",
+ lambda: (
+ self.is_available_within_months(1)
+ and self.data_scientist_position()
+ and self.has_seniority("senior")
+ ),
+ ),
+ ...
+ ]
+```
+
+## Freeform views
+
+Currently freeform views accept SQL query syntax as a raw string. The larger variety of passing parameters is considered to be implemented in further db-ally releases.
+
+```python
+from dbally.prompt.elements import FewShotExample
+from dbally.views.freeform.text2sql import BaseText2SQLView
+
+class RecruitmentView(BaseText2SQLView):
+ """
+ A view for retrieving candidates from the database.
+ """
+
+ def list_few_shots(self) -> List[FewShotExample]:
+ return [
+ FewShotExample(
+ "Which candidates studied at University of Toronto?",
+ 'SELECT name FROM candidates WHERE university = "University of Toronto"',
+ ),
+ FewShotExample(
+ "Which clients are from NY?",
+ 'SELECT name FROM clients WHERE city = "NY"',
+ ),
+ ...
+ ]
+```
+
+## Prompt format
+
+By default each few shot is injected subsequent to a system prompt message. The format is as follows:
+
+```python
+[
+ {
+ "role" "user",
+ "content": "Question",
+ },
+ {
+ "role": "assistant",
+ "content": "Answer",
+ }
+]
+```
+
+If you use `examples` formatting tag in content field of the system or user message, all examples are going to be injected inside the message without additional conversation.
+
+The example of prompt utilizing `examples` tag:
+
+```python
+[
+ {
+ "role" "system",
+ "content": "Here are example resonses:\n {examples}",
+ },
+]
+```
+
+!!!info
+ There is no best way to inject a few shot example. Different models can behave diffrently based on few shots formatting of choice.
+ Generally, first appoach should yield the best results in most cases. Therefore, adding example tags in your custom prompts is not recommended.
diff --git a/docs/reference/collection.md b/docs/reference/collection.md
index cb9b4b97..c7b7269a 100644
--- a/docs/reference/collection.md
+++ b/docs/reference/collection.md
@@ -3,8 +3,6 @@
!!! tip
To understand the general idea better, visit the [Collection concept page](../concepts/collections.md).
-::: dbally.create_collection
-
::: dbally.collection.Collection
::: dbally.collection.results.ExecutionResult
diff --git a/docs/reference/event_handlers/index.md b/docs/reference/event_handlers/index.md
index ae69bc0d..f95f5798 100644
--- a/docs/reference/event_handlers/index.md
+++ b/docs/reference/event_handlers/index.md
@@ -10,10 +10,10 @@ db-ally provides an `EventHandler` abstract class that can be used to log the ru
Each run of [dbally.Collection.ask][dbally.Collection.ask] will trigger all instances of EventHandler that were passed to the Collection's constructor (or the [dbally.create_collection][dbally.create_collection] function).
-1. `EventHandler.request_start` is called with [RequestStart][dbally.data_models.audit.RequestStart], it can return a context object that will be passed to next calls.
+1. `EventHandler.request_start` is called with [RequestStart][dbally.audit.events.RequestStart], it can return a context object that will be passed to next calls.
2. For each event that occurs during the run, `EventHandler.event_start` is called with the context object returned by `EventHandler.request_start` and an Event object. It can return context for the `EventHandler.event_end` method.
3. When the event ends `EventHandler.event_end` is called with the context object returned by `EventHandler.event_start` and an Event object.
-4. On the end of the run `EventHandler.request_end` is called with the context object returned by `EventHandler.request_start` and the [RequestEnd][dbally.data_models.audit.RequestEnd].
+4. On the end of the run `EventHandler.request_end` is called with the context object returned by `EventHandler.request_start` and the [RequestEnd][dbally.audit.events.RequestEnd].
``` mermaid
@@ -42,8 +42,14 @@ Currently handled events:
::: dbally.audit.EventHandler
-::: dbally.data_models.audit.RequestStart
+::: dbally.audit.events.RequestStart
-::: dbally.data_models.audit.RequestEnd
+::: dbally.audit.events.RequestEnd
-::: dbally.data_models.audit.LLMEvent
+::: dbally.audit.events.Event
+
+::: dbally.audit.events.LLMEvent
+
+::: dbally.audit.events.SimilarityEvent
+
+::: dbally.audit.spans.EventSpan
diff --git a/docs/reference/index.md b/docs/reference/index.md
index 0deb591a..fa1abc4f 100644
--- a/docs/reference/index.md
+++ b/docs/reference/index.md
@@ -1,4 +1,3 @@
# dbally
-
::: dbally.create_collection
diff --git a/docs/reference/iql/iql_generator.md b/docs/reference/iql/iql_generator.md
index 15edcb56..b91a0b0c 100644
--- a/docs/reference/iql/iql_generator.md
+++ b/docs/reference/iql/iql_generator.md
@@ -1,7 +1,3 @@
# IQLGenerator
::: dbally.iql_generator.iql_generator.IQLGenerator
-
-::: dbally.iql_generator.iql_prompt_template.IQLPromptTemplate
-
-::: dbally.iql_generator.iql_prompt_template.default_iql_template
diff --git a/docs/reference/nl_responder.md b/docs/reference/nl_responder.md
index fb80741c..531243de 100644
--- a/docs/reference/nl_responder.md
+++ b/docs/reference/nl_responder.md
@@ -26,7 +26,3 @@ Otherwise, a response is generated using a `nl_responder_prompt_template`.
To understand general idea better, visit the [NL Responder concept page](../concepts/nl_responder.md).
::: dbally.nl_responder.nl_responder.NLResponder
-
-::: dbally.nl_responder.query_explainer_prompt_template
-
-::: dbally.nl_responder.nl_responder_prompt_template.default_nl_responder_template
diff --git a/docs/reference/prompt.md b/docs/reference/prompt.md
new file mode 100644
index 00000000..42ab8901
--- /dev/null
+++ b/docs/reference/prompt.md
@@ -0,0 +1,7 @@
+# Prompt
+
+::: dbally.prompt.template.PromptTemplate
+
+::: dbally.prompt.template.PromptFormat
+
+::: dbally.prompt.elements.FewShotExample
diff --git a/docs/reference/view_selection/llm_view_selector.md b/docs/reference/view_selection/llm_view_selector.md
index 774aa4b9..a177a8bd 100644
--- a/docs/reference/view_selection/llm_view_selector.md
+++ b/docs/reference/view_selection/llm_view_selector.md
@@ -1,5 +1,3 @@
# LLMViewSelector
::: dbally.view_selection.LLMViewSelector
-
-::: dbally.view_selection.view_selector_prompt_template.default_view_selector_template
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
index 35115f64..698837d0 100644
--- a/docs/stylesheets/extra.css
+++ b/docs/stylesheets/extra.css
@@ -1,3 +1,13 @@
:root {
--md-primary-fg-color: #00b0e0;
+}
+
+.md-header__button.md-logo {
+ margin: 0;
+ padding: 0;
+}
+
+.md-header__button.md-logo img, .md-header__button.md-logo svg {
+ height: 1.8rem;
+ width: 1.8rem;
}
\ No newline at end of file
diff --git a/examples/recruiting.py b/examples/recruiting.py
index a4813b41..ea16a934 100644
--- a/examples/recruiting.py
+++ b/examples/recruiting.py
@@ -9,9 +9,37 @@
from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.llms.litellm import LiteLLM
-from dbally.prompts import PromptTemplate
+from dbally.prompt import PromptTemplate
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import PromptFormat
-TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate(
+
+class Text2SQLPromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by SQL prompt.
+ """
+
+ def __init__(
+ self,
+ *,
+ question: str,
+ schema: str,
+ examples: List[FewShotExample] = None,
+ ) -> None:
+ """
+ Constructs a new Text2SQLInputFormat instance.
+
+ Args:
+ question: Question to be asked.
+ schema: SQL schema description.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
+ self.schema = schema
+
+
+TEXT2SQL_PROMPT_TEMPLATE = PromptTemplate[Text2SQLPromptFormat](
(
{
"role": "system",
@@ -112,9 +140,10 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example
for question in benchmark.questions:
await recruitment_db.ask(question.dbally_question, return_natural_response=True)
gpt_question = question.gpt_question if question.gpt_question else question.dbally_question
- gpt_response = await llm.generate_text(
- TEXT2SQL_PROMPT_TEMPLATE, {"schema": db_description, "question": gpt_question}, event_tracker=event_tracker
- )
+
+ prompt_format = Text2SQLPromptFormat(question=gpt_question, schema=db_description)
+ formatted_prompt = TEXT2SQL_PROMPT_TEMPLATE.format_prompt(prompt_format)
+ gpt_response = await llm.generate_text(formatted_prompt, event_tracker=event_tracker)
print(f"GPT response: {gpt_response}")
diff --git a/examples/recruiting/views.py b/examples/recruiting/views.py
index 22afb455..773d3f62 100644
--- a/examples/recruiting/views.py
+++ b/examples/recruiting/views.py
@@ -1,10 +1,13 @@
-from typing import Literal
+from datetime import date
+from typing import List, Literal
import awoc # pip install a-world-of-countries
import sqlalchemy
+from dateutil.relativedelta import relativedelta
from sqlalchemy import and_, select
from dbally import SqlAlchemyBaseView, decorators
+from dbally.prompt.elements import FewShotExample
from .db import Candidate
@@ -57,3 +60,43 @@ def is_from_continent( # pylint: disable=W0602, C0116, W9011
@decorators.view_filter()
def studied_at(self, university: str) -> sqlalchemy.ColumnElement: # pylint: disable=W0602, C0116, W9011
return Candidate.university == university
+
+
+class FewShotRecruitmentView(RecruitmentView):
+ """
+ A view for the recruitment database including examples of question:answers pairs (few-shot).
+ """
+
+ @decorators.view_filter()
+ def is_available_within_months( # pylint: disable=W0602, C0116, W9011
+ self, months: int
+ ) -> sqlalchemy.ColumnElement:
+ start = date.today()
+ end = start + relativedelta(months=months)
+ return Candidate.available_from.between(start, end)
+
+ def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011
+ return [
+ FewShotExample(
+ "Which candidates studied at University of Toronto?",
+ 'studied_at("University of Toronto")',
+ ),
+ FewShotExample(
+ "Do we have any soon available candidate?",
+ lambda: self.is_available_within_months(1),
+ ),
+ FewShotExample(
+ "Do we have any soon available perfect fits for senior data scientist positions?",
+ lambda: (
+ self.is_available_within_months(1)
+ and self.data_scientist_position()
+ and self.has_seniority("senior")
+ ),
+ ),
+ FewShotExample(
+ "List all junior or senior data scientist positions",
+ lambda: (
+ self.data_scientist_position() and (self.has_seniority("junior") or self.has_seniority("senior"))
+ ),
+ ),
+ ]
diff --git a/mkdocs.yml b/mkdocs.yml
index 35ad2f5c..826ffe15 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -21,6 +21,7 @@ nav:
- how-to/views/text-to-sql.md
- how-to/views/pandas.md
- how-to/views/custom.md
+ - how-to/views/few-shots.md
- Using LLMs:
- how-to/llms/litellm.md
- how-to/llms/custom.md
@@ -59,6 +60,7 @@ nav:
- LLMs:
- reference/llms/index.md
- reference/llms/litellm.md
+ - reference/prompt.md
- Similarity:
- reference/similarity/index.md
- Store:
@@ -80,6 +82,8 @@ nav:
- about/contact.md
theme:
name: material
+ logo: assets/guide_dog_lg.png
+ favicon: assets/guide_dog_sm.png
icon:
repo: fontawesome/brands/github
palette:
diff --git a/src/dbally/assistants/openai.py b/src/dbally/assistants/openai.py
index 8560cc95..4ec239df 100644
--- a/src/dbally/assistants/openai.py
+++ b/src/dbally/assistants/openai.py
@@ -6,7 +6,7 @@
from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState
from dbally.collection import Collection
-from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError
+from dbally.iql_generator.prompt import UnsupportedQueryError
_DBALLY_INFO = "Dbally has access to the following database views: "
diff --git a/src/dbally/audit/__init__.py b/src/dbally/audit/__init__.py
index af9a5384..73253f71 100644
--- a/src/dbally/audit/__init__.py
+++ b/src/dbally/audit/__init__.py
@@ -7,8 +7,19 @@
except ImportError:
pass
+from .event_tracker import EventTracker
+from .events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent
+from .spans import EventSpan
+
__all__ = [
"CLIEventHandler",
"LangSmithEventHandler",
+ "Event",
"EventHandler",
+ "EventTracker",
+ "EventSpan",
+ "LLMEvent",
+ "RequestEnd",
+ "RequestStart",
+ "SimilarityEvent",
]
diff --git a/src/dbally/audit/event_handlers/base.py b/src/dbally/audit/event_handlers/base.py
index 10fce0cf..dc3ea7f8 100644
--- a/src/dbally/audit/event_handlers/base.py
+++ b/src/dbally/audit/event_handlers/base.py
@@ -1,8 +1,8 @@
import abc
from abc import ABC
-from typing import Generic, TypeVar, Union
+from typing import Generic, Optional, TypeVar
-from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent
+from dbally.audit.events import Event, RequestEnd, RequestStart
RequestCtx = TypeVar("RequestCtx")
EventCtx = TypeVar("EventCtx")
@@ -26,7 +26,7 @@ async def request_start(self, user_request: RequestStart) -> RequestCtx:
"""
@abc.abstractmethod
- async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: RequestCtx) -> EventCtx:
+ async def event_start(self, event: Event, request_context: RequestCtx) -> EventCtx:
"""
Function that is called during every event execution.
@@ -40,9 +40,7 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con
"""
@abc.abstractmethod
- async def event_end(
- self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RequestCtx, event_context: EventCtx
- ) -> None:
+ async def event_end(self, event: Optional[Event], request_context: RequestCtx, event_context: EventCtx) -> None:
"""
Function that is called during every event execution.
diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py
index f738f90b..aa48e049 100644
--- a/src/dbally/audit/event_handlers/cli_event_handler.py
+++ b/src/dbally/audit/event_handlers/cli_event_handler.py
@@ -1,7 +1,7 @@
import re
from io import StringIO
from sys import stdout
-from typing import Optional, Union
+from typing import Optional
try:
from rich import print as pprint
@@ -15,7 +15,7 @@
pprint = print # type: ignore
from dbally.audit.event_handlers.base import EventHandler
-from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent
+from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent
_RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"}
_RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]"
@@ -40,14 +40,14 @@ class CLIEventHandler(EventHandler):
![Example output from CLIEventHandler](../../assets/event_handler_example.png)
"""
- def __init__(self, buffer: StringIO = None) -> None:
+ def __init__(self, buffer: Optional[StringIO] = None) -> None:
super().__init__()
self.buffer = buffer
out = self.buffer if buffer else stdout
self._console = Console(file=out, record=True) if RICH_OUTPUT else None
- def _print_syntax(self, content: str, lexer: str = None) -> None:
+ def _print_syntax(self, content: str, lexer: Optional[str] = None) -> None:
if self._console:
if lexer:
console_content = Syntax(content, lexer, word_wrap=True)
@@ -69,7 +69,7 @@ async def request_start(self, user_request: RequestStart) -> None:
self._print_syntax("[grey53]\n=======================================")
self._print_syntax("[grey53]=======================================\n")
- async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_context: None) -> None:
+ async def event_start(self, event: Event, request_context: None) -> None:
"""
Displays information that event has started, then all messages inside the prompt
@@ -98,9 +98,7 @@ async def event_start(self, event: Union[LLMEvent, SimilarityEvent], request_con
f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n"
)
- async def event_end(
- self, event: Union[None, LLMEvent, SimilarityEvent], request_context: None, event_context: None
- ) -> None:
+ async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None:
"""
Displays the response from the LLM.
diff --git a/src/dbally/audit/event_handlers/langsmith_event_handler.py b/src/dbally/audit/event_handlers/langsmith_event_handler.py
index 5974a068..c0b619c2 100644
--- a/src/dbally/audit/event_handlers/langsmith_event_handler.py
+++ b/src/dbally/audit/event_handlers/langsmith_event_handler.py
@@ -1,12 +1,12 @@
import socket
from getpass import getuser
-from typing import Optional, Union
+from typing import Optional
from langsmith.client import Client
from langsmith.run_trees import RunTree
from dbally.audit.event_handlers.base import EventHandler
-from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent
+from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent
class LangSmithEventHandler(EventHandler[RunTree, RunTree]):
@@ -47,7 +47,7 @@ async def request_start(self, user_request: RequestStart) -> RunTree:
return run_tree
- async def event_start(self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree) -> RunTree:
+ async def event_start(self, event: Event, request_context: RunTree) -> RunTree:
"""
Log the start of the event.
@@ -79,9 +79,7 @@ async def event_start(self, event: Union[None, LLMEvent, SimilarityEvent], reque
raise ValueError("Unsupported event")
- async def event_end(
- self, event: Union[None, LLMEvent, SimilarityEvent], request_context: RunTree, event_context: RunTree
- ) -> None:
+ async def event_end(self, event: Optional[Event], request_context: RunTree, event_context: RunTree) -> None:
"""
Log the end of the event.
diff --git a/src/dbally/audit/event_span.py b/src/dbally/audit/event_span.py
deleted file mode 100644
index c7cba584..00000000
--- a/src/dbally/audit/event_span.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from typing import Any, Optional, Union
-
-from dbally.data_models.audit import LLMEvent, SimilarityEvent
-
-
-class EventSpan:
- """Helper class for logging events."""
-
- data: Optional[Any]
-
- def __init__(self) -> None:
- self.data = None
-
- def __call__(self, data: Union[LLMEvent, SimilarityEvent]) -> None:
- """
- Call method for logging events.
-
- Args:
- data: Event data.
- """
-
- self.data = data
diff --git a/src/dbally/audit/event_tracker.py b/src/dbally/audit/event_tracker.py
index c483a65e..34faf803 100644
--- a/src/dbally/audit/event_tracker.py
+++ b/src/dbally/audit/event_tracker.py
@@ -1,9 +1,9 @@
from contextlib import asynccontextmanager
-from typing import AsyncIterator, Dict, List, Optional, Union
+from typing import AsyncIterator, Dict, List, Optional
from dbally.audit.event_handlers.base import EventHandler
-from dbally.audit.event_span import EventSpan
-from dbally.data_models.audit import LLMEvent, RequestEnd, RequestStart, SimilarityEvent
+from dbally.audit.events import Event, RequestEnd, RequestStart
+from dbally.audit.spans import EventSpan
class EventTracker:
@@ -69,7 +69,7 @@ def subscribe(self, event_handler: EventHandler) -> None:
self._handlers.append(event_handler)
@asynccontextmanager
- async def track_event(self, event: Union[LLMEvent, SimilarityEvent]) -> AsyncIterator[EventSpan]:
+ async def track_event(self, event: Event) -> AsyncIterator[EventSpan]:
"""
Context manager for processing an event.
diff --git a/src/dbally/data_models/audit.py b/src/dbally/audit/events.py
similarity index 82%
rename from src/dbally/data_models/audit.py
rename to src/dbally/audit/events.py
index c56e73be..de397a74 100644
--- a/src/dbally/data_models/audit.py
+++ b/src/dbally/audit/events.py
@@ -1,21 +1,20 @@
+from abc import ABC
from dataclasses import dataclass
-from enum import Enum
from typing import Optional, Union
from dbally.collection.results import ExecutionResult
-from dbally.prompts import ChatFormat
+from dbally.prompt.template import ChatFormat
-class EventType(Enum):
+@dataclass
+class Event(ABC):
"""
- Enum for event types.
+ Base class for all events.
"""
- LLM = "LLM"
-
@dataclass
-class LLMEvent:
+class LLMEvent(Event):
"""
Class for LLM event.
"""
@@ -30,7 +29,7 @@ class LLMEvent:
@dataclass
-class SimilarityEvent:
+class SimilarityEvent(Event):
"""
SimilarityEvent is fired when a SimilarityIndex lookup is performed.
"""
diff --git a/src/dbally/audit/spans.py b/src/dbally/audit/spans.py
new file mode 100644
index 00000000..0b9d273d
--- /dev/null
+++ b/src/dbally/audit/spans.py
@@ -0,0 +1,21 @@
+from typing import Optional
+
+from dbally.audit.events import Event
+
+
+class EventSpan:
+ """
+ Helper class for logging events.
+ """
+
+ def __init__(self) -> None:
+ self.data: Optional[Event] = None
+
+ def __call__(self, data: Event) -> None:
+ """
+ Call method for logging events.
+
+ Args:
+ data: Event data.
+ """
+ self.data = data
diff --git a/src/dbally/collection/collection.py b/src/dbally/collection/collection.py
index 2256606e..5e175cb5 100644
--- a/src/dbally/collection/collection.py
+++ b/src/dbally/collection/collection.py
@@ -7,9 +7,9 @@
from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.event_tracker import EventTracker
+from dbally.audit.events import RequestEnd, RequestStart
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ExecutionResult
-from dbally.data_models.audit import RequestEnd, RequestStart
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.nl_responder.nl_responder import NLResponder
diff --git a/src/dbally/data_models/__init__.py b/src/dbally/data_models/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/src/dbally/gradio/gradio_interface.py b/src/dbally/gradio/gradio_interface.py
index 30182b37..761b0dd2 100644
--- a/src/dbally/gradio/gradio_interface.py
+++ b/src/dbally/gradio/gradio_interface.py
@@ -9,8 +9,8 @@
from dbally.audit import CLIEventHandler
from dbally.collection import Collection
from dbally.collection.exceptions import NoViewFoundError
-from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError
-from dbally.prompts import PromptTemplateError
+from dbally.iql_generator.prompt import UnsupportedQueryError
+from dbally.prompt.template import PromptTemplateError
async def create_gradio_interface(user_collection: Collection, preview_limit: int = 10) -> gradio.Interface:
diff --git a/src/dbally/iql/_query.py b/src/dbally/iql/_query.py
index 39274f54..2099d39c 100644
--- a/src/dbally/iql/_query.py
+++ b/src/dbally/iql/_query.py
@@ -1,4 +1,5 @@
from typing import TYPE_CHECKING, List, Optional, Type
+from typing_extensions import Self
from ..audit.event_tracker import EventTracker
from . import syntax
@@ -16,8 +17,12 @@ class IQLQuery:
root: syntax.Node
- def __init__(self, root: syntax.Node):
+ def __init__(self, root: syntax.Node, source: str) -> None:
self.root = root
+ self._source = source
+
+ def __str__(self) -> str:
+ return self._source
@classmethod
async def parse(
@@ -26,7 +31,7 @@ async def parse(
allowed_functions: List["ExposedFunction"],
event_tracker: Optional[EventTracker] = None,
context: Optional[CustomContextsList] = None
- ) -> "IQLQuery":
+ ) -> Self:
"""
Parse IQL string to IQLQuery object.
@@ -37,4 +42,6 @@ async def parse(
Returns:
IQLQuery object
"""
- return cls(await IQLProcessor(source, allowed_functions, context, event_tracker).process())
+
+ root = await IQLProcessor(source, allowed_functions, context, event_tracker).process()
+ return cls(root=root, source=source)
diff --git a/src/dbally/iql_generator/iql_generator.py b/src/dbally/iql_generator/iql_generator.py
index 8633afc0..7eeb9154 100644
--- a/src/dbally/iql_generator/iql_generator.py
+++ b/src/dbally/iql_generator/iql_generator.py
@@ -1,12 +1,17 @@
-import copy
-from typing import Callable, List, Optional, Tuple, TypeVar
+from typing import List, Optional
from dbally.audit.event_tracker import EventTracker
-from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
+from dbally.iql import IQLError, IQLQuery
+from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import PromptTemplate
from dbally.views.exposed_functions import ExposedFunction
+ERROR_MESSAGE = "Unfortunately, generated IQL is not valid. Please try again, \
+ generation of correct IQL is very important. Below you have errors generated by the system:\n{error}"
+
class IQLGenerator:
"""
@@ -19,98 +24,61 @@ class IQLGenerator:
It uses LLM to generate text-based responses, passing in the prompt template, formatted filters, and user question.
"""
- _ERROR_MSG_PREFIX = "Unfortunately, generated IQL is not valid. Please try again, \
- generation of correct IQL is very important. Below you have errors generated by the system: \n"
-
- TException = TypeVar("TException", bound=Exception)
-
- def __init__(
- self,
- llm: LLM,
- prompt_template: Optional[IQLPromptTemplate] = None,
- promptify_view: Optional[Callable] = None,
- ) -> None:
+ def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[IQLGenerationPromptFormat]] = None) -> None:
"""
+ Constructs a new IQLGenerator instance.
+
Args:
llm: LLM used to generate IQL
- prompt_template: If not provided by the users is set to `default_iql_template`
- promptify_view: Function formatting filters for prompt
"""
self._llm = llm
- self._prompt_template = prompt_template or copy.deepcopy(default_iql_template)
- self._promptify_view = promptify_view or _promptify_filters
+ self._prompt_template = prompt_template or IQL_GENERATION_TEMPLATE
async def generate_iql(
self,
- filters: List[ExposedFunction],
question: str,
+ filters: List[ExposedFunction],
event_tracker: EventTracker,
- conversation: Optional[IQLPromptTemplate] = None,
+ examples: Optional[List[FewShotExample]] = None,
llm_options: Optional[LLMOptions] = None,
- ) -> Tuple[str, IQLPromptTemplate]:
+ n_retries: int = 3,
+ ) -> IQLQuery:
"""
- Uses LLM to generate IQL in text form
+ Generates IQL in text form using LLM.
Args:
- question: user question
- filters: list of filters exposed by the view
- event_tracker: event store used to audit the generation process
- conversation: conversation to be continued
- llm_options: options to use for the LLM client
+ question: User question.
+ filters: List of filters exposed by the view.
+ event_tracker: Event store used to audit the generation process.
+ examples: List of examples to be injected into the conversation.
+ llm_options: Options to use for the LLM client.
+ n_retries: Number of retries to regenerate IQL in case of errors.
Returns:
- IQL - iql generated based on the user question
+ Generated IQL query.
"""
- filters_for_prompt = self._promptify_view(filters)
-
- template = conversation or self._prompt_template
-
- llm_response = await self._llm.generate_text(
- template=template,
- fmt={"filters": filters_for_prompt, "question": question},
- event_tracker=event_tracker,
- options=llm_options,
+ prompt_format = IQLGenerationPromptFormat(
+ question=question,
+ filters=filters,
+ examples=examples,
)
-
- iql_filters = self._prompt_template.llm_response_parser(llm_response)
-
- if conversation is None:
- conversation = self._prompt_template
-
- conversation = conversation.add_assistant_message(content=llm_response)
-
- return iql_filters, conversation
-
- def add_error_msg(self, conversation: IQLPromptTemplate, errors: List[TException]) -> IQLPromptTemplate:
- """
- Appends to the conversation error messages returned due to the invalid IQL generated by the LLM.
-
- Args:
- conversation (IQLPromptTemplate): conversation containing current IQL generation trace
- errors (List[Exception]): errors to be appended
-
- Returns:
- IQLPromptTemplate: Conversation extended with errors
- """
-
- msg = self._ERROR_MSG_PREFIX
- for error in errors:
- msg += str(error) + "\n"
-
- return conversation.add_user_message(content=msg)
-
-
-def _promptify_filters(
- filters: List[ExposedFunction],
-) -> str:
- """
- Formats filters for prompt
-
- Args:
- filters: list of filters exposed by the view
-
- Returns:
- filters_for_prompt: filters formatted for prompt
- """
- filters_for_prompt = "\n".join([str(filter) for filter in filters])
- return filters_for_prompt
+ formatted_prompt = self._prompt_template.format_prompt(prompt_format)
+
+ for _ in range(n_retries + 1):
+ try:
+ response = await self._llm.generate_text(
+ prompt=formatted_prompt,
+ event_tracker=event_tracker,
+ options=llm_options,
+ )
+ # TODO: Move response parsing to llm generate_text method
+ iql = formatted_prompt.response_parser(response)
+ # TODO: Move IQL query parsing to prompt response parser
+ return await IQLQuery.parse(
+ source=iql,
+ allowed_functions=filters,
+ event_tracker=event_tracker,
+ )
+ except IQLError as exc:
+ formatted_prompt = formatted_prompt.add_assistant_message(response)
+ formatted_prompt = formatted_prompt.add_user_message(ERROR_MESSAGE.format(error=exc))
diff --git a/src/dbally/iql_generator/prompt.py b/src/dbally/iql_generator/prompt.py
new file mode 100644
index 00000000..44bb2cd4
--- /dev/null
+++ b/src/dbally/iql_generator/prompt.py
@@ -0,0 +1,87 @@
+# pylint: disable=C0301
+
+from typing import List
+
+from dbally.exceptions import DbAllyError
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import PromptFormat, PromptTemplate
+from dbally.views.exposed_functions import ExposedFunction
+
+
+class UnsupportedQueryError(DbAllyError):
+ """
+ Error raised when IQL generator is unable to construct a query
+ with given filters.
+ """
+
+
+def _validate_iql_response(llm_response: str) -> str:
+ """
+ Validates LLM response to IQL
+
+ Args:
+ llm_response: LLM response
+
+ Returns:
+ A string containing IQL for filters.
+
+ Raises:
+ UnsuppotedQueryError: When IQL generator is unable to construct a query
+ with given filters.
+ """
+ if "unsupported query" in llm_response.lower():
+ raise UnsupportedQueryError
+ return llm_response
+
+
+class IQLGenerationPromptFormat(PromptFormat):
+ """
+ IQL prompt format, providing a question and filters to be used in the conversation.
+ """
+
+ def __init__(
+ self,
+ *,
+ question: str,
+ filters: List[ExposedFunction],
+ examples: List[FewShotExample] = None,
+ ) -> None:
+ """
+ Constructs a new IQLGenerationPromptFormat instance.
+
+ Args:
+ question: Question to be asked.
+ filters: List of filters exposed by the view.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
+ self.filters = "\n".join([str(filter) for filter in filters])
+
+
+IQL_GENERATION_TEMPLATE = PromptTemplate[IQLGenerationPromptFormat](
+ [
+ {
+ "role": "system",
+ "content": (
+ "You have access to API that lets you query a database:\n"
+ "\n{filters}\n"
+ "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
+ "Remember! Don't give any comments, just the function calls.\n"
+ "The output will look like this:\n"
+ 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
+ "DO NOT INCLUDE arguments names in your response. Only the values.\n"
+ "You MUST use only these methods:\n"
+ "\n{filters}\n"
+ "It is VERY IMPORTANT not to use methods other than those listed above."
+ """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """
+ "This is CRUCIAL, otherwise the system will crash. "
+ ),
+ },
+ {
+ "role": "user",
+ "content": "{question}",
+ },
+ ],
+ response_parser=_validate_iql_response,
+)
diff --git a/src/dbally/llms/base.py b/src/dbally/llms/base.py
index e570547c..7e2381e1 100644
--- a/src/dbally/llms/base.py
+++ b/src/dbally/llms/base.py
@@ -1,12 +1,11 @@
from abc import ABC, abstractmethod
from functools import cached_property
-from typing import Dict, Generic, Optional, Type
+from typing import Generic, Optional, Type
from dbally.audit.event_tracker import EventTracker
-from dbally.data_models.audit import LLMEvent
+from dbally.audit.events import LLMEvent
from dbally.llms.clients.base import LLMClient, LLMClientOptions, LLMOptions
-from dbally.prompts.common_validation_utils import ChatFormat
-from dbally.prompts.prompt_template import PromptTemplate
+from dbally.prompt.template import PromptTemplate
class LLM(Generic[LLMClientOptions], ABC):
@@ -41,36 +40,21 @@ def client(self) -> LLMClient:
Client for the LLM.
"""
- def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFormat:
+ def count_tokens(self, prompt: PromptTemplate) -> int:
"""
- Applies formatting to the prompt template.
+ Counts tokens in the prompt.
Args:
- template: Prompt template in system/user/assistant openAI format.
- fmt: Dictionary with formatting.
+ prompt: Formatted prompt template with conversation and response parsing configuration.
Returns:
- Prompt in the format of the client.
+ Number of tokens in the prompt.
"""
- return [{**message, "content": message["content"].format(**fmt)} for message in template.chat]
-
- def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int:
- """
- Counts tokens in the messages.
-
- Args:
- messages: Messages to count tokens for.
- fmt: Arguments to be used with prompt.
-
- Returns:
- Number of tokens in the messages.
- """
- return sum(len(message["content"].format(**fmt)) for message in messages)
+ return sum(len(message["content"]) for message in prompt.chat)
async def generate_text(
self,
- template: PromptTemplate,
- fmt: Dict[str, str],
+ prompt: PromptTemplate,
*,
event_tracker: Optional[EventTracker] = None,
options: Optional[LLMOptions] = None,
@@ -79,8 +63,7 @@ async def generate_text(
Prepares and sends a prompt to the LLM and returns the response.
Args:
- template: Prompt template in system/user/assistant openAI format.
- fmt: Dictionary with formatting.
+ prompt: Formatted prompt template with conversation and response parsing configuration.
event_tracker: Event store used to audit the generation process.
options: Options to use for the LLM client.
@@ -88,16 +71,15 @@ async def generate_text(
Text response from LLM.
"""
options = (self.default_options | options) if options else self.default_options
- prompt = self.format_prompt(template, fmt)
- event = LLMEvent(prompt=prompt, type=type(template).__name__)
+ event = LLMEvent(prompt=prompt.chat, type=type(prompt).__name__)
event_tracker = event_tracker or EventTracker()
async with event_tracker.track_event(event) as span:
event.response = await self.client.call(
- prompt=prompt,
- response_format=template.response_format,
+ conversation=prompt.chat,
options=options,
event=event,
+ json_mode=prompt.json_mode,
)
span(event)
diff --git a/src/dbally/llms/clients/base.py b/src/dbally/llms/clients/base.py
index bc55f6ea..0293390f 100644
--- a/src/dbally/llms/clients/base.py
+++ b/src/dbally/llms/clients/base.py
@@ -2,8 +2,8 @@
from dataclasses import asdict, dataclass
from typing import Any, ClassVar, Dict, Generic, Optional, TypeVar
-from dbally.data_models.audit import LLMEvent
-from dbally.prompts import ChatFormat
+from dbally.audit.events import LLMEvent
+from dbally.prompt.template import ChatFormat
from ..._types import NotGiven
@@ -67,19 +67,19 @@ def __init__(self, model_name: str) -> None:
@abstractmethod
async def call(
self,
- prompt: ChatFormat,
- response_format: Optional[Dict[str, str]],
+ conversation: ChatFormat,
options: LLMClientOptions,
event: LLMEvent,
+ json_mode: bool = False,
) -> str:
"""
Calls LLM inference API.
Args:
- prompt: Prompt passed to the LLM.
- response_format: Optional argument used in the OpenAI API - used to force a json output
+ conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
options: Additional settings used by LLM.
event: LLMEvent instance which fields should be filled during the method execution.
+ json_mode: Force the response to be in JSON format.
Returns:
Response string from LLM.
diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py
index 3ec4ccc9..1e23df91 100644
--- a/src/dbally/llms/clients/litellm.py
+++ b/src/dbally/llms/clients/litellm.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import Dict, List, Optional, Union
+from typing import List, Optional, Union
try:
import litellm
@@ -9,10 +9,10 @@
HAVE_LITELLM = False
-from dbally.data_models.audit import LLMEvent
+from dbally.audit.events import LLMEvent
from dbally.llms.clients.base import LLMClient, LLMOptions
from dbally.llms.clients.exceptions import LLMConnectionError, LLMResponseError, LLMStatusError
-from dbally.prompts import ChatFormat
+from dbally.prompt.template import ChatFormat
from ..._types import NOT_GIVEN, NotGiven
@@ -72,19 +72,19 @@ def __init__(
async def call(
self,
- prompt: ChatFormat,
- response_format: Optional[Dict[str, str]],
+ conversation: ChatFormat,
options: LiteLLMOptions,
event: LLMEvent,
+ json_mode: bool = False,
) -> str:
"""
Calls the appropriate LLM endpoint with the given prompt and options.
Args:
- prompt: Prompt as an OpenAI client style list.
- response_format: Optional argument used in the OpenAI API - used to force the json output
+ conversation: List of dicts with "role" and "content" keys, representing the chat history so far.
options: Additional settings used by the LLM.
event: Container with the prompt, LLM response and call metrics.
+ json_mode: Force the response to be in JSON format.
Returns:
Response string from LLM.
@@ -94,9 +94,11 @@ async def call(
LLMStatusError: If the LLM API returns an error status code.
LLMResponseError: If the LLM API response is invalid.
"""
+ response_format = {"type": "json_object"} if json_mode else None
+
try:
response = await litellm.acompletion(
- messages=prompt,
+ messages=conversation,
model=self.model_name,
base_url=self.base_url,
api_key=self.api_key,
diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py
index c5699a1e..077474e9 100644
--- a/src/dbally/llms/litellm.py
+++ b/src/dbally/llms/litellm.py
@@ -1,5 +1,5 @@
from functools import cached_property
-from typing import Dict, Optional
+from typing import Optional
try:
import litellm
@@ -10,7 +10,7 @@
from dbally.llms.base import LLM
from dbally.llms.clients.litellm import LiteLLMClient, LiteLLMOptions
-from dbally.prompts import ChatFormat
+from dbally.prompt.template import PromptTemplate
class LiteLLM(LLM[LiteLLMOptions]):
@@ -65,17 +65,14 @@ def client(self) -> LiteLLMClient:
api_version=self.api_version,
)
- def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int:
+ def count_tokens(self, prompt: PromptTemplate) -> int:
"""
- Counts tokens in the messages using a specified model.
+ Counts tokens in the prompt.
Args:
- messages: Messages to count tokens for.
- fmt: Arguments to be used with prompt.
+ prompt: Formatted prompt template with conversation and response parsing configuration.
Returns:
- Number of tokens in the messages.
+ Number of tokens in the prompt.
"""
- return sum(
- litellm.token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages
- )
+ return sum(litellm.token_counter(model=self.model_name, text=message["content"]) for message in prompt.chat)
diff --git a/src/dbally/nl_responder/nl_responder.py b/src/dbally/nl_responder/nl_responder.py
index 8bcafb11..7a8f98e4 100644
--- a/src/dbally/nl_responder/nl_responder.py
+++ b/src/dbally/nl_responder/nl_responder.py
@@ -1,48 +1,44 @@
-import copy
-from typing import Dict, List, Optional
-
-import pandas as pd
+from typing import Optional
from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
-from dbally.nl_responder.nl_responder_prompt_template import NLResponderPromptTemplate, default_nl_responder_template
-from dbally.nl_responder.query_explainer_prompt_template import (
- QueryExplainerPromptTemplate,
- default_query_explainer_template,
+from dbally.nl_responder.prompts import (
+ NL_RESPONSE_TEMPLATE,
+ QUERY_EXPLANATION_TEMPLATE,
+ NLResponsePromptFormat,
+ QueryExplanationPromptFormat,
)
+from dbally.prompt.template import PromptTemplate
class NLResponder:
- """Class used to generate natural language response from the database output."""
-
- # Keys used to extract the query from the context (ordered by priority)
- QUERY_KEYS = ["iql", "sql", "query"]
+ """
+ Class used to generate natural language response from the database output.
+ """
def __init__(
self,
llm: LLM,
- query_explainer_prompt_template: Optional[QueryExplainerPromptTemplate] = None,
- nl_responder_prompt_template: Optional[NLResponderPromptTemplate] = None,
+ prompt_template: Optional[PromptTemplate[NLResponsePromptFormat]] = None,
+ explainer_prompt_template: Optional[PromptTemplate[QueryExplanationPromptFormat]] = None,
max_tokens_count: int = 4096,
) -> None:
"""
+ Constructs a new NLResponder instance.
+
Args:
- llm: LLM used to generate natural language response
- query_explainer_prompt_template: template for the prompt used to generate the iql explanation
- if not set defaults to `default_query_explainer_template`
- nl_responder_prompt_template: template for the prompt used to generate the NL response
- if not set defaults to `nl_responder_prompt_template`
- max_tokens_count: maximum number of tokens that can be used in the prompt
+ llm: LLM used to generate natural language response.
+ prompt_template: Template for the prompt used to generate the NL response
+ if not set defaults to `NL_RESPONSE_TEMPLATE`.
+ explainer_prompt_template: Template for the prompt used to generate the iql explanation
+ if not set defaults to `QUERY_EXPLANATION_TEMPLATE`.
+ max_tokens_count: Maximum number of tokens that can be used in the prompt.
"""
self._llm = llm
- self._nl_responder_prompt_template = nl_responder_prompt_template or copy.deepcopy(
- default_nl_responder_template
- )
- self._query_explainer_prompt_template = query_explainer_prompt_template or copy.deepcopy(
- default_query_explainer_template
- )
+ self._prompt_template = prompt_template or NL_RESPONSE_TEMPLATE
+ self._explainer_prompt_template = explainer_prompt_template or QUERY_EXPLANATION_TEMPLATE
self._max_tokens_count = max_tokens_count
async def generate_response(
@@ -56,53 +52,38 @@ async def generate_response(
Uses LLM to generate a response in natural language form.
Args:
- result: object representing the result of the query execution
- question: user question
- event_tracker: event store used to audit the generation process
- llm_options: options to use for the LLM client.
+ result: Object representing the result of the query execution.
+ question: User question.
+ event_tracker: Event store used to audit the generation process.
+ llm_options: Options to use for the LLM client.
Returns:
Natural language response to the user question.
"""
- rows = _promptify_rows(result.results)
-
- tokens_count = self._llm.count_tokens(
- messages=self._nl_responder_prompt_template.chat,
- fmt={"rows": rows, "question": question},
+ prompt_format = NLResponsePromptFormat(
+ question=question,
+ results=result.results,
)
+ formatted_prompt = self._prompt_template.format_prompt(prompt_format)
+ tokens_count = self._llm.count_tokens(formatted_prompt)
if tokens_count > self._max_tokens_count:
- context = result.context
- query = next((context.get(key) for key in self.QUERY_KEYS if context.get(key)), question)
+ prompt_format = QueryExplanationPromptFormat(
+ question=question,
+ context=result.context,
+ results=result.results,
+ )
+ formatted_prompt = self._explainer_prompt_template.format_prompt(prompt_format)
llm_response = await self._llm.generate_text(
- template=self._query_explainer_prompt_template,
- fmt={"question": question, "query": query, "number_of_results": len(result.results)},
+ prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
-
return llm_response
llm_response = await self._llm.generate_text(
- template=self._nl_responder_prompt_template,
- fmt={"rows": _promptify_rows(result.results), "question": question},
+ prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
return llm_response
-
-
-def _promptify_rows(rows: List[Dict]) -> str:
- """
- Formats rows into a markdown table.
-
- Args:
- rows: list of rows to be formatted
-
- Returns:
- str: formatted rows
- """
-
- df = pd.DataFrame.from_records(rows)
-
- return df.to_markdown(index=False, headers="keys", tablefmt="psql")
diff --git a/src/dbally/nl_responder/nl_responder_prompt_template.py b/src/dbally/nl_responder/nl_responder_prompt_template.py
deleted file mode 100644
index 9e6e687e..00000000
--- a/src/dbally/nl_responder/nl_responder_prompt_template.py
+++ /dev/null
@@ -1,47 +0,0 @@
-from typing import Callable, Dict, Optional
-
-from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables
-
-
-class NLResponderPromptTemplate(PromptTemplate):
- """
- Class for prompt templates meant for the natural response.
- """
-
- def __init__(
- self,
- chat: ChatFormat,
- response_format: Optional[Dict[str, str]] = None,
- llm_response_parser: Callable = lambda x: x,
- ) -> None:
- """
- Initializes NLResponderPromptTemplate class.
-
- Args:
- chat: chat format
- response_format: response format
- llm_response_parser: function to parse llm response
- """
-
- super().__init__(chat, response_format, llm_response_parser)
- self.chat = check_prompt_variables(chat, {"rows", "question"})
-
-
-default_nl_responder_template = NLResponderPromptTemplate(
- chat=(
- {
- "role": "system",
- "content": "You are a helpful assistant that helps answer the user's questions "
- "based on the table provided. You MUST use the table to answer the question. "
- "You are very intelligent and obedient.\n"
- "The table ALWAYS contains full answer to a question.\n"
- "Answer the question in a way that is easy to understand and informative.\n"
- "DON'T MENTION using a table in your answer.",
- },
- {
- "role": "user",
- "content": "The table below represents the answer to a question: {question}.\n"
- "{rows}\nAnswer the question: {question}.",
- },
- )
-)
diff --git a/src/dbally/nl_responder/prompts.py b/src/dbally/nl_responder/prompts.py
new file mode 100644
index 00000000..f99a8a6c
--- /dev/null
+++ b/src/dbally/nl_responder/prompts.py
@@ -0,0 +1,111 @@
+from typing import Any, Dict, List
+
+import pandas as pd
+
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import PromptFormat, PromptTemplate
+
+
+class NLResponsePromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by default NL response prompt.
+ """
+
+ def __init__(
+ self,
+ *,
+ question: str,
+ results: List[Dict[str, Any]],
+ examples: List[FewShotExample] = None,
+ ) -> None:
+ """
+ Constructs a new IQLGenerationPromptFormat instance.
+
+ Args:
+ question: Question to be asked.
+ filters: List of filters exposed by the view.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
+ self.results = pd.DataFrame.from_records(results).to_markdown(index=False, headers="keys", tablefmt="psql")
+
+
+class QueryExplanationPromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by default query explanation prompt.
+ """
+
+ def __init__(
+ self,
+ *,
+ question: str,
+ context: Dict[str, Any],
+ results: List[Dict[str, Any]],
+ examples: List[FewShotExample] = None,
+ ) -> None:
+ """
+ Constructs a new QueryExplanationPromptFormat instance.
+
+ Args:
+ question: Question to be asked.
+ context: Context of the query.
+ results: List of results returned by the query.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
+ self.query = next((context.get(key) for key in ("iql", "sql", "query") if context.get(key)), question)
+ self.number_of_results = len(results)
+
+
+NL_RESPONSE_TEMPLATE = PromptTemplate[NLResponsePromptFormat](
+ [
+ {
+ "role": "system",
+ "content": (
+ "You are a helpful assistant that helps answer the user's questions "
+ "based on the table provided. You MUST use the table to answer the question. "
+ "You are very intelligent and obedient.\n"
+ "The table ALWAYS contains full answer to a question.\n"
+ "Answer the question in a way that is easy to understand and informative.\n"
+ "DON'T MENTION using a table in your answer."
+ ),
+ },
+ {
+ "role": "user",
+ "content": (
+ "The table below represents the answer to a question: {question}.\n"
+ "{results}\n"
+ "Answer the question: {question}."
+ ),
+ },
+ ],
+)
+
+QUERY_EXPLANATION_TEMPLATE = PromptTemplate[QueryExplanationPromptFormat](
+ [
+ {
+ "role": "system",
+ "content": (
+ "You are a helpful assistant that helps describe a table generated by a query "
+ "that answers users' question. "
+ "You are very intelligent and obedient.\n"
+ "Your task is to provide natural language description of the table used by the logical query "
+ "to the database.\n"
+ "Describe the table in a way that is short and informative.\n"
+ "Make your answer as short as possible, start it by infroming the user that the underlying "
+ "data is too long to print and then describe the table based on the question and the query.\n"
+ "DON'T MENTION using a query in your answer."
+ ),
+ },
+ {
+ "role": "user",
+ "content": (
+ "The query below represents the answer to a question: {question}.\n"
+ "Describe the table generated using this query: {query}.\n"
+ "Number of results to this query: {number_of_results}."
+ ),
+ },
+ ],
+)
diff --git a/src/dbally/nl_responder/query_explainer_prompt_template.py b/src/dbally/nl_responder/query_explainer_prompt_template.py
deleted file mode 100644
index 00a3e6a6..00000000
--- a/src/dbally/nl_responder/query_explainer_prompt_template.py
+++ /dev/null
@@ -1,48 +0,0 @@
-from typing import Callable, Dict, Optional
-
-from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables
-
-
-class QueryExplainerPromptTemplate(PromptTemplate):
- """
- Class for prompt templates meant to generate explanations for queries
- (when the data cannot be shown due to token limit).
-
- Args:
- chat: chat format
- response_format: response format
- llm_response_parser: function to parse llm response
- """
-
- def __init__(
- self,
- chat: ChatFormat,
- response_format: Optional[Dict[str, str]] = None,
- llm_response_parser: Callable = lambda x: x,
- ) -> None:
- super().__init__(chat, response_format, llm_response_parser)
- self.chat = check_prompt_variables(chat, {"question", "query", "number_of_results"})
-
-
-default_query_explainer_template = QueryExplainerPromptTemplate(
- chat=(
- {
- "role": "system",
- "content": "You are a helpful assistant that helps describe a table generated by a query "
- "that answers users' question. "
- "You are very intelligent and obedient.\n"
- "Your task is to provide natural language description of the table used by the logical query "
- "to the database.\n"
- "Describe the table in a way that is short and informative.\n"
- "Make your answer as short as possible, start it by infroming the user that the underlying "
- "data is too long to print and then describe the table based on the question and the query.\n"
- "DON'T MENTION using a query in your answer.\n",
- },
- {
- "role": "user",
- "content": "The query below represents the answer to a question: {question}.\n"
- "Describe the table generated using this query: {query}.\n"
- "Number of results to this query: {number_of_results}.\n",
- },
- )
-)
diff --git a/src/dbally/prompt/__init__.py b/src/dbally/prompt/__init__.py
new file mode 100644
index 00000000..61495d33
--- /dev/null
+++ b/src/dbally/prompt/__init__.py
@@ -0,0 +1,3 @@
+from .template import ChatFormat, PromptTemplate, PromptTemplateError
+
+__all__ = ["PromptTemplate", "PromptTemplateError", "ChatFormat"]
diff --git a/src/dbally/prompt/elements.py b/src/dbally/prompt/elements.py
new file mode 100644
index 00000000..37375508
--- /dev/null
+++ b/src/dbally/prompt/elements.py
@@ -0,0 +1,61 @@
+import inspect
+import re
+import textwrap
+from typing import Callable, Union
+
+
+class FewShotExample:
+ """
+ A question:answer representation for few-shot prompting
+ """
+
+ def __init__(self, question: str, answer_expr: Union[str, Callable]) -> None:
+ """
+ Args:
+ question: sample question
+ answer_expr: it can be either a stringified expression or a lambda for greater safety and code completions.
+
+ Raises:
+ ValueError: If answer_expr is not a correct type.
+ """
+ self.question = question
+ self.answer_expr = answer_expr
+
+ if isinstance(self.answer_expr, str):
+ self.answer = self.answer_expr
+ elif callable(answer_expr):
+ self.answer = self._parse_lambda(answer_expr)
+ else:
+ raise ValueError("Answer expression should be either a string or a lambda")
+
+ def _parse_lambda(self, expr: Callable) -> str:
+ """
+ Parses provided callable in order to extract the lambda code.
+ All comments and references to variables like `self` etc will be removed
+ to form a simple lambda representation.
+
+ Args:
+ expr: lambda expression to parse
+
+ Returns:
+ Parsed lambda in a form of cleaned up string
+ """
+ # extract lambda from code
+ expr_source = textwrap.dedent(inspect.getsource(expr))
+ expr_body = expr_source.replace("lambda:", "")
+
+ # clean up by removing comments, new lines, free vars (self etc)
+ parsed_expr = re.sub("\\#.*\n", "\n", expr_body, flags=re.MULTILINE)
+
+ for m_name in expr.__code__.co_names:
+ parsed_expr = parsed_expr.replace(f"{expr.__code__.co_freevars[0]}.{m_name}", m_name)
+
+ # clean up any dangling commas or leading and trailing brackets
+ parsed_expr = " ".join(parsed_expr.split()).strip().rstrip(",").replace("( ", "(").replace(" )", ")")
+ if parsed_expr.startswith("("):
+ parsed_expr = parsed_expr[1:-1]
+
+ return parsed_expr
+
+ def __str__(self) -> str:
+ return f"{self.question} -> {self.answer}"
diff --git a/src/dbally/prompt/template.py b/src/dbally/prompt/template.py
new file mode 100644
index 00000000..124a3e1c
--- /dev/null
+++ b/src/dbally/prompt/template.py
@@ -0,0 +1,234 @@
+import copy
+import re
+from typing import Callable, Dict, Generic, List, TypeVar
+
+from typing_extensions import Self
+
+from dbally.exceptions import DbAllyError
+from dbally.prompt.elements import FewShotExample
+
+ChatFormat = List[Dict[str, str]]
+
+
+class PromptTemplateError(DbAllyError):
+ """
+ Error raised on incorrect PromptTemplate construction.
+ """
+
+
+def _check_chat_order(chat: ChatFormat) -> ChatFormat:
+ """
+ Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating).
+
+ Args:
+ chat: Chat template
+
+ Raises:
+ PromptTemplateError: if chat template is not constructed correctly.
+
+ Returns:
+ Chat template
+ """
+ if len(chat) == 0:
+ raise PromptTemplateError("Template should not be empty")
+
+ expected_order = ["user", "assistant"]
+ for i, message in enumerate(chat):
+ role = message["role"]
+ if role == "system":
+ if i != 0:
+ raise PromptTemplateError("Only first message should come from system")
+ continue
+ index = i % len(expected_order)
+ if role != expected_order[index - 1]:
+ raise PromptTemplateError(
+ "Template format is not correct. It should be system, and then user/assistant alternating."
+ )
+
+ if expected_order[index] not in ["user", "assistant"]:
+ raise PromptTemplateError("Template needs to end on either user or assistant turn")
+ return chat
+
+
+class PromptFormat:
+ """
+ Generic format for prompts allowing to inject few shot examples into the conversation.
+ """
+
+ def __init__(self, examples: List[FewShotExample] = None) -> None:
+ """
+ Constructs a new PromptFormat instance.
+
+ Args:
+ examples: List of examples to be injected into the conversation.
+ """
+ self.examples = examples or []
+
+
+PromptFormatT = TypeVar("PromptFormatT", bound=PromptFormat)
+
+
+class PromptTemplate(Generic[PromptFormatT]):
+ """
+ Class for prompt templates.
+ """
+
+ def __init__(
+ self,
+ chat: ChatFormat,
+ *,
+ json_mode: bool = False,
+ response_parser: Callable = lambda x: x,
+ ) -> None:
+ """
+ Constructs a new PromptTemplate instance.
+
+ Args:
+ chat: Chat-formatted conversation template.
+ json_mode: Whether to enforce JSON response from LLM.
+ response_parser: Function parsing the LLM response into the desired format.
+ """
+ self.chat: ChatFormat = _check_chat_order(chat)
+ self.json_mode = json_mode
+ self.response_parser = response_parser
+
+ def __eq__(self, other: "PromptTemplate") -> bool:
+ return isinstance(other, PromptTemplate) and self.chat == other.chat
+
+ def _has_variable(self, variable: str) -> bool:
+ """
+ Validates a given chat to make sure it contains variables required.
+
+ Args:
+ variable: Variable to check.
+
+ Returns:
+ True if the variable is present in the chat.
+ """
+ for message in self.chat:
+ if re.match(rf"{{{variable}}}", message["content"]):
+ return True
+ return False
+
+ def format_prompt(self, prompt_format: PromptFormatT) -> Self:
+ """
+ Applies formatting to the prompt template chat contents.
+
+ Args:
+ prompt_format: Format to be applied to the prompt.
+
+ Returns:
+ PromptTemplate with formatted chat contents.
+ """
+ formatted_prompt = copy.deepcopy(self)
+ formatting = dict(prompt_format.__dict__)
+
+ if self._has_variable("examples"):
+ formatting["examples"] = "\n".join(prompt_format.examples)
+ else:
+ formatted_prompt = formatted_prompt.clear_few_shot_messages()
+ for example in prompt_format.examples:
+ formatted_prompt = formatted_prompt.add_few_shot_message(example)
+
+ formatted_prompt.chat = [
+ {
+ "role": message.get("role"),
+ "content": message.get("content").format(**formatting),
+ "is_example": message.get("is_example", False),
+ }
+ for message in formatted_prompt.chat
+ ]
+ return formatted_prompt
+
+ def set_system_message(self, content: str) -> Self:
+ """
+ Sets a system message to the template prompt.
+
+ Args:
+ content: Message to be added.
+
+ Returns:
+ PromptTemplate with appended system message.
+ """
+ return self.__class__(
+ chat=[{"role": "system", "content": content}, *self.chat],
+ json_mode=self.json_mode,
+ response_parser=self.response_parser,
+ )
+
+ def add_user_message(self, content: str) -> Self:
+ """
+ Add a user message to the template prompt.
+
+ Args:
+ content: Message to be added.
+
+ Returns:
+ PromptTemplate with appended user message.
+ """
+ return self.__class__(
+ chat=[*self.chat, {"role": "user", "content": content}],
+ json_mode=self.json_mode,
+ response_parser=self.response_parser,
+ )
+
+ def add_assistant_message(self, content: str) -> Self:
+ """
+ Add an assistant message to the template prompt.
+
+ Args:
+ content: Message to be added.
+
+ Returns:
+ PromptTemplate with appended assistant message.
+ """
+ return self.__class__(
+ chat=[*self.chat, {"role": "assistant", "content": content}],
+ json_mode=self.json_mode,
+ response_parser=self.response_parser,
+ )
+
+ def add_few_shot_message(self, example: FewShotExample) -> Self:
+ """
+ Add a few-shot message to the template prompt.
+
+ Args:
+ example: Few-shot example to be added.
+
+ Returns:
+ PromptTemplate with appended few-shot message.
+
+ Raises:
+ PromptTemplateError: if the template is empty.
+ """
+ if len(self.chat) == 0:
+ raise PromptTemplateError("Cannot add few-shot messages to an empty template.")
+
+ few_shot = [
+ {"role": "user", "content": example.question, "is_example": True},
+ {"role": "assistant", "content": example.answer, "is_example": True},
+ ]
+ few_shot_index = max(
+ (i for i, entry in enumerate(self.chat) if entry.get("is_example") or entry.get("role") == "system"),
+ default=0,
+ )
+ chat = self.chat[: few_shot_index + 1] + few_shot + self.chat[few_shot_index + 1 :]
+
+ return self.__class__(
+ chat=chat,
+ json_mode=self.json_mode,
+ response_parser=self.response_parser,
+ )
+
+ def clear_few_shot_messages(self) -> Self:
+ """
+ Removes all few-shot messages from the template prompt.
+
+ Returns:
+ PromptTemplate with few-shot messages removed.
+ """
+ return self.__class__(
+ chat=[message for message in self.chat if not message.get("is_example")],
+ json_mode=self.json_mode,
+ response_parser=self.response_parser,
+ )
diff --git a/src/dbally/prompts/__init__.py b/src/dbally/prompts/__init__.py
deleted file mode 100644
index 38e20cc7..00000000
--- a/src/dbally/prompts/__init__.py
+++ /dev/null
@@ -1,4 +0,0 @@
-from .common_validation_utils import ChatFormat, PromptTemplateError, check_prompt_variables
-from .prompt_template import PromptTemplate
-
-__all__ = ["PromptTemplate", "PromptTemplateError", "check_prompt_variables", "ChatFormat"]
diff --git a/src/dbally/prompts/common_validation_utils.py b/src/dbally/prompts/common_validation_utils.py
deleted file mode 100644
index f4660810..00000000
--- a/src/dbally/prompts/common_validation_utils.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import re
-from typing import Dict, List, Set
-
-from dbally.exceptions import DbAllyError
-
-ChatFormat = List[Dict[str, str]]
-
-
-class PromptTemplateError(DbAllyError):
- """Error raised on incorrect PromptTemplate construction"""
-
-
-def _extract_variables(text: str) -> List[str]:
- """
- Given a text string, extract all variables that can be filled using .format
-
- Args:
- text: string to process
-
- Returns:
- list of variables extracted from text
- """
- pattern = r"\{([^}]+)\}"
- return re.findall(pattern, text)
-
-
-def check_prompt_variables(chat: ChatFormat, variables_to_check: Set[str]) -> ChatFormat:
- """
- Function validates a given chat to make sure it contains variables required.
-
- Args:
- chat: chat to validate
- variables_to_check: set of variables to assert
-
- Raises:
- PromptTemplateError: If required variables are missing
-
- Returns:
- Chat, if it's valid.
- """
- variables = []
- for message in chat:
- content = message["content"]
- variables.extend(_extract_variables(content))
- if not set(variables_to_check).issubset(variables):
- raise PromptTemplateError(
- "Cannot build a prompt template from the provided chat, "
- "because it lacks necessary string variables. "
- "You need to format the following variables: {variables_to_check}"
- )
- return chat
diff --git a/src/dbally/prompts/prompt_template.py b/src/dbally/prompts/prompt_template.py
deleted file mode 100644
index 8e2746fe..00000000
--- a/src/dbally/prompts/prompt_template.py
+++ /dev/null
@@ -1,83 +0,0 @@
-from typing import Callable, Dict, Optional
-
-from typing_extensions import Self
-
-from .common_validation_utils import ChatFormat, PromptTemplateError
-
-
-def _check_chat_order(chat: ChatFormat) -> ChatFormat:
- """
- Pydantic validator. Checks if the chat template is constructed correctly (system, user, assistant alternating).
-
- Args:
- chat: Chat template
-
- Raises:
- PromptTemplateError: if chat template is not constructed correctly.
-
- Returns:
- Chat template
- """
- expected_order = ["user", "assistant"]
- for i, message in enumerate(chat):
- role = message["role"]
- if role == "system":
- if i != 0:
- raise PromptTemplateError("Only first message should come from system")
- continue
- index = i % len(expected_order)
- if role != expected_order[index - 1]:
- raise PromptTemplateError(
- "Template format is not correct. It should be system, and then user/assistant alternating."
- )
-
- if expected_order[index] not in ["user", "assistant"]:
- raise PromptTemplateError("Template needs to end on either user or assistant turn")
- return chat
-
-
-class PromptTemplate:
- """
- Class for prompt templates
-
- Attributes:
- response_format: Optional argument for OpenAI Turbo models - may be used to force json output
- llm_response_parser: Function parsing the LLM response into IQL
- """
-
- def __init__(
- self,
- chat: ChatFormat,
- response_format: Optional[Dict[str, str]] = None,
- llm_response_parser: Callable = lambda x: x,
- ):
- self.chat: ChatFormat = _check_chat_order(chat)
- self.response_format = response_format
- self.llm_response_parser = llm_response_parser
-
- def __eq__(self, __value: object) -> bool:
- return isinstance(__value, PromptTemplate) and self.chat == __value.chat
-
- def add_user_message(self, content: str) -> Self:
- """
- Add a user message to the template prompt.
-
- Args:
- content: Message to be added
-
- Returns:
- PromptTemplate with appended user message
- """
- return self.__class__((*self.chat, {"role": "user", "content": content}))
-
- def add_assistant_message(self, content: str) -> Self:
- """
- Add an assistant message to the template prompt.
-
- Args:
- content: Message to be added
-
- Returns:
- PromptTemplate with appended assistant message
- """
- return self.__class__((*self.chat, {"role": "assistant", "content": content}))
diff --git a/src/dbally/similarity/index.py b/src/dbally/similarity/index.py
index 6895c566..31cfe8cf 100644
--- a/src/dbally/similarity/index.py
+++ b/src/dbally/similarity/index.py
@@ -2,7 +2,7 @@
from typing import Optional
from dbally.audit.event_tracker import EventTracker
-from dbally.data_models.audit import SimilarityEvent
+from dbally.audit.events import SimilarityEvent
from dbally.similarity.fetcher import SimilarityFetcher
from dbally.similarity.store import SimilarityStore
diff --git a/src/dbally/view_selection/llm_view_selector.py b/src/dbally/view_selection/llm_view_selector.py
index 2d501922..b4069bb1 100644
--- a/src/dbally/view_selection/llm_view_selector.py
+++ b/src/dbally/view_selection/llm_view_selector.py
@@ -1,12 +1,11 @@
-import copy
-from typing import Callable, Dict, Optional
+from typing import Dict, Optional
from dbally.audit.event_tracker import EventTracker
-from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
+from dbally.prompt.template import PromptTemplate
from dbally.view_selection.base import ViewSelector
-from dbally.view_selection.view_selector_prompt_template import default_view_selector_template
+from dbally.view_selection.prompt import VIEW_SELECTION_TEMPLATE, ViewSelectionPromptFormat
class LLMViewSelector(ViewSelector):
@@ -20,22 +19,16 @@ class LLMViewSelector(ViewSelector):
ultimately returning the name of the most suitable view.
"""
- def __init__(
- self,
- llm: LLM,
- prompt_template: Optional[IQLPromptTemplate] = None,
- promptify_views: Optional[Callable[[Dict[str, str]], str]] = None,
- ) -> None:
+ def __init__(self, llm: LLM, prompt_template: Optional[PromptTemplate[ViewSelectionPromptFormat]] = None) -> None:
"""
+ Constructs a new LLMViewSelector instance.
+
Args:
llm: LLM used to generate IQL
prompt_template: template for the prompt used for the view selection
- promptify_views: Function formatting filters for prompt. By default names and descriptions of\
- all views are concatenated
"""
self._llm = llm
- self._prompt_template = prompt_template or copy.deepcopy(default_view_selector_template)
- self._promptify_views = promptify_views or _promptify_views
+ self._prompt_template = prompt_template or VIEW_SELECTION_TEMPLATE
async def select_view(
self,
@@ -56,28 +49,13 @@ async def select_view(
Returns:
The most relevant view name.
"""
-
- views_for_prompt = self._promptify_views(views)
+ prompt_format = ViewSelectionPromptFormat(question=question, views=views)
+ formatted_prompt = self._prompt_template.format_prompt(prompt_format)
llm_response = await self._llm.generate_text(
- template=self._prompt_template,
- fmt={"views": views_for_prompt, "question": question},
+ prompt=formatted_prompt,
event_tracker=event_tracker,
options=llm_options,
)
- selected_view = self._prompt_template.llm_response_parser(llm_response)
+ selected_view = self._prompt_template.response_parser(llm_response)
return selected_view
-
-
-def _promptify_views(views: Dict[str, str]) -> str:
- """
- Formats views for prompt
-
- Args:
- views: dictionary of available view names with corresponding descriptions.
-
- Returns:
- views_for_prompt: views formatted for prompt
- """
-
- return "\n".join([f"{name}: {description}" for name, description in views.items()])
diff --git a/src/dbally/view_selection/prompt.py b/src/dbally/view_selection/prompt.py
new file mode 100644
index 00000000..cdbedf5a
--- /dev/null
+++ b/src/dbally/view_selection/prompt.py
@@ -0,0 +1,52 @@
+from typing import Dict, List
+
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import PromptFormat, PromptTemplate
+
+
+class ViewSelectionPromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by default IQL prompt.
+ """
+
+ def __init__(
+ self,
+ *,
+ question: str,
+ views: Dict[str, str],
+ examples: List[FewShotExample] = None,
+ ) -> None:
+ """
+ Constructs a new ViewSelectionPromptFormat instance.
+
+ Args:
+ question: Question to be asked.
+ views: Dictionary of available view names with corresponding descriptions.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
+ self.views = "\n".join([f"{name}: {description}" for name, description in views.items()])
+
+
+VIEW_SELECTION_TEMPLATE = PromptTemplate[ViewSelectionPromptFormat](
+ [
+ {
+ "role": "system",
+ "content": (
+ "You are a very smart database programmer. "
+ "You have access to API that lets you query a database:\n"
+ "First you need to select a class to query, based on its description and the user question. "
+ "You have the following classes to choose from:\n"
+ "{views}\n"
+ "Return only the selected view name. Don't give any comments.\n"
+ "You can only use the classes that were listed. "
+ "If none of the classes listed can be used to answer the user question, say `NoViewFoundError`"
+ ),
+ },
+ {
+ "role": "user",
+ "content": "{question}",
+ },
+ ],
+)
diff --git a/src/dbally/view_selection/view_selector_prompt_template.py b/src/dbally/view_selection/view_selector_prompt_template.py
deleted file mode 100644
index 60440c84..00000000
--- a/src/dbally/view_selection/view_selector_prompt_template.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import json
-from typing import Callable, Dict, Optional
-
-from dbally.prompts import ChatFormat, PromptTemplate, check_prompt_variables
-
-
-class ViewSelectorPromptTemplate(PromptTemplate):
- """
- Class for prompt templates meant for the ViewSelector
- """
-
- def __init__(
- self,
- chat: ChatFormat,
- response_format: Optional[Dict[str, str]] = None,
- llm_response_parser: Callable = lambda x: x,
- ):
- super().__init__(chat, response_format, llm_response_parser)
- self.chat = check_prompt_variables(chat, {"views"})
-
-
-def _convert_llm_json_response_to_selected_view(llm_response_json: str) -> str:
- """
- Converts LLM json response to IQL
-
- Args:
- llm_response_json: LLM response in JSON format
-
- Returns:
- A string containing selected view
- """
- llm_response_dict = json.loads(llm_response_json)
- return llm_response_dict.get("view")
-
-
-default_view_selector_template = ViewSelectorPromptTemplate(
- chat=(
- {
- "role": "system",
- "content": "You are a very smart database programmer. "
- "You have access to API that lets you query a database:\n"
- "First you need to select a class to query, based on its description and the user question. "
- "You have the following classes to choose from:\n"
- "{views}\n"
- "Return only the selected view name. Don't give any comments.\n"
- "You can only use the classes that were listed. "
- "If none of the classes listed can be used to answer the user question, say `NoViewFoundError`",
- },
- {"role": "user", "content": "{question}"},
- ),
-)
diff --git a/src/dbally/views/base.py b/src/dbally/views/base.py
index 104b8daf..365e83d6 100644
--- a/src/dbally/views/base.py
+++ b/src/dbally/views/base.py
@@ -5,6 +5,7 @@
from dbally.collection.results import ViewExecutionResult
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
+from dbally.prompt.elements import FewShotExample
from dbally.similarity import AbstractSimilarityIndex
from dbally.context.context import BaseCallerContext, CustomContextsList
@@ -51,3 +52,12 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc
Mapping of similarity indexes to their locations.
"""
return {}
+
+ def list_few_shots(self) -> List[FewShotExample]:
+ """
+ List all examples to be injected into few-shot prompt.
+
+ Returns:
+ List of few-shot examples
+ """
+ return []
diff --git a/src/dbally/views/freeform/text2sql/prompt.py b/src/dbally/views/freeform/text2sql/prompt.py
new file mode 100644
index 00000000..5f9a547d
--- /dev/null
+++ b/src/dbally/views/freeform/text2sql/prompt.py
@@ -0,0 +1,61 @@
+# pylint: disable=C0301
+
+from typing import List
+
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import PromptFormat, PromptTemplate
+from dbally.views.freeform.text2sql.config import TableConfig
+
+
+class SQLGenerationPromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by default SQL prompt.
+ """
+
+ def __init__(
+ self,
+ *,
+ question: str,
+ dialect: str,
+ tables: List[TableConfig],
+ examples: List[FewShotExample] = None,
+ ) -> None:
+ """
+ Constructs a new SQLGenerationPromptFormat instance.
+
+ Args:
+ question: Question to be asked.
+ context: Context of the query.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
+ self.dialect = dialect
+ self.tables = "\n".join(table.ddl for table in tables)
+
+
+SQL_GENERATION_TEMPLATE = PromptTemplate[SQLGenerationPromptFormat](
+ [
+ {
+ "role": "system",
+ "content": (
+ "You are a very smart database programmer. "
+ "You have access to the following {dialect} tables:\n"
+ "{tables}\n"
+ "Create SQL query to answer user question. Response with JSON containing following keys:\n\n"
+ "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n"
+ "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n"
+ " - name: the name of the parameter\n"
+ " - value: the value of the parameter\n"
+ " - table: the table the parameter is used with (if any)\n"
+ " - column: the column the parameter is compared to (if any)\n\n"
+ "Respond ONLY with the raw JSON response. Don't include any additional text or characters."
+ ),
+ },
+ {
+ "role": "user",
+ "content": "{question}",
+ },
+ ],
+ json_mode=True,
+)
diff --git a/src/dbally/views/freeform/text2sql/view.py b/src/dbally/views/freeform/text2sql/view.py
index af948b3b..7f24f00e 100644
--- a/src/dbally/views/freeform/text2sql/view.py
+++ b/src/dbally/views/freeform/text2sql/view.py
@@ -10,32 +10,12 @@
from dbally.collection.results import ViewExecutionResult
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
-from dbally.prompts import PromptTemplate
+from dbally.prompt.template import PromptTemplate
from dbally.similarity import AbstractSimilarityIndex, SimpleSqlAlchemyFetcher
from dbally.views.base import BaseView, IndexLocation
from dbally.views.freeform.text2sql.config import TableConfig
from dbally.views.freeform.text2sql.exceptions import Text2SQLError
-
-text2sql_prompt = PromptTemplate(
- chat=(
- {
- "role": "system",
- "content": "You are a very smart database programmer. "
- "You have access to the following {dialect} tables:\n"
- "{tables}\n"
- "Create SQL query to answer user question. Response with JSON containing following keys:\n\n"
- "- sql: SQL query to answer the question, with parameter :placeholders for user input.\n"
- "- parameters: a list of parameters to be used in the query, represented by maps with the following keys:\n"
- " - name: the name of the parameter\n"
- " - value: the value of the parameter\n"
- " - table: the table the parameter is used with (if any)\n"
- " - column: the column the parameter is compared to (if any)\n\n"
- "Respond ONLY with the raw JSON response. Don't include any additional text or characters.",
- },
- {"role": "user", "content": "{question}"},
- ),
- response_format={"type": "json_object"},
-)
+from dbally.views.freeform.text2sql.prompt import SQL_GENERATION_TEMPLATE, SQLGenerationPromptFormat
@dataclass
@@ -142,18 +122,26 @@ async def ask(
Raises:
Text2SQLError: If the text2sql query generation fails after n_retries.
"""
-
- conversation = text2sql_prompt
sql, rows = None, None
exceptions = []
- for _ in range(n_retries):
+ tables = self.get_tables()
+ examples = self.list_few_shots()
+
+ prompt_format = SQLGenerationPromptFormat(
+ question=query,
+ dialect=self._engine.dialect.name,
+ tables=tables,
+ examples=examples,
+ )
+ formatted_prompt = SQL_GENERATION_TEMPLATE.format_prompt(prompt_format)
+
+ for _ in range(n_retries + 1):
# We want to catch all exceptions to retry the process.
# pylint: disable=broad-except
try:
- sql, parameters, conversation = await self._generate_sql(
- query=query,
- conversation=conversation,
+ sql, parameters, formatted_prompt = await self._generate_sql(
+ conversation=formatted_prompt,
llm=llm,
event_tracker=event_tracker,
llm_options=llm_options,
@@ -165,7 +153,7 @@ async def ask(
rows = await self._execute_sql(sql, parameters, event_tracker=event_tracker)
break
except Exception as e:
- conversation = conversation.add_user_message(f"Response is invalid! Error: {e}")
+ formatted_prompt = formatted_prompt.add_user_message(f"Response is invalid! Error: {e}")
exceptions.append(e)
continue
@@ -183,15 +171,13 @@ async def ask(
async def _generate_sql(
self,
- query: str,
conversation: PromptTemplate,
llm: LLM,
event_tracker: EventTracker,
llm_options: Optional[LLMOptions] = None,
) -> Tuple[str, List[SQLParameterOption], PromptTemplate]:
response = await llm.generate_text(
- template=conversation,
- fmt={"tables": self._get_tables_context(), "dialect": self._engine.dialect.name, "question": query},
+ prompt=conversation,
event_tracker=event_tracker,
options=llm_options,
)
@@ -222,12 +208,6 @@ async def _execute_sql(
with self._engine.connect() as conn:
return conn.execute(text(sql), param_values).fetchall()
- def _get_tables_context(self) -> str:
- context = ""
- for table in self._table_index.values():
- context += f"{table.ddl}\n"
- return context
-
def _create_default_fetcher(self, table: str, column: str) -> SimpleSqlAlchemyFetcher:
return SimpleSqlAlchemyFetcher(
sqlalchemy_engine=self._engine,
diff --git a/src/dbally/views/structured.py b/src/dbally/views/structured.py
index 8a2cd3c1..5e9c4e1a 100644
--- a/src/dbally/views/structured.py
+++ b/src/dbally/views/structured.py
@@ -5,7 +5,7 @@
from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.context.context import BaseCallerContext
-from dbally.iql import IQLError, IQLQuery
+from dbally.iql import IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
@@ -27,10 +27,10 @@ def get_iql_generator(self, llm: LLM) -> IQLGenerator:
Returns the IQL generator for the view.
Args:
- llm: LLM used to generate the IQL queries
+ llm: LLM used to generate the IQL queries.
Returns:
- IQLGenerator: IQL generator for the view
+ IQL generator for the view.
"""
return IQLGenerator(llm=llm)
@@ -60,39 +60,30 @@ async def ask(
The result of the query.
"""
iql_generator = self.get_iql_generator(llm)
- filter_list = self.list_filters()
- iql_filters, conversation = await iql_generator.generate_iql(
+ filters = self.list_filters()
+ examples = self.list_few_shots()
+
+ iql = await iql_generator.generate_iql(
question=query,
- filters=filter_list,
+ filters=filters,
+ examples=examples,
event_tracker=event_tracker,
llm_options=llm_options,
+ n_retries=n_retries,
)
- for _ in range(n_retries):
- try:
- filters = await IQLQuery.parse(iql_filters, filter_list, event_tracker=event_tracker, context=context)
- await self.apply_filters(filters)
- break
- except (IQLError, ValueError) as e:
- conversation = iql_generator.add_error_msg(conversation, [e])
- iql_filters, conversation = await iql_generator.generate_iql(
- question=query,
- filters=filter_list,
- event_tracker=event_tracker,
- conversation=conversation,
- llm_options=llm_options,
- )
- continue
+ await self.apply_filters(iql)
result = self.execute(dry_run=dry_run)
- result.context["iql"] = iql_filters
+ result.context["iql"] = f"{iql}"
return result
@abc.abstractmethod
def list_filters(self) -> List[ExposedFunction]:
"""
+ Lists all available filters for the View.
Returns:
Filters defined inside the View.
@@ -129,4 +120,5 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc
for param in filter_.parameters:
if param.similarity_index:
indexes[param.similarity_index].append((self.__class__.__name__, filter_.name, param.name))
+
return indexes
diff --git a/src/dbally_codegen/autodiscovery.py b/src/dbally_codegen/autodiscovery.py
index 1e20c542..c842a07f 100644
--- a/src/dbally_codegen/autodiscovery.py
+++ b/src/dbally_codegen/autodiscovery.py
@@ -6,12 +6,59 @@
from typing_extensions import Self
from dbally.llms.base import LLM
-from dbally.prompts import PromptTemplate
+from dbally.prompt.template import PromptFormat, PromptTemplate
from dbally.similarity.index import SimilarityIndex
-from dbally.views.freeform.text2sql import ColumnConfig, TableConfig
+from dbally.views.freeform.text2sql.config import ColumnConfig, TableConfig
-DISCOVERY_TEMPLATE = PromptTemplate(
- chat=(
+
+class DiscoveryPromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by default discovery prompt.
+ """
+
+ def __init__(
+ self,
+ *,
+ dialect: str,
+ table_ddl: str,
+ samples: List[Dict[str, Any]],
+ ) -> None:
+ """
+ Constructs a new DiscoveryPromptFormat instance.
+
+ Args:
+ dialect: The SQL dialect of the database.
+ table_ddl: The DDL of the table.
+ samples: The example rows from the table.
+ """
+ super().__init__()
+ self.dialect = dialect
+ self.table_ddl = table_ddl
+ self.samples = samples
+
+
+class SimilarityPromptFormat(PromptFormat):
+ """
+ Formats provided parameters to a form acceptable by default similarity prompt.
+ """
+
+ def __init__(self, *, table_summary: str, column_name: str, samples: List[Any]) -> None:
+ """
+ Constructs a new SimilarityPromptFormat instance.
+
+ Args:
+ table_summary: The summary of the table.
+ column_name: The name of the column.
+ samples: The example values from the column.
+ """
+ super().__init__()
+ self.table_summary = table_summary
+ self.column_name = column_name
+ self.samples = samples
+
+
+DISCOVERY_TEMPLATE = PromptTemplate[DiscoveryPromptFormat](
+ [
{
"role": "system",
"content": (
@@ -24,11 +71,11 @@
"role": "user",
"content": "DDL:\n {table_ddl}\n" "EXAMPLE ROWS:\n {samples}",
},
- ),
+ ],
)
-SIMILARITY_TEMPLATE = PromptTemplate(
- chat=(
+SIMILARITY_TEMPLATE = PromptTemplate[SimilarityPromptFormat](
+ [
{
"role": "system",
"content": (
@@ -43,7 +90,7 @@
"role": "user",
"content": "TABLE SUMMARY: {table_summary}\n" "COLUMN NAME: {column_name}\n" "EXAMPLE VALUES: {samples}",
},
- )
+ ],
)
@@ -108,14 +155,15 @@ async def extract_description(self, table: Table, connection: Connection) -> str
"""
ddl = self._generate_ddl(table)
samples = self._fetch_samples(connection, table)
- return await self.llm.generate_text(
- template=DISCOVERY_TEMPLATE,
- fmt={
- "dialect": self.engine.dialect.name,
- "table_ddl": ddl,
- "samples": samples,
- },
+
+ prompt_format = DiscoveryPromptFormat(
+ dialect=self.engine.dialect.name,
+ table_ddl=ddl,
+ samples=samples,
)
+ formatted_prompt = DISCOVERY_TEMPLATE.format_prompt(prompt_format)
+
+ return await self.llm.generate_text(formatted_prompt)
def _fetch_samples(self, connection: Connection, table: Table) -> List[Dict[str, Any]]:
rows = connection.execute(table.select().limit(self.samples_count)).fetchall()
@@ -218,14 +266,15 @@ async def select_index(
table=table,
column=column,
)
- use_index = await self.llm.generate_text(
- template=SIMILARITY_TEMPLATE,
- fmt={
- "table_summary": description,
- "column_name": column.name,
- "samples": samples,
- },
+
+ prompt_format = SimilarityPromptFormat(
+ table_summary=description,
+ column_name=column.name,
+ samples=samples,
)
+ formatted_prompt = SIMILARITY_TEMPLATE.format_prompt(prompt_format)
+
+ use_index = await self.llm.generate_text(formatted_prompt)
return self.index_builder(connection.engine, table, column) if use_index.upper() == "TRUE" else None
def _fetch_samples(self, connection: Connection, table: Table, column: Column) -> List[Any]:
diff --git a/tests/integration/test_llm_options.py b/tests/integration/test_llm_options.py
index e8c53435..fb8cfba4 100644
--- a/tests/integration/test_llm_options.py
+++ b/tests/integration/test_llm_options.py
@@ -35,20 +35,20 @@ async def test_llm_options_propagation():
llm.client.call.assert_has_calls(
[
call(
- prompt=ANY,
- response_format=ANY,
+ conversation=ANY,
+ json_mode=ANY,
event=ANY,
options=expected_options,
),
call(
- prompt=ANY,
- response_format=ANY,
+ conversation=ANY,
+ json_mode=ANY,
event=ANY,
options=expected_options,
),
call(
- prompt=ANY,
- response_format=ANY,
+ conversation=ANY,
+ json_mode=ANY,
event=ANY,
options=expected_options,
),
diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py
index 9858e45f..75cc914b 100644
--- a/tests/unit/mocks.py
+++ b/tests/unit/mocks.py
@@ -6,12 +6,11 @@
from dataclasses import dataclass
from functools import cached_property
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
from dbally import NOT_GIVEN, NotGiven
from dbally.iql import IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
-from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMClient, LLMOptions
from dbally.similarity.index import AbstractSimilarityIndex
@@ -35,12 +34,12 @@ def execute(self, dry_run=False) -> ViewExecutionResult:
class MockIQLGenerator(IQLGenerator):
- def __init__(self, iql: str) -> None:
+ def __init__(self, iql: IQLQuery) -> None:
self.iql = iql
super().__init__(llm=MockLLM())
- async def generate_iql(self, *_, **__) -> Tuple[str, IQLPromptTemplate]:
- return self.iql, default_iql_template
+ async def generate_iql(self, *_, **__) -> IQLQuery:
+ return self.iql
class MockViewSelector(ViewSelector):
diff --git a/tests/unit/test_assistants_adapters.py b/tests/unit/test_assistants_adapters.py
index 72a55e06..9c203bd6 100644
--- a/tests/unit/test_assistants_adapters.py
+++ b/tests/unit/test_assistants_adapters.py
@@ -8,7 +8,7 @@
from dbally.assistants.base import FunctionCallingError, FunctionCallState
from dbally.assistants.openai import _DBALLY_INFO, _DBALLY_INSTRUCTION, OpenAIAdapter, OpenAIDballyResponse
-from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError
+from dbally.iql_generator.prompt import UnsupportedQueryError
MOCK_VIEWS = {"view1": "description1", "view2": "description2"}
F_ID = "f_id"
diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py
index 3e5bddb5..38ec3e99 100644
--- a/tests/unit/test_collection.py
+++ b/tests/unit/test_collection.py
@@ -1,7 +1,7 @@
# pylint: disable=missing-docstring, missing-return-doc, missing-param-doc, disallowed-name, missing-return-type-doc
from typing import List, Tuple, Type
-from unittest.mock import AsyncMock, Mock, call, patch
+from unittest.mock import AsyncMock, Mock
import pytest
from typing_extensions import Annotated
@@ -10,9 +10,9 @@
from dbally.collection import Collection
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ViewExecutionResult
-from dbally.iql._exceptions import IQLError
+from dbally.iql import IQLQuery
+from dbally.iql.syntax import FunctionCall
from dbally.views.exposed_functions import ExposedFunction, MethodParamWithTyping
-from dbally.views.structured import BaseStructuredView
from tests.unit.mocks import MockIQLGenerator, MockLLM, MockSimilarityIndex, MockViewBase, MockViewSelector
@@ -59,8 +59,8 @@ def execute(self, dry_run=False) -> ViewExecutionResult:
def list_filters(self) -> List[ExposedFunction]:
return [ExposedFunction("test_filter", "", [])]
- def get_iql_generator(self, *_, **__):
- return MockIQLGenerator("test_filter()")
+ def get_iql_generator(self, *_, **__) -> MockIQLGenerator:
+ return MockIQLGenerator(IQLQuery(FunctionCall("test_filter", []), "test_filter()"))
@pytest.fixture(name="similarity_classes")
@@ -275,42 +275,6 @@ def get_iql_generator(self, *_, **__):
return collection
-async def test_ask_feedback_loop(collection_feedback: Collection) -> None:
- """
- Tests that the ask_feedback_loop method works correctly
- """
-
- mock_node = Mock(col_offset=0, end_col_offset=-1)
- errors = [
- IQLError("err1", mock_node, "src1"),
- IQLError("err2", mock_node, "src2"),
- ValueError("err3"),
- ValueError("err4"),
- ]
- with patch("dbally.iql._query.IQLQuery.parse") as mock_iql_query:
- mock_iql_query.side_effect = errors
- view = collection_feedback.get("ViewWithMockGenerator")
- assert isinstance(view, BaseStructuredView)
- iql_generator = view.get_iql_generator(llm=MockLLM())
-
- await collection_feedback.ask("Mock question")
-
- iql_gen_error: Mock = iql_generator.add_error_msg # type: ignore
-
- iql_gen_error.assert_has_calls(
- [call("iql1_c", [errors[0]]), call("iql2_c", [errors[1]]), call("iql3_c", [errors[2]])]
- )
- assert iql_gen_error.call_count == 3
-
- iql_gen_gen_iql: Mock = iql_generator.generate_iql # type: ignore
-
- for i, c in enumerate(iql_gen_gen_iql.call_args_list):
- if i > 0:
- assert c[1]["conversation"] == f"err{i}"
-
- assert iql_gen_gen_iql.call_count == 4
-
-
async def test_ask_view_selection_single_view() -> None:
"""
Tests that the ask method select view correctly when there is only one view
diff --git a/tests/unit/test_fewshot.py b/tests/unit/test_fewshot.py
new file mode 100644
index 00000000..2b8ba8b3
--- /dev/null
+++ b/tests/unit/test_fewshot.py
@@ -0,0 +1,73 @@
+from typing import Callable, List, Tuple
+
+import pytest
+
+from dbally.prompt.elements import FewShotExample
+
+
+class TestExamples:
+ def studied_at(self, _: str) -> bool:
+ return False
+
+ def is_available_within_months(self, _: int) -> bool:
+ return False
+
+ def data_scientist_position(self) -> bool:
+ return False
+
+ def has_seniority(self, _: str) -> bool:
+ return False
+
+ def __call__(self) -> List[Tuple[str, Callable]]: # pylint: disable=W0602, C0116, W9011
+ return [
+ (
+ # dummy test
+ "None",
+ lambda: None,
+ ),
+ (
+ # test lambda
+ "True and False or data_scientist_position() or (True or True)",
+ lambda: (True and False or self.data_scientist_position() or (True or True)),
+ ),
+ (
+ # test string
+ 'studied_at("University of Toronto")',
+ lambda: self.studied_at("University of Toronto"),
+ ),
+ (
+ # test complex conditions with comments
+ 'is_available_within_months(1) and data_scientist_position() and has_seniority("senior")',
+ lambda: (
+ self.is_available_within_months(1)
+ and self.data_scientist_position()
+ and self.has_seniority("senior")
+ ), # pylint: disable=line-too-long
+ ),
+ (
+ # test nested conditions with comments
+ 'data_scientist_position(1) and (has_seniority("junior") or has_seniority("senior"))',
+ lambda: (
+ self.data_scientist_position(1)
+ and (
+ self.has_seniority("junior") or self.has_seniority("senior")
+ ) # pylint: disable=too-many-function-args
+ ),
+ ),
+ ]
+
+
+@pytest.mark.parametrize(
+ "repr_lambda",
+ TestExamples()(),
+)
+def test_fewshot_lambda(repr_lambda: Tuple[str, Callable]) -> None:
+ result = FewShotExample("question", repr_lambda[1])
+ assert result.answer == repr_lambda[0]
+ assert str(result) == f"question -> {repr_lambda[0]}"
+
+
+def test_fewshot_string() -> None:
+ result = FewShotExample("question", "answer")
+ assert result.answer == "answer"
+ assert str(result) == "question -> answer"
diff --git a/tests/unit/test_iql_format.py b/tests/unit/test_iql_format.py
new file mode 100644
index 00000000..8f583c4c
--- /dev/null
+++ b/tests/unit/test_iql_format.py
@@ -0,0 +1,89 @@
+from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
+from dbally.prompt.elements import FewShotExample
+
+
+async def test_iql_prompt_format_default() -> None:
+ prompt_format = IQLGenerationPromptFormat(
+ question="",
+ filters=[],
+ examples=[],
+ )
+ formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format)
+
+ assert formatted_prompt.chat == [
+ {
+ "role": "system",
+ "content": "You have access to API that lets you query a database:\n"
+ "\n\n"
+ "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
+ "Remember! Don't give any comments, just the function calls.\n"
+ "The output will look like this:\n"
+ 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
+ "DO NOT INCLUDE arguments names in your response. Only the values.\n"
+ "You MUST use only these methods:\n"
+ "\n\n"
+ "It is VERY IMPORTANT not to use methods other than those listed above."
+ """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """
+ "This is CRUCIAL, otherwise the system will crash. ",
+ "is_example": False,
+ },
+ {"role": "user", "content": "", "is_example": False},
+ ]
+
+
+async def test_iql_prompt_format_few_shots_injected() -> None:
+ examples = [FewShotExample("q1", "a1")]
+ prompt_format = IQLGenerationPromptFormat(
+ question="",
+ filters=[],
+ examples=examples,
+ )
+ formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format)
+
+ assert formatted_prompt.chat == [
+ {
+ "role": "system",
+ "content": "You have access to API that lets you query a database:\n"
+ "\n\n"
+ "Please suggest which one(s) to call and how they should be joined with logic operators (AND, OR, NOT).\n"
+ "Remember! Don't give any comments, just the function calls.\n"
+ "The output will look like this:\n"
+ 'filter1("arg1") AND (NOT filter2(120) OR filter3(True))\n'
+ "DO NOT INCLUDE arguments names in your response. Only the values.\n"
+ "You MUST use only these methods:\n"
+ "\n\n"
+ "It is VERY IMPORTANT not to use methods other than those listed above."
+ """If you DON'T KNOW HOW TO ANSWER DON'T SAY \"\", SAY: `UNSUPPORTED QUERY` INSTEAD! """
+ "This is CRUCIAL, otherwise the system will crash. ",
+ "is_example": False,
+ },
+ {"role": "user", "content": examples[0].question, "is_example": True},
+ {"role": "assistant", "content": examples[0].answer, "is_example": True},
+ {"role": "user", "content": "", "is_example": False},
+ ]
+
+
+async def test_iql_input_format_few_shot_examples_repeat_no_example_duplicates() -> None:
+ examples = [FewShotExample("q1", "a1")]
+ prompt_format = IQLGenerationPromptFormat(
+ question="",
+ filters=[],
+ examples=examples,
+ )
+ formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format)
+
+ assert len(formatted_prompt.chat) == len(IQL_GENERATION_TEMPLATE.chat) + (len(examples) * 2)
+ assert formatted_prompt.chat[1]["role"] == "user"
+ assert formatted_prompt.chat[1]["content"] == examples[0].question
+ assert formatted_prompt.chat[2]["role"] == "assistant"
+ assert formatted_prompt.chat[2]["content"] == examples[0].answer
+
+ formatted_prompt = formatted_prompt.add_assistant_message("response")
+
+ formatted_prompt2 = formatted_prompt.format_prompt(prompt_format)
+
+ assert len(formatted_prompt2.chat) == len(formatted_prompt.chat)
+ assert formatted_prompt2.chat[1]["role"] == "user"
+ assert formatted_prompt2.chat[1]["content"] == examples[0].question
+ assert formatted_prompt2.chat[2]["role"] == "assistant"
+ assert formatted_prompt2.chat[2]["content"] == examples[0].answer
diff --git a/tests/unit/test_iql_generator.py b/tests/unit/test_iql_generator.py
index c330f747..ce3f593d 100644
--- a/tests/unit/test_iql_generator.py
+++ b/tests/unit/test_iql_generator.py
@@ -1,15 +1,15 @@
# mypy: disable-error-code="empty-body"
-from unittest.mock import AsyncMock
+from unittest.mock import AsyncMock, Mock, patch
import pytest
import sqlalchemy
from dbally import decorators
from dbally.audit.event_tracker import EventTracker
-from dbally.iql import IQLQuery
+from dbally.iql import IQLError, IQLQuery
from dbally.iql_generator.iql_generator import IQLGenerator
-from dbally.iql_generator.iql_prompt_template import default_iql_template
+from dbally.iql_generator.prompt import IQL_GENERATION_TEMPLATE, IQLGenerationPromptFormat
from dbally.views.methods_base import MethodsBaseView
from tests.unit.mocks import MockLLM
@@ -41,7 +41,7 @@ def view() -> MockView:
@pytest.fixture
def llm() -> MockLLM:
llm = MockLLM()
- llm.client.call = AsyncMock(return_value="LLM IQL mock answer")
+ llm.generate_text = AsyncMock(return_value="filter_by_id(1)")
return llm
@@ -50,35 +50,64 @@ def event_tracker() -> EventTracker:
return EventTracker()
-@pytest.mark.asyncio
-async def test_iql_generation(llm: MockLLM, event_tracker: EventTracker, view: MockView) -> None:
- iql_generator = IQLGenerator(llm, default_iql_template)
-
- filters_for_prompt = iql_generator._promptify_view(view.list_filters())
- filters_in_prompt = set(filters_for_prompt.split("\n"))
-
- assert filters_in_prompt == {"filter_by_id(idx: int)", "filter_by_name(city: str)"}
-
- response = await iql_generator.generate_iql(view.list_filters(), "Mock_question", event_tracker)
+@pytest.fixture
+def iql_generator(llm: MockLLM) -> IQLGenerator:
+ return IQLGenerator(llm)
- template_after_response = default_iql_template.add_assistant_message(content="LLM IQL mock answer")
- assert response == ("LLM IQL mock answer", template_after_response)
- template_after_response = template_after_response.add_user_message(content="Mock_error")
- response2 = await iql_generator.generate_iql(
- view.list_filters(), "Mock_question", event_tracker, template_after_response
+@pytest.mark.asyncio
+async def test_iql_generation(iql_generator: IQLGenerator, event_tracker: EventTracker, view: MockView) -> None:
+ filters = view.list_filters()
+ prompt_format = IQLGenerationPromptFormat(
+ question="Mock_question",
+ filters=filters,
)
- template_after_2nd_response = template_after_response.add_assistant_message(content="LLM IQL mock answer")
- assert response2 == ("LLM IQL mock answer", template_after_2nd_response)
-
-
-def test_add_error_msg(llm: MockLLM) -> None:
- iql_generator = IQLGenerator(llm, default_iql_template)
- errors = [ValueError("Mock_error")]
+ formatted_prompt = IQL_GENERATION_TEMPLATE.format_prompt(prompt_format)
+
+ with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse:
+ iql = await iql_generator.generate_iql(
+ question="Mock_question",
+ filters=filters,
+ event_tracker=event_tracker,
+ )
+ assert iql == "filter_by_id(1)"
+ iql_generator._llm.generate_text.assert_called_once_with(
+ prompt=formatted_prompt,
+ event_tracker=event_tracker,
+ options=None,
+ )
+ mock_parse.assert_called_once_with(
+ source="filter_by_id(1)",
+ allowed_functions=filters,
+ event_tracker=event_tracker,
+ )
- conversation = default_iql_template.add_assistant_message(content="Assistant")
- conversation_with_error = iql_generator.add_error_msg(conversation, errors)
-
- error_msg = iql_generator._ERROR_MSG_PREFIX + "Mock_error\n"
- assert conversation_with_error == conversation.add_user_message(content=error_msg)
+@pytest.mark.asyncio
+async def test_iql_generation_error_handling(
+ iql_generator: IQLGenerator,
+ event_tracker: EventTracker,
+ view: MockView,
+) -> None:
+ filters = view.list_filters()
+
+ mock_node = Mock(col_offset=0, end_col_offset=-1)
+ errors = [
+ IQLError("err1", mock_node, "src1"),
+ IQLError("err2", mock_node, "src2"),
+ IQLError("err3", mock_node, "src3"),
+ IQLError("err4", mock_node, "src4"),
+ ]
+
+ with patch("dbally.iql.IQLQuery.parse", AsyncMock(return_value="filter_by_id(1)")) as mock_parse:
+ mock_parse.side_effect = errors
+ iql = await iql_generator.generate_iql(
+ question="Mock_question",
+ filters=filters,
+ event_tracker=event_tracker,
+ )
+
+ assert iql is None
+ assert iql_generator._llm.generate_text.call_count == 4
+ for i, arg in enumerate(iql_generator._llm.generate_text.call_args_list[1:], start=1):
+ assert f"err{i}" in arg[1]["prompt"].chat[-1]["content"]
diff --git a/tests/unit/test_prompt_builder.py b/tests/unit/test_prompt_builder.py
index f8a886fe..00fa7fd5 100644
--- a/tests/unit/test_prompt_builder.py
+++ b/tests/unit/test_prompt_builder.py
@@ -1,116 +1,99 @@
+from typing import List
+
import pytest
-from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate
-from dbally.prompts import ChatFormat, PromptTemplate, PromptTemplateError
-from tests.unit.mocks import MockLLM
+from dbally.prompt.elements import FewShotExample
+from dbally.prompt.template import ChatFormat, PromptFormat, PromptTemplate, PromptTemplateError
+
+
+class QuestionPromptFormat(PromptFormat):
+ """
+ Generic format for prompts allowing to inject few shot examples into the conversation.
+ """
+
+ def __init__(self, question: str, examples: List[FewShotExample] = None) -> None:
+ """
+ Constructs a new PromptFormat instance.
+
+ Args:
+ question: Question to be asked.
+ examples: List of examples to be injected into the conversation.
+ """
+ super().__init__(examples)
+ self.question = question
@pytest.fixture()
-def simple_template():
- simple_template = PromptTemplate(
- chat=(
+def template() -> PromptTemplate[QuestionPromptFormat]:
+ return PromptTemplate[QuestionPromptFormat](
+ [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "{question}"},
- )
+ ]
)
- return simple_template
-@pytest.fixture()
-def llm():
- return MockLLM()
+def test_prompt_template_formatting(template: PromptTemplate[QuestionPromptFormat]) -> None:
+ prompt_format = QuestionPromptFormat(question="Example user question?")
+ formatted_prompt = template.format_prompt(prompt_format)
+ assert formatted_prompt.chat == [
+ {"content": "You are a helpful assistant.", "role": "system", "is_example": False},
+ {"content": "Example user question?", "role": "user", "is_example": False},
+ ]
-def test_default_llm_format_prompt(llm, simple_template):
- prompt = llm.format_prompt(
- template=simple_template,
- fmt={"question": "Example user question?"},
- )
- assert prompt == [
- {"content": "You are a helpful assistant.", "role": "system"},
- {"content": "Example user question?", "role": "user"},
+def test_missing_prompt_template_formatting(template: PromptTemplate[QuestionPromptFormat]) -> None:
+ prompt_format = PromptFormat()
+ with pytest.raises(KeyError):
+ template.format_prompt(prompt_format)
+
+
+def test_add_few_shots(template: PromptTemplate[QuestionPromptFormat]) -> None:
+ examples = [
+ FewShotExample(
+ question="What is the capital of France?",
+ answer_expr="Paris",
+ ),
+ FewShotExample(
+ question="What is the capital of Germany?",
+ answer_expr="Berlin",
+ ),
]
+ for example in examples:
+ template = template.add_few_shot_message(example)
-def test_missing_format_dict(llm, simple_template):
- with pytest.raises(KeyError):
- _ = llm.format_prompt(simple_template, fmt={})
+ assert template.chat == [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "What is the capital of France?", "is_example": True},
+ {"role": "assistant", "content": "Paris", "is_example": True},
+ {"role": "user", "content": "What is the capital of Germany?", "is_example": True},
+ {"role": "assistant", "content": "Berlin", "is_example": True},
+ {"role": "user", "content": "{question}"},
+ ]
@pytest.mark.parametrize(
"invalid_chat",
[
- (
+ [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "{question}"},
{"role": "user", "content": "{question}"},
- ),
- (
+ ],
+ [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "assistant", "content": "{question}"},
{"role": "assistant", "content": "{question}"},
- ),
- (
+ ],
+ [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "{question}"},
{"role": "assistant", "content": "{question}"},
{"role": "system", "content": "{question}"},
- ),
+ ],
],
)
-def test_chat_order_validation(invalid_chat):
+def test_chat_order_validation(invalid_chat: ChatFormat) -> None:
with pytest.raises(PromptTemplateError):
- _ = PromptTemplate(chat=invalid_chat)
-
-
-def test_dynamic_few_shot(llm, simple_template):
- assert (
- len(
- llm.format_prompt(
- simple_template.add_assistant_message("assistant message").add_user_message("user message"),
- fmt={"question": "user question"},
- )
- )
- == 4
- )
-
-
-@pytest.mark.parametrize(
- "invalid_chat",
- [
- (
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "{question}"},
- ),
- (
- {"role": "system", "content": "You are a helpful assistant."},
- {"role": "user", "content": "Hello"},
- ),
- (
- {"role": "system", "content": "You are a helpful assistant. {filters}}"},
- {"role": "user", "content": "Hello"},
- ),
- ],
- ids=["Missing filters", "Missing filters, question", "Missing question"],
-)
-def test_bad_iql_prompt_template(invalid_chat: ChatFormat):
- with pytest.raises(PromptTemplateError):
- _ = IQLPromptTemplate(invalid_chat)
-
-
-@pytest.mark.parametrize(
- "chat",
- [
- (
- {"role": "system", "content": "You are a helpful assistant.{filters}"},
- {"role": "user", "content": "{question}"},
- ),
- (
- {"role": "system", "content": "{filters}{filters}{filters}}}"},
- {"role": "user", "content": "{question}"},
- ),
- ],
- ids=["Good template", "Good template with repeating variables"],
-)
-def test_good_iql_prompt_template(chat: ChatFormat):
- _ = IQLPromptTemplate(chat)
+ PromptTemplate[QuestionPromptFormat](invalid_chat)
diff --git a/tests/unit/test_view_selector.py b/tests/unit/test_view_selector.py
index 2d3b1d9c..8de038e2 100644
--- a/tests/unit/test_view_selector.py
+++ b/tests/unit/test_view_selector.py
@@ -31,7 +31,7 @@ def views() -> Dict[str, str]:
@pytest.mark.asyncio
-async def test_view_selection(llm: LLM, views: Dict[str, str]):
+async def test_view_selection(llm: LLM, views: Dict[str, str]) -> None:
view_selector = LLMViewSelector(llm)
view = await view_selector.select_view("Mock question?", views, event_tracker=EventTracker())
assert view == "MockView1"