Skip to content

Commit

Permalink
Merge pull request #28 from parkervg/feature/lazy-evaluation
Browse files Browse the repository at this point in the history
`unpack_options()`, adding options argument to MapIngredient
  • Loading branch information
parkervg authored Jul 1, 2024
2 parents 24c0008 + c5b7145 commit e3f2668
Show file tree
Hide file tree
Showing 14 changed files with 217 additions and 66 deletions.
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,20 @@ SELECT COUNT(*) FROM parks
| 1 |
<hr>

_What's the difference in visitors for those parks with a superlative in their description vs. those without?_
```sql
SELECT SUM(CAST(REPLACE("Recreation Visitors (2022)", ',', '') AS integer)) AS "Total Visitors",
{{LLMMap('Contains a superlative?', 'parks::Description', options='t;f')}} AS "Description Contains Superlative",
GROUP_CONCAT(Name, ', ') AS "Park Names"
FROM parks
GROUP BY "Description Contains Superlative"
```
| Total Visitors | Description Contains Superlative | Park Names |
|---------------:|-----------------------------------:|:------------------------------|
| 43365 | 0 | Everglades, Katmai |
| 2722385 | 1 | Death Valley, New River Gorge |
<hr>

Now, we have an intermediate representation for our LLM to use that is explainable, debuggable, and [very effective at hybrid question-answering tasks](https://arxiv.org/abs/2402.17882).

For in-depth descriptions of the above queries, check out our [documentation](https://parkervg.github.io/blendsql/).
Expand Down
1 change: 1 addition & 0 deletions blendsql/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ class IngredientKwarg:
CONTEXT: str = "context"
VALUES: str = "values"
OPTIONS: str = "options"
REGEX: str = "regex"
MODEL: str = "model"
16 changes: 16 additions & 0 deletions blendsql/_smoothie.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,19 @@ class Smoothie:

def __post_init__(self):
self.df = PrettyDataFrame(self.df)

def summary(self):
s = "---------------- SUMMARY ----------------\n"
s += self.meta.query + "\n"
s += tabulate(
pd.DataFrame(
{
"Time (s)": self.meta.process_time_seconds,
"Values Passed to Ingredients": self.meta.num_values_passed,
"Prompt Tokens": self.meta.prompt_tokens,
"Completion Tokens": self.meta.completion_tokens,
},
index=[0],
)
)
return s
33 changes: 24 additions & 9 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from colorama import Fore
import string

from ._logger import logger, msg_box
from ._logger import logger
from .utils import (
sub_tablename,
get_temp_session_table,
Expand Down Expand Up @@ -374,7 +374,7 @@ def disambiguate_and_submit_blend(
for alias, d in ingredient_alias_to_parsed_dict.items():
query = re.sub(re.escape(alias), d["raw"], query)
logger.debug(
Fore.CYAN + f"Executing `{query}` and setting to `{aliasname}`..." + Fore.RESET
Fore.CYAN + f"Executing `{query}` and setting to `{aliasname}`" + Fore.RESET
)
return _blend(query=query, **kwargs)

Expand Down Expand Up @@ -547,7 +547,11 @@ def _blend(
db.lazy_tables.pop(tablename).collect()
logger.debug(
Fore.CYAN
+ f"Executing `{abstracted_query_str}` and setting to `{_get_temp_subquery_table(tablename)}`..."
+ "Executing "
+ Fore.LIGHTCYAN_EX
+ f"`{abstracted_query_str}` "
+ Fore.CYAN
+ f"and setting to `{_get_temp_subquery_table(tablename)}`..."
+ Fore.RESET
)
try:
Expand Down Expand Up @@ -629,7 +633,13 @@ def _blend(
continue
executed_subquery_ingredients.add(alias_function_str)
kwargs_dict = parsed_results_dict["kwargs_dict"]

logger.debug(
Fore.CYAN
+ "Executing "
+ Fore.LIGHTCYAN_EX
+ f" `{parsed_results_dict['raw']}`..."
+ Fore.RESET
)
if infer_gen_constraints:
# Latter is the winner.
# So if we already define something in kwargs_dict,
Expand Down Expand Up @@ -774,8 +784,13 @@ def _blend(
if naive_execution:
break
# Combine all the retrieved ingredient outputs
for tablename, llm_outs in tablename_to_map_out.items():
if len(llm_outs) > 0:
for tablename, ingredient_outputs in tablename_to_map_out.items():
if len(ingredient_outputs) > 0:
logger.debug(
Fore.CYAN
+ f"Combining {len(ingredient_outputs)} outputs for table `{tablename}`"
+ Fore.RESET
)
# Once we finish parsing this subquery, write to our session_uuid table
# Below, we differ from Binder, which seems to replace the old table
# On their left join merge command: https://github.com/HKUNLP/Binder/blob/9eede69186ef3f621d2a50572e1696bc418c0e77/nsql/database.py#L196
Expand All @@ -793,8 +808,8 @@ def _blend(
previously_added_columns = base_table.columns.difference(
_base_table.columns
)
assert len(set([len(x) for x in llm_outs])) == 1
llm_out_df = pd.concat(llm_outs, axis=1)
assert len(set([len(x) for x in ingredient_outputs])) == 1
llm_out_df = pd.concat(ingredient_outputs, axis=1)
llm_out_df = llm_out_df.loc[:, ~llm_out_df.columns.duplicated()]
# Handle duplicate columns, e.g. in test_nested_duplicate_ingredient_calls()
for column in previously_added_columns:
Expand Down Expand Up @@ -842,7 +857,7 @@ def _blend(
if table.name in db.lazy_tables:
db.lazy_tables.pop(table.name).collect()

logger.debug(Fore.LIGHTGREEN_EX + msg_box(f"Final Query:\n{query}") + Fore.RESET)
logger.debug(Fore.LIGHTGREEN_EX + f"Final Query:\n{query}" + Fore.RESET)

df = db.execute_to_df(query)

Expand Down
2 changes: 1 addition & 1 deletion blendsql/db/_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def to_temp_table(self, df: pd.DataFrame, tablename: str):
create_table_stmt = re.sub(
r"^CREATE TABLE", "CREATE TEMP TABLE", create_table_stmt
)
logger.debug(Fore.CYAN + create_table_stmt + Fore.RESET)
logger.debug(Fore.LIGHTBLACK_EX + create_table_stmt + Fore.RESET)
self.con.execute(text(create_table_stmt))
df.to_sql(name=tablename, con=self.con, if_exists="append", index=False)

Expand Down
4 changes: 2 additions & 2 deletions blendsql/generate/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
def regex(
model: Model,
prompt: str,
pattern: str,
regex: str,
max_tokens: Optional[int] = None,
stop_at: Optional[Union[List[str], str]] = None,
) -> str:
generator = outlines.generate.regex(model.model_obj, regex_str=pattern)
generator = outlines.generate.regex(model.model_obj, regex_str=regex)
return generator(prompt, max_tokens=max_tokens, stop_at=stop_at)


Expand Down
2 changes: 1 addition & 1 deletion blendsql/ingredients/builtin/join/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __call__(
response = generate.regex(
model,
prompt=prompt,
pattern=regex(len(left_values)),
regex=regex(len(left_values)),
max_tokens=max_tokens,
stop_at=["---"],
)
Expand Down
12 changes: 6 additions & 6 deletions blendsql/ingredients/builtin/map/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def __call__(
prompt += f"\nHere are some example outputs: {example_outputs}\n"
prompt += "\nA:"
if isinstance(model, LocalModel) and regex is not None:
response = generate.regex(model, prompt=prompt, pattern=regex(len(values)))
response = generate.regex(model, prompt=prompt, regex=regex(len(values)))
else:
response = generate.text(
model, prompt=prompt, max_tokens=max_tokens, stop_at="\n"
Expand All @@ -139,7 +139,7 @@ def run(
value_limit: Union[int, None] = None,
example_outputs: Optional[str] = None,
output_type: Optional[str] = None,
pattern: Optional[str] = None,
regex: Optional[str] = None,
table_to_title: Optional[Dict[str, str]] = None,
**kwargs,
) -> Iterable[Any]:
Expand All @@ -152,7 +152,7 @@ def run(
value_limit: Optional limit on the number of values to pass to the Model
example_outputs: If binary == False, this gives the Model an example of the output we expect.
output_type: One of 'numeric', 'string', 'bool'
pattern: Optional regex to constrain answer generation.
regex: Optional regex to constrain answer generation.
table_to_title: Mapping from tablename to a title providing some more context.
Returns:
Expand All @@ -165,7 +165,7 @@ def run(
# Unpack default kwargs
tablename, colname = self.unpack_default_kwargs(**kwargs)
# Remote endpoints can't use patterns
pattern = None if isinstance(model, RemoteModel) else pattern
regex = None if isinstance(model, RemoteModel) else regex
if value_limit is not None:
values = values[:value_limit]
values = [value if not pd.isna(value) else "-" for value in values]
Expand Down Expand Up @@ -207,7 +207,7 @@ def run(
output_type=output_type,
include_tf_disclaimer=include_tf_disclaimer,
table_title=table_title,
regex=pattern,
regex=regex,
max_tokens=max_tokens,
**kwargs,
)
Expand Down Expand Up @@ -259,7 +259,7 @@ def run(
continue
logger.debug(
Fore.YELLOW
+ f"Finished with values:\n{json.dumps(dict(zip(values[:10], split_results[:10])), indent=4)}"
+ f"Finished LLMMap with values:\n{json.dumps(dict(zip(values[:10], split_results[:10])), indent=4)}"
+ Fore.RESET
)
return split_results
58 changes: 33 additions & 25 deletions blendsql/ingredients/ingredient.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import re
from attr import attrs, attrib
from abc import abstractmethod
import pandas as pd
from sqlglot import exp
import json
from skrub import Joiner
from typing import Any, Union, Dict, Tuple, Callable, Set, Optional, List, Type
from typing import Any, Union, Dict, Tuple, Callable, Set, Optional, Type
from collections.abc import Collection, Iterable
import uuid
from colorama import Fore
Expand All @@ -14,10 +15,16 @@
from .._exceptions import IngredientException
from .._logger import logger
from .. import utils
from .._constants import IngredientKwarg, IngredientType
from .._constants import (
IngredientKwarg,
IngredientType,
DEFAULT_ANS_SEP,
DEFAULT_NAN_ANS,
)
from ..db import Database
from ..db.utils import select_all_from_table_query
from ..models import Model
from .utils import unpack_options


def unpack_default_kwargs(**kwargs):
Expand Down Expand Up @@ -94,7 +101,13 @@ def unpack_default_kwargs(self, **kwargs):
return unpack_default_kwargs(**kwargs)

def __call__(
self, question: str = None, context: str = None, *args, **kwargs
self,
question: str = None,
context: str = None,
regex: Optional[Callable] = None,
options: Optional[Union[list, str]] = None,
*args,
**kwargs,
) -> tuple:
"""Returns tuple with format (arg, tablename, colname, new_table)"""
# Unpack kwargs
Expand Down Expand Up @@ -159,8 +172,18 @@ def __call__(
original_table[new_arg_column] = None
return (new_arg_column, tablename, colname, original_table)

if options is not None:
# Override any pattern with our new unpacked options
unpacked_options = unpack_options(
options=options, aliases_to_tablenames=aliases_to_tablenames, db=self.db
)
base_regex = f"(({'|'.join([re.escape(option) for option in unpacked_options])}|{DEFAULT_NAN_ANS}){DEFAULT_ANS_SEP})"
kwargs[IngredientKwarg.REGEX] = (
lambda num_repeats: base_regex + "{" + str(num_repeats) + "}"
)
else:
kwargs[IngredientKwarg.REGEX] = regex
kwargs[IngredientKwarg.VALUES] = values
kwargs["original_table"] = original_table
kwargs[IngredientKwarg.QUESTION] = question
mapped_values: Collection[Any] = self._run(*args, **kwargs)
self.num_values_passed += len(mapped_values)
Expand Down Expand Up @@ -385,29 +408,14 @@ def __call__(
if subtable.empty:
raise IngredientException("Empty subtable passed to QAIngredient!")

unpacked_options: Union[List[str], None] = options
if options is not None:
if not isinstance(options, list):
try:
tablename, colname = utils.get_tablename_colname(options)
tablename = aliases_to_tablenames.get(tablename, tablename)
# Optionally materialize a CTE
if tablename in self.db.lazy_tables:
unpacked_options = (
self.db.lazy_tables.pop(tablename)
.collect()[colname]
.unique()
.tolist()
)
else:
unpacked_options = self.db.execute_to_list(
f'SELECT DISTINCT "{colname}" FROM "{tablename}"'
)
except ValueError:
unpacked_options = options.split(";")
unpacked_options: Set[str] = set(unpacked_options)
kwargs[IngredientKwarg.OPTIONS] = unpack_options(
options=options, aliases_to_tablenames=aliases_to_tablenames, db=self.db
)
else:
kwargs[IngredientKwarg.OPTIONS] = None

self.num_values_passed += len(subtable) if subtable is not None else 0
kwargs[IngredientKwarg.OPTIONS] = unpacked_options
kwargs[IngredientKwarg.CONTEXT] = subtable
kwargs[IngredientKwarg.QUESTION] = question
response: Union[str, int, float] = self._run(*args, **kwargs)
Expand Down
26 changes: 26 additions & 0 deletions blendsql/ingredients/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Union, List, Set, Dict

from ..utils import get_tablename_colname
from ..db import Database


def unpack_options(
options: Union[List[str], str], aliases_to_tablenames: Dict[str, str], db: Database
) -> Set[str]:
unpacked_options = options
if not isinstance(options, list):
try:
tablename, colname = get_tablename_colname(options)
tablename = aliases_to_tablenames.get(tablename, tablename)
# Optionally materialize a CTE
if tablename in db.lazy_tables:
unpacked_options = (
db.lazy_tables.pop(tablename).collect()[colname].unique().tolist()
)
else:
unpacked_options = db.execute_to_list(
f'SELECT DISTINCT "{colname}" FROM "{tablename}"'
)
except ValueError:
unpacked_options = options.split(";")
return set(unpacked_options)
2 changes: 1 addition & 1 deletion blendsql/models/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def predict(self, program: Type[Program], **kwargs) -> str:
# First, check our cache
key: str = self._create_key(program, **kwargs)
if key in self.cache:
logger.debug(Fore.MAGENTA + "Using cache..." + Fore.RESET)
logger.debug(Fore.MAGENTA + "Using model cache..." + Fore.RESET)
response: str = self.cache.get(key) # type: ignore
self.prompts.insert(-1, self.format_prompt(response, **kwargs))
return response
Expand Down
Loading

0 comments on commit e3f2668

Please sign in to comment.