Skip to content

Commit

Permalink
Functioning DPR retriever for few-shot examples
Browse files Browse the repository at this point in the history
  • Loading branch information
parkervg committed Oct 15, 2024
1 parent dc81ff8 commit b1f84a1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 92 deletions.
82 changes: 41 additions & 41 deletions tests/test_multi_table_blendsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@pytest.fixture
def ingredients() -> set:
def dummy_ingredients() -> set:
return {
starts_with,
get_length,
Expand All @@ -35,7 +35,7 @@ def ingredients() -> set:


@pytest.mark.parametrize("db", databases)
def test_simple_multi_exec(db, ingredients):
def test_simple_multi_exec(db, dummy_ingredients):
"""Test with multiple tables.
Also ensures we only pass what is neccessary to the external ingredient F().
"Show me the price of tech stocks in my portfolio that start with 'A'"
Expand Down Expand Up @@ -65,7 +65,7 @@ def test_simple_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"])
Expand All @@ -83,7 +83,7 @@ def test_simple_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_join_multi_exec(db, ingredients):
def test_join_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name
FROM account_history
Expand All @@ -105,7 +105,7 @@ def test_join_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"])
Expand All @@ -119,7 +119,7 @@ def test_join_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_join_not_qualified_multi_exec(db, ingredients):
def test_join_not_qualified_multi_exec(db, dummy_ingredients):
"""Same test as test_join_multi_exec(), but without qualifying columns if we don't need to.
i.e. 'Action' and 'Sector' don't have tablename preceding them.
commit fefbc0a
Expand All @@ -145,7 +145,7 @@ def test_join_not_qualified_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"])
Expand All @@ -159,7 +159,7 @@ def test_join_not_qualified_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_select_multi_exec(db, ingredients):
def test_select_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT "Run Date", Account, Action, ROUND("Amount ($)", 2) AS 'Total Dividend Payout ($$)', Name
FROM account_history
Expand All @@ -180,14 +180,14 @@ def test_select_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_complex_multi_exec(db, ingredients):
def test_complex_multi_exec(db, dummy_ingredients):
"""
Below yields a tie in constituents.Name lengths, with 'Amgen' and 'Cisco'.
DuckDB has different sorting behavior depending on the subset that's passed?
Expand All @@ -213,14 +213,14 @@ def test_complex_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_complex_not_qualified_multi_exec(db, ingredients):
def test_complex_not_qualified_multi_exec(db, dummy_ingredients):
"""Same test as test_complex_multi_exec(), but without qualifying columns if we don't need to.
commit fefbc0a
"""
Expand All @@ -245,14 +245,14 @@ def test_complex_not_qualified_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_join_ingredient_multi_exec(db, ingredients):
def test_join_ingredient_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT Account, Quantity FROM returns
JOIN {{
Expand All @@ -268,14 +268,14 @@ def test_join_ingredient_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_qa_equals_multi_exec(db, ingredients):
def test_qa_equals_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT Action FROM account_history
WHERE Symbol = {{return_aapl()}}
Expand All @@ -287,14 +287,14 @@ def test_qa_equals_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_table_alias_multi_exec(db, ingredients):
def test_table_alias_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT Symbol FROM portfolio AS w
WHERE {{starts_with('A', 'w::Symbol')}} = TRUE
Expand All @@ -308,7 +308,7 @@ def test_table_alias_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["A"])
Expand All @@ -322,7 +322,7 @@ def test_table_alias_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_subquery_alias_multi_exec(db, ingredients):
def test_subquery_alias_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT Symbol FROM (
SELECT DISTINCT Symbol FROM portfolio WHERE Symbol IN (
Expand All @@ -340,7 +340,7 @@ def test_subquery_alias_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"])
Expand All @@ -354,7 +354,7 @@ def test_subquery_alias_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_cte_qa_multi_exec(db, ingredients):
def test_cte_qa_multi_exec(db, dummy_ingredients):
blendsql = """
{{
get_table_size(
Expand All @@ -376,7 +376,7 @@ def test_cte_qa_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"])
Expand All @@ -402,7 +402,7 @@ def test_cte_qa_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_cte_qa_named_multi_exec(db, ingredients):
def test_cte_qa_named_multi_exec(db, dummy_ingredients):
blendsql = """
{{
get_table_size(
Expand All @@ -424,7 +424,7 @@ def test_cte_qa_named_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"])
Expand All @@ -448,7 +448,7 @@ def test_cte_qa_named_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_ingredient_in_select_with_join_multi_exec(db, ingredients):
def test_ingredient_in_select_with_join_multi_exec(db, dummy_ingredients):
"""If the query only has an ingredient in the `SELECT` statement, and `JOIN` clause,
we should run the `JOIN` statement first, and then call the ingredient.
Expand All @@ -469,7 +469,7 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)
Expand All @@ -485,7 +485,7 @@ def test_ingredient_in_select_with_join_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients):
def test_ingredient_in_select_with_join_multi_select_multi_exec(db, dummy_ingredients):
"""A modified version of the above
commit de4a7bc
Expand All @@ -503,7 +503,7 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients)
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)
Expand All @@ -519,7 +519,7 @@ def test_ingredient_in_select_with_join_multi_select_multi_exec(db, ingredients)


@pytest.mark.parametrize("db", databases)
def test_subquery_alias_with_join_multi_exec(db, ingredients):
def test_subquery_alias_with_join_multi_exec(db, dummy_ingredients):
blendsql = """
SELECT w."Percent of Account" FROM (SELECT * FROM "portfolio" WHERE Quantity > 200 OR "Today's Gain/Loss Percent" > 0.05) as w
JOIN {{
Expand All @@ -540,7 +540,7 @@ def test_subquery_alias_with_join_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df, args=["F"])
Expand All @@ -554,7 +554,7 @@ def test_subquery_alias_with_join_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_materialize_ctes_multi_exec(db, ingredients):
def test_materialize_ctes_multi_exec(db, dummy_ingredients):
"""We shouldn't create materialized CTE tables if they aren't used in an ingredient.
commit dba7540
Expand All @@ -572,7 +572,7 @@ def test_materialize_ctes_multi_exec(db, ingredients):
_ = _blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
assert db.has_temp_table("a")
assert not db.has_temp_table("b")
Expand All @@ -581,7 +581,7 @@ def test_materialize_ctes_multi_exec(db, ingredients):


@pytest.mark.parametrize("db", databases)
def test_options_referencing_cte_multi_exec(db, ingredients):
def test_options_referencing_cte_multi_exec(db, dummy_ingredients):
"""You should be able to reference a CTE in a QAIngredient `options` argument.
f849ed3
Expand All @@ -603,14 +603,14 @@ def test_options_referencing_cte_multi_exec(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_infer_options_arg(db, ingredients):
def test_infer_options_arg(db, dummy_ingredients):
"""The infer_gen_constraints function should extend to cases when we do a
`column = {{QAIngredient()}}` predicate.
Expand All @@ -627,14 +627,14 @@ def test_infer_options_arg(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_join_with_multiple_ingredients(db, ingredients):
def test_join_with_multiple_ingredients(db, dummy_ingredients):
"""
af86714
"""
Expand All @@ -660,14 +660,14 @@ def test_join_with_multiple_ingredients(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)


@pytest.mark.parametrize("db", databases)
def test_null_negation(db, ingredients):
def test_null_negation(db, dummy_ingredients):
blendsql = """
SELECT DISTINCT Name FROM constituents
LEFT JOIN account_history ON constituents.Symbol = account_history.Symbol
Expand All @@ -685,7 +685,7 @@ def test_null_negation(db, ingredients):
smoothie = blend(
query=blendsql,
db=db,
ingredients=ingredients,
ingredients=dummy_ingredients,
)
sql_df = db.execute_to_df(sql)
assert_equality(smoothie=smoothie, sql_df=sql_df)
Loading

0 comments on commit b1f84a1

Please sign in to comment.