Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/infer options arg #25

Merged
merged 10 commits into from
Jun 19, 2024
9 changes: 9 additions & 0 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ def infer_gen_constraints(self, start: int, end: int) -> dict:

- pattern: regular expression pattern lambda to use in constrained decoding with Model
- See `create_pattern` for more info on these pattern lambdas

- options: Optional str default to pass to `options` argument in a QAIngredient
- Will have the form '{table}::{column}'
"""

def create_pattern(
Expand Down Expand Up @@ -763,6 +766,12 @@ def create_pattern(
predicate_literals: List[str] = []
if start_node is not None:
predicate_literals = get_predicate_literals(start_node)
if isinstance(start_node, exp.EQ):
if isinstance(start_node.args["this"], exp.Column):
# This is valid for a default `options` set
added_kwargs[
"options"
] = f"{start_node.args['this'].args['table'].name}::{start_node.args['this'].args['this'].name}"
if len(predicate_literals) > 0:
if all(isinstance(x, bool) for x in predicate_literals):
output_type = "boolean"
Expand Down
29 changes: 16 additions & 13 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,23 +606,22 @@ def _blend(
executed_subquery_ingredients.add(alias_function_str)
kwargs_dict = parsed_results_dict["kwargs_dict"]

if ingredient.ingredient_type == IngredientType.MAP:
if infer_gen_constraints:
# Latter is the winner.
# So if we already define something in kwargs_dict,
# It's not overriden here
kwargs_dict = (
scm.infer_gen_constraints(
start=start,
end=end,
)
| kwargs_dict
if infer_gen_constraints:
# Latter is the winner.
# So if we already define something in kwargs_dict,
# It's not overriden here
kwargs_dict = (
scm.infer_gen_constraints(
start=start,
end=end,
)
| kwargs_dict
)

if table_to_title is not None:
kwargs_dict["table_to_title"] = table_to_title

# Optionally, recursively call blend() again to get subtable
# Optionally, recursively call blend() again to get subtable from args
# This applies to `context` and `options`
for i, unpack_kwarg in enumerate(
[IngredientKwarg.CONTEXT, IngredientKwarg.OPTIONS]
Expand Down Expand Up @@ -655,7 +654,11 @@ def _blend(
_prev_passed_values = _smoothie.meta.num_values_passed
subtable = _smoothie.df
if unpack_kwarg == IngredientKwarg.OPTIONS:
# Here, we need to format as a flat list
if len(subtable.columns) != 1:
raise InvalidBlendSQL(
f"Invalid subquery passed to `options`!\nNeeds to return exactly one column, got {len(subtable.columns)} instead"
)
# Here, we need to format as a flat set
kwargs_dict[unpack_kwarg] = list(subtable.values.flat)
else:
kwargs_dict[unpack_kwarg] = subtable
Expand Down
4 changes: 1 addition & 3 deletions blendsql/db/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def execute_to_df(self, query: str, params: Optional[dict] = None) -> pd.DataFra
...

@abstractmethod
def execute_to_list(
self, query: str, to_type: Optional[Callable] = lambda x: x
) -> list:
def execute_to_list(self, query: str, to_type: Callable = lambda x: x) -> list:
"""A lower-level execute method that doesn't use the pandas processing logic.
Returns results as a list.
"""
Expand Down
2 changes: 1 addition & 1 deletion blendsql/ingredients/builtin/qa/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,4 @@ def run(
**kwargs,
)
# Post-process language model response
return "'{}'".format(single_quote_escape(result.strip().lower()))
return "'{}'".format(single_quote_escape(result.strip()))
4 changes: 2 additions & 2 deletions blendsql/models/remote/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
**kwargs
)

def _load_model(self, config: OpenAIConfig) -> ModelObj:
def _load_model(self, config: Optional[OpenAIConfig] = None) -> ModelObj:
return azure_openai(
self.model_name_or_path,
config=config,
Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(
**kwargs
)

def _load_model(self, config: OpenAIConfig) -> ModelObj:
def _load_model(self, config: Optional[OpenAIConfig] = None) -> ModelObj:
return openai(
self.model_name_or_path, config=config, api_key=os.getenv("OPENAI_API_KEY")
) # type: ignore
Expand Down
10 changes: 0 additions & 10 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,22 +79,12 @@ def find_version(*file_paths):
"recognizers-text-suite",
"emoji==1.7.0",
],
"test": [
"pytest",
"pre-commit",
"llama-cpp-python",
"transformers",
"torch",
"coverage",
"tox",
],
"docs": [
"mkdocs-material",
"mkdocstrings",
"mkdocs-section-index",
"mkdocstrings-python",
"mkdocs-jupyter",
],
"demo": ["chainlit"],
},
)
41 changes: 28 additions & 13 deletions tests/test_generic_blendsql.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import pytest
import pandas as pd
import sqlite3
from pathlib import Path
from blendsql import blend
from blendsql.db import SQLite
from blendsql.db import Pandas
from blendsql._exceptions import IngredientException, InvalidBlendSQL
from tests.utils import select_first_option


@pytest.fixture(scope="session")
def db() -> SQLite:
"""Create a dummy sqlite db to use in tests."""
dbpath = "./test_generic.db"
df = pd.DataFrame({"Name": ["Danny", "Emma", "Tony"], "Age": [23, 26, 19]})
con = sqlite3.connect(dbpath)
df.to_sql("w", con=con)
con.close()
yield SQLite(dbpath)
Path(dbpath).unlink()
def db() -> Pandas:
"""Create a dummy db to use in tests."""
return Pandas(
pd.DataFrame({"Name": ["Danny", "Emma", "Tony"], "Age": [23, 26, 19]}),
tablename="w",
)


def test_error_on_delete1(db):
Expand Down Expand Up @@ -45,11 +41,30 @@ def test_error_on_delete2(db):

def test_error_on_invalid_ingredient(db):
blendsql = """
SELECT * transactions WHERE {{ingredient()}} = TRUE
SELECT * w WHERE {{ingredient()}} = TRUE
"""
with pytest.raises(IngredientException):
_ = blend(
query=blendsql,
db=db,
ingredients={"This is not an ingredient type"},
)


def test_error_on_bad_options_subquery(db):
blendsql = """
SELECT * FROM w
WHERE {{
select_first_option(
'I am at a nice cafe right now',
'w::Name',
options=(SELECT * FROM w)
)
}}
"""
with pytest.raises(InvalidBlendSQL):
_ = blend(
query=blendsql,
db=db,
ingredients={select_first_option},
)
26 changes: 26 additions & 0 deletions tests/test_multi_table_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
do_join,
return_aapl,
get_table_size,
select_first_option,
)

databases = [
Expand All @@ -29,6 +30,7 @@ def ingredients() -> set:
do_join.from_args(use_skrub_joiner=False),
return_aapl,
get_table_size,
select_first_option,
}


Expand Down Expand Up @@ -605,3 +607,27 @@ def test_options_referencing_cte_multi_exec(db, ingredients):
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_infer_options_arg(db, ingredients):
"""The infer_gen_constraints function should extend to cases when we do a
`column = {{QAIngredient()}}` predicate.

1a98559
"""
blendsql = """
SELECT * FROM account_history
WHERE Symbol = {{select_first_option()}}
"""
sql = """
SELECT * FROM account_history
WHERE Symbol = (SELECT Symbol FROM account_history WHERE Symbol NOT NULL ORDER BY Symbol LIMIT 1)
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)
2 changes: 1 addition & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def run(
self, question: str, context: pd.DataFrame, options: set, **kwargs
) -> Union[str, int, float]:
"""Returns the first item in the (ordered) options set"""
return f"'{single_quote_escape(sorted(list(options))[0])}'"
return f"'{single_quote_escape(sorted(list(filter(lambda x: x, options)))[0])}'"


class do_join(JoinIngredient):
Expand Down
Loading