Skip to content

Commit

Permalink
feat: few-shot selector (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgnatonski authored Jun 25, 2024
1 parent cd5bf7b commit d482638
Show file tree
Hide file tree
Showing 13 changed files with 435 additions and 62 deletions.
1 change: 1 addition & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ omit =
exclude_lines =
pragma: no cover
if __name__ == .__main__.
\.\.\.
show_missing = True
6 changes: 3 additions & 3 deletions benchmark/dbally_benchmark/iql_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dbally.iql_generator.iql_generator import IQLGenerator
from dbally.iql_generator.iql_prompt_template import UnsupportedQueryError, default_iql_template
from dbally.llms.litellm import LiteLLM
from dbally.prompts.formatters import IQLInputFormatter
from dbally.views.structured import BaseStructuredView


Expand All @@ -31,11 +32,10 @@ async def _run_iql_for_single_example(
) -> IQLResult:
filter_list = view.list_filters()
event_tracker = EventTracker()
input_formatter = IQLInputFormatter(question=example.question, filters=filter_list)

try:
iql_filters, _ = await iql_generator.generate_iql(
question=example.question, filters=filter_list, event_tracker=event_tracker
)
iql_filters, _ = await iql_generator.generate_iql(input_formatter=input_formatter, event_tracker=event_tracker)
except UnsupportedQueryError:
return IQLResult(question=example.question, iql_filters="UNSUPPORTED_QUERY", exception_raised=True)

Expand Down
45 changes: 44 additions & 1 deletion examples/recruiting/views.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Literal
from datetime import date
from typing import List, Literal

import awoc # pip install a-world-of-countries
import sqlalchemy
from dateutil.relativedelta import relativedelta
from sqlalchemy import and_, select

from dbally import SqlAlchemyBaseView, decorators
from dbally.prompts.elements import FewShotExample

from .db import Candidate

Expand Down Expand Up @@ -57,3 +60,43 @@ def is_from_continent( # pylint: disable=W0602, C0116, W9011
@decorators.view_filter()
def studied_at(self, university: str) -> sqlalchemy.ColumnElement: # pylint: disable=W0602, C0116, W9011
return Candidate.university == university


class FewShotRecruitmentView(RecruitmentView):
"""
A view for the recruitment database including examples of question:answers pairs (few-shot).
"""

@decorators.view_filter()
def is_available_within_months( # pylint: disable=W0602, C0116, W9011
self, months: int
) -> sqlalchemy.ColumnElement:
start = date.today()
end = start + relativedelta(months=months)
return Candidate.available_from.between(start, end)

def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011
return [
FewShotExample(
"Which candidates studied at University of Toronto?",
'studied_at("University of Toronto")',
),
FewShotExample(
"Do we have any soon available candidate?",
lambda: self.is_available_within_months(1),
),
FewShotExample(
"Do we have any soon available perfect fits for senior data scientist positions?",
lambda: (
self.is_available_within_months(1)
and self.data_scientist_position()
and self.has_seniority("senior")
),
),
FewShotExample(
"List all junior or senior data scientist positions",
lambda: (
self.data_scientist_position() and (self.has_seniority("junior") or self.has_seniority("senior"))
),
),
]
52 changes: 10 additions & 42 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import copy
from typing import Callable, List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple, TypeVar

from dbally.audit.event_tracker import EventTracker
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template
from dbally.iql_generator.iql_prompt_template import IQLPromptTemplate, default_iql_template # noqa
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.views.exposed_functions import ExposedFunction
from dbally.prompts.formatters import IQLInputFormatter


class IQLGenerator:
Expand All @@ -24,26 +23,16 @@ class IQLGenerator:

TException = TypeVar("TException", bound=Exception)

def __init__(
self,
llm: LLM,
prompt_template: Optional[IQLPromptTemplate] = None,
promptify_view: Optional[Callable] = None,
) -> None:
def __init__(self, llm: LLM) -> None:
"""
Args:
llm: LLM used to generate IQL
prompt_template: If not provided by the users is set to `default_iql_template`
promptify_view: Function formatting filters for prompt
"""
self._llm = llm
self._prompt_template = prompt_template or copy.deepcopy(default_iql_template)
self._promptify_view = promptify_view or _promptify_filters

async def generate_iql(
self,
filters: List[ExposedFunction],
question: str,
input_formatter: IQLInputFormatter,
event_tracker: EventTracker,
conversation: Optional[IQLPromptTemplate] = None,
llm_options: Optional[LLMOptions] = None,
Expand All @@ -52,30 +41,25 @@ async def generate_iql(
Uses LLM to generate IQL in text form
Args:
question: user question
filters: list of filters exposed by the view
input_formatter: formatter used to prepare prompt arguments dictionary
event_tracker: event store used to audit the generation process
conversation: conversation to be continued
llm_options: options to use for the LLM client
Returns:
IQL - iql generated based on the user question
"""
filters_for_prompt = self._promptify_view(filters)

template = conversation or self._prompt_template
conversation, fmt = input_formatter(conversation or default_iql_template)

llm_response = await self._llm.generate_text(
template=template,
fmt={"filters": filters_for_prompt, "question": question},
template=conversation,
fmt=fmt,
event_tracker=event_tracker,
options=llm_options,
)

iql_filters = self._prompt_template.llm_response_parser(llm_response)

if conversation is None:
conversation = self._prompt_template
iql_filters = conversation.llm_response_parser(llm_response)

conversation = conversation.add_assistant_message(content=llm_response)

Expand All @@ -98,19 +82,3 @@ def add_error_msg(self, conversation: IQLPromptTemplate, errors: List[TException
msg += str(error) + "\n"

return conversation.add_user_message(content=msg)


def _promptify_filters(
filters: List[ExposedFunction],
) -> str:
"""
Formats filters for prompt
Args:
filters: list of filters exposed by the view
Returns:
filters_for_prompt: filters formatted for prompt
"""
filters_for_prompt = "\n".join([str(filter) for filter in filters])
return filters_for_prompt
2 changes: 1 addition & 1 deletion src/dbally/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def format_prompt(self, template: PromptTemplate, fmt: Dict[str, str]) -> ChatFo
Returns:
Prompt in the format of the client.
"""
return [{**message, "content": message["content"].format(**fmt)} for message in template.chat]
return [{"role": message["role"], "content": message["content"].format(**fmt)} for message in template.chat]

def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int:
"""
Expand Down
61 changes: 61 additions & 0 deletions src/dbally/prompts/elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import inspect
import re
import textwrap
from typing import Callable, Union


class FewShotExample:
"""
A question:answer representation for few-shot prompting
"""

def __init__(self, question: str, answer_expr: Union[str, Callable]) -> None:
"""
Args:
question: sample question
answer_expr: it can be either a stringified expression or a lambda for greater safety and code completions.
Raises:
ValueError: If answer_expr is not a correct type.
"""
self.question = question
self.answer_expr = answer_expr

if isinstance(self.answer_expr, str):
self.answer = self.answer_expr
elif callable(answer_expr):
self.answer = self._parse_lambda(answer_expr)
else:
raise ValueError("Answer expression should be either a string or a lambda")

def _parse_lambda(self, expr: Callable) -> str:
"""
Parses provided callable in order to extract the lambda code.
All comments and references to variables like `self` etc will be removed
to form a simple lambda representation.
Args:
expr: lambda expression to parse
Returns:
Parsed lambda in a form of cleaned up string
"""
# extract lambda from code
expr_source = textwrap.dedent(inspect.getsource(expr))
expr_body = expr_source.replace("lambda:", "")

# clean up by removing comments, new lines, free vars (self etc)
parsed_expr = re.sub("\\#.*\n", "\n", expr_body, flags=re.MULTILINE)

for m_name in expr.__code__.co_names:
parsed_expr = parsed_expr.replace(f"{expr.__code__.co_freevars[0]}.{m_name}", m_name)

# clean up any dangling commas or leading and trailing brackets
parsed_expr = " ".join(parsed_expr.split()).strip().rstrip(",").replace("( ", "(").replace(" )", ")")
if parsed_expr.startswith("("):
parsed_expr = parsed_expr[1:-1]

return parsed_expr

def __str__(self) -> str:
return self.answer
119 changes: 119 additions & 0 deletions src/dbally/prompts/formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import copy
from abc import ABCMeta, abstractmethod
from typing import Dict, List, Tuple

from dbally.prompts.elements import FewShotExample
from dbally.prompts.prompt_template import PromptTemplate
from dbally.views.exposed_functions import ExposedFunction


def _promptify_filters(
filters: List[ExposedFunction],
) -> str:
"""
Formats filters for prompt
Args:
filters: list of filters exposed by the view
Returns:
filters formatted for prompt
"""
filters_for_prompt = "\n".join([str(filter) for filter in filters])
return filters_for_prompt


class InputFormatter(metaclass=ABCMeta):
"""
Formats provided parameters to a form acceptable by IQL prompt
"""

@abstractmethod
def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]:
"""
Runs the input formatting for provided prompt template.
Args:
conversation_template: a prompt template to use.
Returns:
A tuple with template and a dictionary with formatted inputs.
"""


class IQLInputFormatter(InputFormatter):
"""
Formats provided parameters to a form acceptable by default IQL prompt
"""

def __init__(self, filters: List[ExposedFunction], question: str) -> None:
self.filters = filters
self.question = question

def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]:
"""
Runs the input formatting for provided prompt template.
Args:
conversation_template: a prompt template to use.
Returns:
A tuple with template and a dictionary with formatted filters and a question.
"""
return conversation_template, {
"filters": _promptify_filters(self.filters),
"question": self.question,
}


class IQLFewShotInputFormatter(InputFormatter):
"""
Formats provided parameters to a form acceptable by default IQL prompt.
Calling it will inject `examples` before last message in a conversation.
"""

def __init__(
self,
filters: List[ExposedFunction],
examples: List[FewShotExample],
question: str,
) -> None:
self.filters = filters
self.question = question
self.examples = examples

def __call__(self, conversation_template: PromptTemplate) -> Tuple[PromptTemplate, Dict[str, str]]:
"""
Performs a deep copy of provided template and injects examples into chat history.
Also prepares filters and question to be included within the prompt.
Args:
conversation_template: a prompt template to use to inject few-shot examples.
Returns:
A tuple with deeply-copied and enriched with examples template
and a dictionary with formatted filters and a question.
"""

template_copy = copy.deepcopy(conversation_template)
sys_msg = template_copy.chat[0]
existing_msgs = [msg for msg in template_copy.chat[1:] if "is_example" not in msg]
chat_examples = [
msg
for example in self.examples
for msg in [
{"role": "user", "content": example.question, "is_example": True},
{"role": "assistant", "content": example.answer, "is_example": True},
]
]

template_copy.chat = (
sys_msg,
*chat_examples,
*existing_msgs,
)

return template_copy, {
"filters": _promptify_filters(self.filters),
"question": self.question,
}
10 changes: 10 additions & 0 deletions src/dbally/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from dbally.collection.results import ViewExecutionResult
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompts.elements import FewShotExample
from dbally.similarity import AbstractSimilarityIndex

IndexLocation = Tuple[str, str, str]
Expand Down Expand Up @@ -49,3 +50,12 @@ def list_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[IndexLoc
Mapping of similarity indexes to their locations.
"""
return {}

def list_few_shots(self) -> List[FewShotExample]:
"""
List all examples to be injected into few-shot prompt.
Returns:
List of few-shot examples
"""
return []
Loading

0 comments on commit d482638

Please sign in to comment.