diff --git a/blendsql/_sqlglot.py b/blendsql/_sqlglot.py index d0dedec5..8bb28d37 100644 --- a/blendsql/_sqlglot.py +++ b/blendsql/_sqlglot.py @@ -1,3 +1,5 @@ +import copy + import sqlglot from sqlglot import exp, Schema from sqlglot.optimizer.scope import build_scope @@ -31,14 +33,6 @@ """ SUBQUERY_EXP = (exp.Select,) -CONDITIONS = ( - exp.Where, - exp.Group, - # IMPORTANT: If we uncomment limit, then `test_limit` in `test_single_table_blendsql.py` will not pass - # exp.Limit, - exp.Except, - exp.Order, -) MODIFIERS = ( exp.Delete, exp.AlterColumn, @@ -385,7 +379,7 @@ def maybe_set_subqueries_to_true(node): return node.transform(set_subqueries_to_true).transform(prune_empty_where) -def all_terminals_are_true(node) -> bool: +def check_all_terminals_are_true(node) -> bool: """Check to see if all terminal nodes of a given node are TRUE booleans.""" for n, _, _ in node.walk(): try: @@ -423,6 +417,32 @@ def get_scope_nodes( yield tablenode +def check_ingredients_only_in_top_select(node) -> bool: + select_exps = list(node.find_all(exp.Select)) + if len(select_exps) == 1: + # Check if the only `STRUCT` nodes are found in select + all_struct_exps = list(node.find_all(exp.Struct)) + if len(all_struct_exps) > 0: + num_select_struct_exps = sum( + [ + len(list(n.find_all(exp.Struct))) + for n in select_exps[0].find_all(exp.Alias) + ] + ) + if num_select_struct_exps == len(all_struct_exps): + return True + return False + + +def to_select_star(node) -> exp.Expression: + """ """ + select_star_node = copy.deepcopy(node) + select_star_node.find(exp.Select).set( + "expressions", exp.select("*").args["expressions"] + ) + return select_star_node + + @attrs class QueryContextManager: """Handles manipulation of underlying SQL query. @@ -514,40 +534,18 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, str], None, None]: # Example: """SELECT w.title, w."designer ( s )", {{LLMMap('How many animals are in this image?', 'images::title')}} # FROM images JOIN w ON w.title = images.title # WHERE "designer ( s )" = 'georgia gerber'""" - join_exp = self.node.find(exp.Join) - if join_exp is not None: - select_exps = list(self.node.find_all(exp.Select)) - if len(select_exps) == 1: - # Check if the only `STRUCT` nodes are found in select - all_struct_exps = list(self.node.find_all(exp.Struct)) - if len(all_struct_exps) > 0: - num_select_struct_exps = sum( - [ - len(list(n.find_all(exp.Struct))) - for n in select_exps[0].find_all(exp.Alias) - ] - ) - if num_select_struct_exps == len(all_struct_exps): - if len(self.tables_in_ingredients) == 1: - tablename = next(iter(self.tables_in_ingredients)) - join_tablename = set( - [i.name for i in self.node.find_all(exp.Table)] - ).difference({tablename}) - if len(join_tablename) == 1: - join_tablename = next(iter(join_tablename)) - base_select_str = f'SELECT "{tablename}".* FROM "{tablename}", {join_tablename} WHERE ' - table_conditions_str = self.get_table_predicates_str( - tablename=tablename, - disambiguate_multi_tables=False, - ) - abstracted_query = _parse_one( - base_select_str + table_conditions_str - ) - abstracted_query_str = recover_blendsql( - abstracted_query.sql(dialect=FTS5SQLite) - ) - yield (tablename, abstracted_query_str) - return + # Below, we need `self.node.find(exp.Table)` in case we get a QAIngredient on its own + # E.g. `SELECT A() AS _col_0` should be ignored + if check_ingredients_only_in_top_select(self.node) and self.node.find( + exp.Table + ): + abstracted_query = to_select_star(self.node).transform(set_structs_to_true) + abstracted_query_str = recover_blendsql( + abstracted_query.sql(dialect=FTS5SQLite) + ) + for tablename in self.tables_in_ingredients: + yield (tablename, abstracted_query_str) + return for tablename, table_star_query in self._table_star_queries(): # If this table_star_query doesn't have an ingredient at the top-level, we can safely ignore if ( @@ -577,7 +575,7 @@ def abstracted_table_selects(self) -> Generator[Tuple[str, str], None, None]: continue elif isinstance(where_node.args["this"], exp.Column): continue - elif all_terminals_are_true(where_node): + elif check_all_terminals_are_true(where_node): continue elif not where_node: continue diff --git a/blendsql/blend.py b/blendsql/blend.py index 2af22d86..536b8358 100644 --- a/blendsql/blend.py +++ b/blendsql/blend.py @@ -31,7 +31,7 @@ get_tablename_colname, ) from ._exceptions import InvalidBlendSQL -from .db import Database +from .db import Database, DuckDB from .db.utils import double_quote_escape, select_all_from_table_query, LazyTable from ._sqlglot import ( MODIFIERS, @@ -237,6 +237,12 @@ def preprocess_blendsql(query: str, default_model: Model) -> Tuple[str, dict, se # maybe if I was better at pp.Suppress we wouldn't need this kwargs_dict = {x[0]: x[-1] for x in parsed_results_dict["kwargs"]} kwargs_dict[IngredientKwarg.MODEL] = default_model + # Heuristic check to see if we should snag the singleton arg as context + if ( + len(parsed_results_dict["args"]) == 1 + and "::" in parsed_results_dict["args"][0] + ): + kwargs_dict[IngredientKwarg.CONTEXT] = parsed_results_dict["args"].pop() context_arg = kwargs_dict.get( IngredientKwarg.CONTEXT, ( @@ -547,8 +553,25 @@ def _blend( + Fore.RESET ) try: + abstracted_df = db.execute_to_df(abstracted_query) + if isinstance(db, DuckDB): + set_of_column_names = set( + i.strip('"') for i in schema[f'"{tablename}"'] + ) + # In case of a join, duckdb formats columns with 'column_1' + # But some columns (e.g. 'parent_category') just have underscores in them already + abstracted_df = abstracted_df.rename( + columns=lambda x: re.sub(r"_\d$", "", x) + if x not in set_of_column_names # noqa: B023 + else x + ) + # In case of a join, we could have duplicate column names in our pandas dataframe + # This will throw an error when we try to write to the database + abstracted_df = abstracted_df.loc[ + :, ~abstracted_df.columns.duplicated() + ] db.to_temp_table( - df=db.execute_to_df(abstracted_query), + df=abstracted_df, tablename=_get_temp_subquery_table(tablename), ) except OperationalError as e: @@ -622,12 +645,7 @@ def _blend( if table_to_title is not None: kwargs_dict["table_to_title"] = table_to_title - # Heuristic check to see if we should snag the singleton arg as context - if ( - len(parsed_results_dict["args"]) == 1 - and "::" in parsed_results_dict["args"][0] - ): - kwargs_dict[IngredientKwarg.CONTEXT] = parsed_results_dict["args"].pop() + # Optionally, recursively call blend() again to get subtable from args # This applies to `context` and `options` for i, unpack_kwarg in enumerate( diff --git a/tests/test_multi_table_blendsql.py b/tests/test_multi_table_blendsql.py index bd29bd56..c14e5bac 100644 --- a/tests/test_multi_table_blendsql.py +++ b/tests/test_multi_table_blendsql.py @@ -90,7 +90,7 @@ def test_join_multi_exec(db, ingredients): LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol WHERE constituents.Sector = 'Information Technology' AND {{starts_with('A', 'constituents::Name')}} = 1 - AND lower(account_history.Action) like '%dividend%' + AND lower(LOWER(account_history.Action)) like '%dividend%' ORDER BY "Total Dividend Payout ($$)" """ sql = """ @@ -99,7 +99,7 @@ def test_join_multi_exec(db, ingredients): LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol WHERE constituents.Sector = 'Information Technology' AND constituents.Name LIKE 'A%' - AND lower(account_history.Action) like '%dividend%' + AND lower(LOWER(account_history.Action)) like '%dividend%' ORDER BY "Total Dividend Payout ($$)" """ smoothie = blend( @@ -165,7 +165,7 @@ def test_select_multi_exec(db, ingredients): FROM account_history LEFT JOIN constituents ON account_history.Symbol = constituents.Symbol WHERE constituents.Sector = {{select_first_sorted(options='constituents::Sector')}} - AND lower(account_history.Action) like '%dividend%' + AND lower(LOWER(account_history.Action)) like '%dividend%' """ sql = """ SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name @@ -175,7 +175,7 @@ def test_select_multi_exec(db, ingredients): SELECT Sector FROM constituents ORDER BY Sector LIMIT 1 ) - AND lower(account_history.Action) like '%dividend%' + AND lower(LOWER(account_history.Action)) like '%dividend%' """ smoothie = blend( query=blendsql, @@ -457,13 +457,13 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients): blendsql = """ SELECT {{get_length('n_length', 'constituents::Name')}} FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol - WHERE account_history.Action like '%dividend%' + WHERE LOWER(account_history.Action) like '%dividend%' ORDER BY constituents.Name """ sql = """ SELECT LENGTH(constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol - WHERE account_history.Action like '%dividend%' + WHERE LOWER(account_history.Action) like '%dividend%' ORDER BY constituents.Name """ smoothie = blend( @@ -478,7 +478,7 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients): """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol - WHERE account_history.Action like '%dividend%' + WHERE LOWER(account_history.Action) like '%dividend%' """ )[0] assert smoothie.meta.num_values_passed == passed_to_ingredient @@ -493,12 +493,12 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients) blendsql = """ SELECT {{get_length('n_length', 'constituents::Name')}}, Action FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol - WHERE Action like '%dividend%' + WHERE LOWER(Action) like '%dividend%' """ sql = """ SELECT LENGTH(constituents.Name), Action FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol - WHERE Action like '%dividend%' + WHERE LOWER(Action) like '%dividend%' """ smoothie = blend( query=blendsql, @@ -512,7 +512,7 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients) """ SELECT COUNT(DISTINCT constituents.Name) FROM constituents JOIN account_history ON account_history.Symbol = constituents.Symbol - WHERE account_history.Action like '%dividend%' + WHERE LOWER(account_history.Action) like '%dividend%' """ )[0] assert smoothie.meta.num_values_passed == passed_to_ingredient diff --git a/tests/test_single_table_blendsql.py b/tests/test_single_table_blendsql.py index 040148ef..e53806f7 100644 --- a/tests/test_single_table_blendsql.py +++ b/tests/test_single_table_blendsql.py @@ -551,5 +551,52 @@ def test_query_options_arg(db, ingredients): assert smoothie.df.values.flat[0] == "Paypal" +@pytest.mark.parametrize("db", databases) +def test_apply_limit(db, ingredients): + # commit 335c67a + blendsql = """ + SELECT {{get_length('length', 'transactions::merchant')}} FROM transactions ORDER BY merchant LIMIT 1 + """ + smoothie = blend( + query=blendsql, + db=db, + ingredients=ingredients, + ) + assert smoothie.meta.num_values_passed == 1 + + +@pytest.mark.parametrize("db", databases) +def test_apply_limit_with_predicate(db, ingredients): + # commit 335c67a + blendsql = """ + SELECT {{get_length('length', 'transactions::merchant')}} + FROM transactions + WHERE amount > 1300 + ORDER BY merchant LIMIT 3 + """ + sql = """ + SELECT LENGTH(merchant) + FROM transactions + WHERE amount > 1300 + ORDER BY merchant LIMIT 3 + """ + smoothie = blend( + query=blendsql, + db=db, + ingredients=ingredients, + ) + sql_df = db.execute_to_df(sql) + assert_equality(smoothie=smoothie, sql_df=sql_df) + # Make sure we only pass what's necessary to our ingredient + passed_to_ingredient = db.execute_to_list( + """ + SELECT COUNT(DISTINCT merchant) FROM transactions WHERE amount > 1300 LIMIT 3 """ + )[0] + # We say `<=` here because the ingredient operates over sets, rather than lists + # So this kind of screws up the `LIMIT` calculation + # But execution outputs should be the same (tested above) + assert smoothie.meta.num_values_passed <= passed_to_ingredient + + if __name__ == "__main__": pytest.main()