Skip to content

Commit

Permalink
Merge pull request #25 from parkervg/feature/infer-options-arg
Browse files Browse the repository at this point in the history
Feature/infer options arg
  • Loading branch information
parkervg authored Jun 19, 2024
2 parents 1446145 + 45e7c86 commit 7fefb40
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 43 deletions.
9 changes: 9 additions & 0 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def infer_gen_constraints(self, start: int, end: int) -> dict:
- pattern: regular expression pattern lambda to use in constrained decoding with Model
- See `create_pattern` for more info on these pattern lambdas
- options: Optional str default to pass to `options` argument in a QAIngredient
- Will have the form '{table}::{column}'
"""

def create_pattern(
Expand Down Expand Up @@ -763,6 +766,12 @@ def create_pattern(
predicate_literals: List[str] = []
if start_node is not None:
predicate_literals = get_predicate_literals(start_node)
if isinstance(start_node, exp.EQ):
if isinstance(start_node.args["this"], exp.Column):
# This is valid for a default `options` set
added_kwargs[
"options"
] = f"{start_node.args['this'].args['table'].name}::{start_node.args['this'].args['this'].name}"
if len(predicate_literals) > 0:
if all(isinstance(x, bool) for x in predicate_literals):
output_type = "boolean"
Expand Down
29 changes: 16 additions & 13 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,23 +606,22 @@ def _blend(
executed_subquery_ingredients.add(alias_function_str)
kwargs_dict = parsed_results_dict["kwargs_dict"]

if ingredient.ingredient_type == IngredientType.MAP:
if infer_gen_constraints:
# Latter is the winner.
# So if we already define something in kwargs_dict,
# It's not overriden here
kwargs_dict = (
scm.infer_gen_constraints(
start=start,
end=end,
)
| kwargs_dict
if infer_gen_constraints:
# Latter is the winner.
# So if we already define something in kwargs_dict,
# It's not overriden here
kwargs_dict = (
scm.infer_gen_constraints(
start=start,
end=end,
)
| kwargs_dict
)

if table_to_title is not None:
kwargs_dict["table_to_title"] = table_to_title

# Optionally, recursively call blend() again to get subtable
# Optionally, recursively call blend() again to get subtable from args
# This applies to `context` and `options`
for i, unpack_kwarg in enumerate(
[IngredientKwarg.CONTEXT, IngredientKwarg.OPTIONS]
Expand Down Expand Up @@ -655,7 +654,11 @@ def _blend(
_prev_passed_values = _smoothie.meta.num_values_passed
subtable = _smoothie.df
if unpack_kwarg == IngredientKwarg.OPTIONS:
# Here, we need to format as a flat list
if len(subtable.columns) != 1:
raise InvalidBlendSQL(
f"Invalid subquery passed to `options`!\nNeeds to return exactly one column, got {len(subtable.columns)} instead"
)
# Here, we need to format as a flat set
kwargs_dict[unpack_kwarg] = list(subtable.values.flat)
else:
kwargs_dict[unpack_kwarg] = subtable
Expand Down
4 changes: 1 addition & 3 deletions blendsql/db/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def execute_to_df(self, query: str, params: Optional[dict] = None) -> pd.DataFra
...

@abstractmethod
def execute_to_list(
self, query: str, to_type: Optional[Callable] = lambda x: x
) -> list:
def execute_to_list(self, query: str, to_type: Callable = lambda x: x) -> list:
"""A lower-level execute method that doesn't use the pandas processing logic.
Returns results as a list.
"""
Expand Down
2 changes: 1 addition & 1 deletion blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def run(
**kwargs,
)
# Post-process language model response
return "'{}'".format(single_quote_escape(result.strip().lower()))
return "'{}'".format(single_quote_escape(result.strip()))
4 changes: 2 additions & 2 deletions blendsql/models/remote/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
**kwargs
)

def _load_model(self, config: OpenAIConfig) -> ModelObj:
def _load_model(self, config: Optional[OpenAIConfig] = None) -> ModelObj:
return azure_openai(
self.model_name_or_path,
config=config,
Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(
**kwargs
)

def _load_model(self, config: OpenAIConfig) -> ModelObj:
def _load_model(self, config: Optional[OpenAIConfig] = None) -> ModelObj:
return openai(
self.model_name_or_path, config=config, api_key=os.getenv("OPENAI_API_KEY")
) # type: ignore
Expand Down
10 changes: 0 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,12 @@ def find_version(*file_paths):
"recognizers-text-suite",
"emoji==1.7.0",
],
"test": [
"pytest",
"pre-commit",
"llama-cpp-python",
"transformers",
"torch",
"coverage",
"tox",
],
"docs": [
"mkdocs-material",
"mkdocstrings",
"mkdocs-section-index",
"mkdocstrings-python",
"mkdocs-jupyter",
],
"demo": ["chainlit"],
},
)
41 changes: 28 additions & 13 deletions tests/test_generic_blendsql.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import pytest
import pandas as pd
import sqlite3
from pathlib import Path
from blendsql import blend
from blendsql.db import SQLite
from blendsql.db import Pandas
from blendsql._exceptions import IngredientException, InvalidBlendSQL
from tests.utils import select_first_option


@pytest.fixture(scope="session")
def db() -> SQLite:
"""Create a dummy sqlite db to use in tests."""
dbpath = "./test_generic.db"
df = pd.DataFrame({"Name": ["Danny", "Emma", "Tony"], "Age": [23, 26, 19]})
con = sqlite3.connect(dbpath)
df.to_sql("w", con=con)
con.close()
yield SQLite(dbpath)
Path(dbpath).unlink()
def db() -> Pandas:
"""Create a dummy db to use in tests."""
return Pandas(
pd.DataFrame({"Name": ["Danny", "Emma", "Tony"], "Age": [23, 26, 19]}),
tablename="w",
)


def test_error_on_delete1(db):
Expand Down Expand Up @@ -45,11 +41,30 @@ def test_error_on_delete2(db):

def test_error_on_invalid_ingredient(db):
blendsql = """
SELECT * transactions WHERE {{ingredient()}} = TRUE
SELECT * w WHERE {{ingredient()}} = TRUE
"""
with pytest.raises(IngredientException):
_ = blend(
query=blendsql,
db=db,
ingredients={"This is not an ingredient type"},
)


def test_error_on_bad_options_subquery(db):
blendsql = """
SELECT * FROM w
WHERE {{
select_first_option(
'I am at a nice cafe right now',
'w::Name',
options=(SELECT * FROM w)
)
}}
"""
with pytest.raises(InvalidBlendSQL):
_ = blend(
query=blendsql,
db=db,
ingredients={select_first_option},
)
26 changes: 26 additions & 0 deletions tests/test_multi_table_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
do_join,
return_aapl,
get_table_size,
select_first_option,
)

databases = [
Expand All @@ -29,6 +30,7 @@ def ingredients() -> set:
do_join.from_args(use_skrub_joiner=False),
return_aapl,
get_table_size,
select_first_option,
}


Expand Down Expand Up @@ -605,3 +607,27 @@ def test_options_referencing_cte_multi_exec(db, ingredients):
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_infer_options_arg(db, ingredients):
"""The infer_gen_constraints function should extend to cases when we do a
`column = {{QAIngredient()}}` predicate.
1a98559
"""
blendsql = """
SELECT * FROM account_history
WHERE Symbol = {{select_first_option()}}
"""
sql = """
SELECT * FROM account_history
WHERE Symbol = (SELECT Symbol FROM account_history WHERE Symbol NOT NULL ORDER BY Symbol LIMIT 1)
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(
self, question: str, context: pd.DataFrame, options: set, **kwargs
) -> Union[str, int, float]:
"""Returns the first item in the (ordered) options set"""
return f"'{single_quote_escape(sorted(list(options))[0])}'"
return f"'{single_quote_escape(sorted(list(filter(lambda x: x, options)))[0])}'"


class do_join(JoinIngredient):
Expand Down

0 comments on commit 7fefb40

Please sign in to comment.