Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/lazy evaluation #27

Merged
merged 10 commits into from
Jun 23, 2024
86 changes: 42 additions & 44 deletions blendsql/_sqlglot.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import copy

import sqlglot
from sqlglot import exp, Schema
from sqlglot.optimizer.scope import build_scope
Expand Down Expand Up @@ -31,14 +33,6 @@
"""

SUBQUERY_EXP = (exp.Select,)
CONDITIONS = (
exp.Where,
exp.Group,
# IMPORTANT: If we uncomment limit, then `test_limit` in `test_single_table_blendsql.py` will not pass
# exp.Limit,
exp.Except,
exp.Order,
)
MODIFIERS = (
exp.Delete,
exp.AlterColumn,
Expand Down Expand Up @@ -385,7 +379,7 @@ def maybe_set_subqueries_to_true(node):
return node.transform(set_subqueries_to_true).transform(prune_empty_where)


def all_terminals_are_true(node) -> bool:
def check_all_terminals_are_true(node) -> bool:
"""Check to see if all terminal nodes of a given node are TRUE booleans."""
for n, _, _ in node.walk():
try:
Expand Down Expand Up @@ -423,6 +417,32 @@ def get_scope_nodes(
yield tablenode


def check_ingredients_only_in_top_select(node) -> bool:
select_exps = list(node.find_all(exp.Select))
if len(select_exps) == 1:
# Check if the only `STRUCT` nodes are found in select
all_struct_exps = list(node.find_all(exp.Struct))
if len(all_struct_exps) > 0:
num_select_struct_exps = sum(
[
len(list(n.find_all(exp.Struct)))
for n in select_exps[0].find_all(exp.Alias)
]
)
if num_select_struct_exps == len(all_struct_exps):
return True
return False


def to_select_star(node) -> exp.Expression:
""" """
select_star_node = copy.deepcopy(node)
select_star_node.find(exp.Select).set(
"expressions", exp.select("*").args["expressions"]
)
return select_star_node


@attrs
class QueryContextManager:
"""Handles manipulation of underlying SQL query.
Expand Down Expand Up @@ -514,40 +534,18 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, str], None, None]:
# Example: """SELECT w.title, w."designer ( s )", {{LLMMap('How many animals are in this image?', 'images::title')}}
# FROM images JOIN w ON w.title = images.title
# WHERE "designer ( s )" = 'georgia gerber'"""
join_exp = self.node.find(exp.Join)
if join_exp is not None:
select_exps = list(self.node.find_all(exp.Select))
if len(select_exps) == 1:
# Check if the only `STRUCT` nodes are found in select
all_struct_exps = list(self.node.find_all(exp.Struct))
if len(all_struct_exps) > 0:
num_select_struct_exps = sum(
[
len(list(n.find_all(exp.Struct)))
for n in select_exps[0].find_all(exp.Alias)
]
)
if num_select_struct_exps == len(all_struct_exps):
if len(self.tables_in_ingredients) == 1:
tablename = next(iter(self.tables_in_ingredients))
join_tablename = set(
[i.name for i in self.node.find_all(exp.Table)]
).difference({tablename})
if len(join_tablename) == 1:
join_tablename = next(iter(join_tablename))
base_select_str = f'SELECT "{tablename}".* FROM "{tablename}", {join_tablename} WHERE '
table_conditions_str = self.get_table_predicates_str(
tablename=tablename,
disambiguate_multi_tables=False,
)
abstracted_query = _parse_one(
base_select_str + table_conditions_str
)
abstracted_query_str = recover_blendsql(
abstracted_query.sql(dialect=FTS5SQLite)
)
yield (tablename, abstracted_query_str)
return
# Below, we need `self.node.find(exp.Table)` in case we get a QAIngredient on its own
# E.g. `SELECT A() AS _col_0` should be ignored
if check_ingredients_only_in_top_select(self.node) and self.node.find(
exp.Table
):
abstracted_query = to_select_star(self.node).transform(set_structs_to_true)
abstracted_query_str = recover_blendsql(
abstracted_query.sql(dialect=FTS5SQLite)
)
for tablename in self.tables_in_ingredients:
yield (tablename, abstracted_query_str)
return
for tablename, table_star_query in self._table_star_queries():
# If this table_star_query doesn't have an ingredient at the top-level, we can safely ignore
if (
Expand Down Expand Up @@ -577,7 +575,7 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, str], None, None]:
continue
elif isinstance(where_node.args["this"], exp.Column):
continue
elif all_terminals_are_true(where_node):
elif check_all_terminals_are_true(where_node):
continue
elif not where_node:
continue
Expand Down
34 changes: 26 additions & 8 deletions blendsql/blend.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
get_tablename_colname,
)
from ._exceptions import InvalidBlendSQL
from .db import Database
from .db import Database, DuckDB
from .db.utils import double_quote_escape, select_all_from_table_query, LazyTable
from ._sqlglot import (
MODIFIERS,
Expand Down Expand Up @@ -237,6 +237,12 @@ def preprocess_blendsql(query: str, default_model: Model) -> Tuple[str, dict, se
# maybe if I was better at pp.Suppress we wouldn't need this
kwargs_dict = {x[0]: x[-1] for x in parsed_results_dict["kwargs"]}
kwargs_dict[IngredientKwarg.MODEL] = default_model
# Heuristic check to see if we should snag the singleton arg as context
if (
len(parsed_results_dict["args"]) == 1
and "::" in parsed_results_dict["args"][0]
):
kwargs_dict[IngredientKwarg.CONTEXT] = parsed_results_dict["args"].pop()
context_arg = kwargs_dict.get(
IngredientKwarg.CONTEXT,
(
Expand Down Expand Up @@ -547,8 +553,25 @@ def _blend(
+ Fore.RESET
)
try:
abstracted_df = db.execute_to_df(abstracted_query)
if isinstance(db, DuckDB):
set_of_column_names = set(
i.strip('"') for i in schema[f'"{tablename}"']
)
# In case of a join, duckdb formats columns with 'column_1'
# But some columns (e.g. 'parent_category') just have underscores in them already
abstracted_df = abstracted_df.rename(
columns=lambda x: re.sub(r"_\d$", "", x)
if x not in set_of_column_names # noqa: B023
else x
)
# In case of a join, we could have duplicate column names in our pandas dataframe
# This will throw an error when we try to write to the database
abstracted_df = abstracted_df.loc[
:, ~abstracted_df.columns.duplicated()
]
db.to_temp_table(
df=db.execute_to_df(abstracted_query),
df=abstracted_df,
tablename=_get_temp_subquery_table(tablename),
)
except OperationalError as e:
Expand Down Expand Up @@ -622,12 +645,7 @@ def _blend(

if table_to_title is not None:
kwargs_dict["table_to_title"] = table_to_title
# Heuristic check to see if we should snag the singleton arg as context
if (
len(parsed_results_dict["args"]) == 1
and "::" in parsed_results_dict["args"][0]
):
kwargs_dict[IngredientKwarg.CONTEXT] = parsed_results_dict["args"].pop()

# Optionally, recursively call blend() again to get subtable from args
# This applies to `context` and `options`
for i, unpack_kwarg in enumerate(
Expand Down
20 changes: 10 additions & 10 deletions tests/test_multi_table_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_join_multi_exec(db, ingredients):
LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol
WHERE constituents.Sector = 'Information Technology'
AND {{starts_with('A', 'constituents::Name')}} = 1
AND lower(account_history.Action) like '%dividend%'
AND lower(LOWER(account_history.Action)) like '%dividend%'
ORDER BY "Total Dividend Payout ($$)"
"""
sql = """
Expand All @@ -99,7 +99,7 @@ def test_join_multi_exec(db, ingredients):
LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol
WHERE constituents.Sector = 'Information Technology'
AND constituents.Name LIKE 'A%'
AND lower(account_history.Action) like '%dividend%'
AND lower(LOWER(account_history.Action)) like '%dividend%'
ORDER BY "Total Dividend Payout ($$)"
"""
smoothie = blend(
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_select_multi_exec(db, ingredients):
FROM account_history
LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol
WHERE constituents.Sector = {{select_first_sorted(options='constituents::Sector')}}
AND lower(account_history.Action) like '%dividend%'
AND lower(LOWER(account_history.Action)) like '%dividend%'
"""
sql = """
SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name
Expand All @@ -175,7 +175,7 @@ def test_select_multi_exec(db, ingredients):
SELECT Sector FROM constituents
ORDER BY Sector LIMIT 1
)
AND lower(account_history.Action) like '%dividend%'
AND lower(LOWER(account_history.Action)) like '%dividend%'
"""
smoothie = blend(
query=blendsql,
Expand Down Expand Up @@ -457,13 +457,13 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients):
blendsql = """
SELECT {{get_length('n_length', 'constituents::Name')}}
FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol
WHERE account_history.Action like '%dividend%'
WHERE LOWER(account_history.Action) like '%dividend%'
ORDER BY constituents.Name
"""
sql = """
SELECT LENGTH(constituents.Name)
FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol
WHERE account_history.Action like '%dividend%'
WHERE LOWER(account_history.Action) like '%dividend%'
ORDER BY constituents.Name
"""
smoothie = blend(
Expand All @@ -478,7 +478,7 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients):
"""
SELECT COUNT(DISTINCT constituents.Name)
FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol
WHERE account_history.Action like '%dividend%'
WHERE LOWER(account_history.Action) like '%dividend%'
"""
)[0]
assert smoothie.meta.num_values_passed == passed_to_ingredient
Expand All @@ -493,12 +493,12 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients)
blendsql = """
SELECT {{get_length('n_length', 'constituents::Name')}}, Action
FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol
WHERE Action like '%dividend%'
WHERE LOWER(Action) like '%dividend%'
"""
sql = """
SELECT LENGTH(constituents.Name), Action
FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol
WHERE Action like '%dividend%'
WHERE LOWER(Action) like '%dividend%'
"""
smoothie = blend(
query=blendsql,
Expand All @@ -512,7 +512,7 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients)
"""
SELECT COUNT(DISTINCT constituents.Name)
FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol
WHERE account_history.Action like '%dividend%'
WHERE LOWER(account_history.Action) like '%dividend%'
"""
)[0]
assert smoothie.meta.num_values_passed == passed_to_ingredient
Expand Down
47 changes: 47 additions & 0 deletions tests/test_single_table_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,5 +551,52 @@ def test_query_options_arg(db, ingredients):
assert smoothie.df.values.flat[0] == "Paypal"


@pytest.mark.parametrize("db", databases)
def test_apply_limit(db, ingredients):
# commit 335c67a
blendsql = """
SELECT {{get_length('length', 'transactions::merchant')}} FROM transactions ORDER BY merchant LIMIT 1
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
assert smoothie.meta.num_values_passed == 1


@pytest.mark.parametrize("db", databases)
def test_apply_limit_with_predicate(db, ingredients):
# commit 335c67a
blendsql = """
SELECT {{get_length('length', 'transactions::merchant')}}
FROM transactions
WHERE amount > 1300
ORDER BY merchant LIMIT 3
"""
sql = """
SELECT LENGTH(merchant)
FROM transactions
WHERE amount > 1300
ORDER BY merchant LIMIT 3
"""
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)
# Make sure we only pass what's necessary to our ingredient
passed_to_ingredient = db.execute_to_list(
"""
SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 1300 LIMIT 3 """
)[0]
# We say `<=` here because the ingredient operates over sets, rather than lists
# So this kind of screws up the `LIMIT` calculation
# But execution outputs should be the same (tested above)
assert smoothie.meta.num_values_passed <= passed_to_ingredient


if __name__ == "__main__":
pytest.main()
Loading