diff --git a/.codespellrc b/.codespellrc index 1f3cfdd7f4aa5..79c092b8a5cb4 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,6 +1,6 @@ [codespell] # local codespell matches `./docs`, pre-commit codespell matches `docs` skip = *.lock,.direnv,.git,./docs/_freeze,./docs/_output/**,./docs/_inv/**,docs/_freeze/**,*.svg,*.css,*.html,*.js,ibis/backends/tests/tpc/queries/duckdb/ds/*.sql -ignore-regex = \b(i[if]f|I[IF]F|AFE|alls)\b +ignore-regex = \b(i[if]f|I[IF]F|AFE|alls|ND)\b builtin = clear,rare,names ignore-words-list = tim,notin,ang diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index 02b198ab930a3..495de8c665c64 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -1400,13 +1400,42 @@ def visit_JoinLink(self, op, *, how, table, predicates): def _generate_groups(groups): return map(sge.convert, range(1, len(groups) + 1)) - def visit_Aggregate(self, op, *, parent, groups, metrics): - sel = sg.select( - *self._cleanup_names(groups), *self._cleanup_names(metrics), copy=False + def _compile_agg_select(self, op, *, parent, keys, metrics): + return sg.select( + *self._cleanup_names(keys), *self._cleanup_names(metrics), copy=False ).from_(parent, copy=False) - if groups: - sel = sel.group_by(*self._generate_groups(groups.values()), copy=False) + def _compile_group_by(self, sel, *, groups, grouping_sets, rollups, cubes): + expressions = list(self._generate_groups(groups.values())) + group = sge.Group( + expressions=expressions, + grouping_sets=[ + sge.GroupingSets( + expressions=[ + sge.Tuple(expressions=expressions) + for expressions in grouping_set + ] + ) + for grouping_set in grouping_sets + ], + rollup=[sge.Rollup(expressions=rollup) for rollup in rollups], + cube=[sge.Cube(expressions=cube) for cube in cubes], + ) + return sel.group_by(group, copy=False) + + def visit_Aggregate( + self, op, *, parent, keys, groups, metrics, grouping_sets, rollups, cubes + ): + sel = self._compile_agg_select(op, parent=parent, keys=keys, metrics=metrics) + + if groups or grouping_sets or rollups or cubes: + sel = self._compile_group_by( + sel, + groups=groups, + grouping_sets=grouping_sets, + rollups=rollups, + cubes=cubes, + ) return sel @@ -1609,6 +1638,9 @@ def _make_sample_backwards_compatible(self, *, sample, parent): parent.args["sample"] = sample return sg.select(STAR).from_(parent) + def visit_GroupID(self, op, *, arg): + return self.f.grouping(*arg) + # `__init_subclass__` is uncalled for subclasses - we manually call it here to # autogenerate the base class implementations as well. diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index cfa174d7ba188..3be7b04821a52 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -464,12 +464,12 @@ def visit_ArgMax(self, op, *, arg, key, where): arg, where=where, order_by=[sge.Ordered(this=key, desc=True)] ) - def visit_Aggregate(self, op, *, parent, groups, metrics): + def _compile_agg_select(self, op, *, parent, keys, metrics): """Support `GROUP BY` expressions in `SELECT` since DataFusion does not.""" quoted = self.quoted metrics = tuple(self._cleanup_names(metrics)) - if groups: + if keys: # datafusion doesn't support count distinct aggregations alongside # computed grouping keys so create a projection of the key and all # existing columns first, followed by the usual group by @@ -484,11 +484,11 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): ), # can't use set subtraction here since the schema keys' # order matters and set subtraction doesn't preserve order - (k for k in op.parent.schema.keys() if k not in groups), + (k for k in op.parent.schema.keys() if k not in keys), ) ) table = ( - sg.select(*cols, *self._cleanup_names(groups)) + sg.select(*cols, *self._cleanup_names(keys)) .from_(parent) .subquery(parent.alias) ) @@ -497,19 +497,14 @@ def visit_Aggregate(self, op, *, parent, groups, metrics): # quoted=True is required here for correctness by_names_quoted = tuple( sg.column(key, table=getattr(value, "table", None), quoted=quoted) - for key, value in groups.items() + for key, value in keys.items() ) selections = by_names_quoted + metrics else: selections = metrics or (STAR,) table = parent - sel = sg.select(*selections).from_(table) - - if groups: - sel = sel.group_by(*by_names_quoted) - - return sel + return sg.select(*selections).from_(table) def visit_StructColumn(self, op, *, names, values): args = [] diff --git a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/datafusion/out.sql b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/datafusion/out.sql index 703ef7e85d34d..3585d09c1d000 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/datafusion/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_group_by_has_index/datafusion/out.sql @@ -25,4 +25,4 @@ FROM ( FROM "countries" AS "t0" ) AS t0 GROUP BY - "cont" \ No newline at end of file + 1 \ No newline at end of file diff --git a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/datafusion/out.sql b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/datafusion/out.sql index ddd0af6c0c025..f5381dbc67da1 100644 --- a/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/datafusion/out.sql +++ b/ibis/backends/tests/snapshots/test_sql/test_union_aliasing/datafusion/out.sql @@ -63,7 +63,7 @@ WITH "t5" AS ( ) AS "t4" ) AS t4 GROUP BY - "t4"."field_of_study" + 1 ) SELECT * diff --git a/ibis/backends/tests/tpc/ds/test_queries.py b/ibis/backends/tests/tpc/ds/test_queries.py index 40b23e4ade0b6..a23563657b992 100644 --- a/ibis/backends/tests/tpc/ds/test_queries.py +++ b/ibis/backends/tests/tpc/ds/test_queries.py @@ -6,7 +6,18 @@ import pytest import ibis -from ibis import _, coalesce, cumulative_window, date, ifelse, null, rank, union +from ibis import ( + _, + coalesce, + cumulative_window, + date, + group_id, + ifelse, + null, + rank, + rollup, + union, +) from ibis import literal as lit from ibis import selectors as s from ibis.backends.tests.errors import ClickHouseDatabaseError, TrinoUserError @@ -212,7 +223,7 @@ def profile(sales, *, name): @tpc_test("ds") -@pytest.mark.xfail(raises=NotImplementedError, reason="requires rollup") +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="Ambiguous reference") def test_05( store_sales, store_returns, @@ -225,7 +236,130 @@ def test_05( web_site, date_dim, ): - raise NotImplementedError() + ssr = ( + store_sales.select( + store_sk=_.ss_store_sk, + date_sk=_.ss_sold_date_sk, + sales_price=_.ss_ext_sales_price, + profit=_.ss_net_profit, + return_amt=lit(0, type="decimal(7, 2)"), + net_loss=lit(0, type="decimal(7, 2)"), + ) + .union( + store_returns.select( + store_sk=_.sr_store_sk, + date_sk=_.sr_returned_date_sk, + sales_price=lit(0, type="decimal(7, 2)"), + profit=lit(0, type="decimal(7, 2)"), + return_amt=_.sr_return_amt, + net_loss=_.sr_net_loss, + ) + ) + .join(date_dim, [("date_sk", "d_date_sk")]) + .join(store, [("store_sk", "s_store_sk")]) + .filter(_.d_date.between(date("2000-08-23"), date("2000-09-06"))) + .group_by(_.s_store_id) + .agg( + sales=_.sales_price.sum(), + profit=_.profit.sum(), + returns_=_.return_amt.sum(), + profit_loss=_.net_loss.sum(), + ) + ) + csr = ( + catalog_sales.select( + page_sk=_.cs_catalog_page_sk, + date_sk=_.cs_sold_date_sk, + sales_price=_.cs_ext_sales_price, + profit=_.cs_net_profit, + return_amt=lit(0, type="decimal(7, 2)"), + net_loss=lit(0, type="decimal(7, 2)"), + ) + .union( + catalog_returns.select( + page_sk=_.cr_catalog_page_sk, + date_sk=_.cr_returned_date_sk, + sales_price=lit(0, type="decimal(7, 2)"), + profit=lit(0, type="decimal(7, 2)"), + return_amt=_.cr_return_amount, + net_loss=_.cr_net_loss, + ) + ) + .join(date_dim, [("date_sk", "d_date_sk")]) + .join(catalog_page, [("page_sk", "cp_catalog_page_sk")]) + .filter(_.d_date.between(date("2000-08-23"), date("2000-09-06"))) + .group_by(_.cp_catalog_page_id) + .agg( + sales=_.sales_price.sum(), + profit=_.profit.sum(), + returns_=_.return_amt.sum(), + profit_loss=_.net_loss.sum(), + ) + ) + wsr = ( + web_sales.select( + wsr_web_site_sk=_.ws_web_site_sk, + date_sk=_.ws_sold_date_sk, + sales_price=_.ws_ext_sales_price, + profit=_.ws_net_profit, + return_amt=lit(0, type="decimal(7, 2)"), + net_loss=lit(0, type="decimal(7, 2)"), + ) + .union( + web_returns.left_join( + web_sales, + [("wr_item_sk", "ws_item_sk"), ("wr_order_number", "ws_order_number")], + ).select( + wsr_web_site_sk=_.ws_web_site_sk, + date_sk=_.wr_returned_date_sk, + sales_price=lit(0, type="decimal(7, 2)"), + profit=lit(0, type="decimal(7, 2)"), + return_amt=_.wr_return_amt, + net_loss=_.wr_net_loss, + ) + ) + .join(date_dim, [("date_sk", "d_date_sk")]) + .join(web_site, [("wsr_web_site_sk", "web_site_sk")]) + .filter(_.d_date.between(date("2000-08-23"), date("2000-09-06"))) + .group_by(_.web_site_id) + .agg( + sales=_.sales_price.sum(), + profit=_.profit.sum(), + returns_=_.return_amt.sum(), + profit_loss=_.net_loss.sum(), + ) + ) + + return ( + ssr.select( + _.sales, + _.returns_, + channel=lit("store channel"), + id="store" + _.s_store_id, + profit=_.profit - _.profit_loss, + ) + .union( + csr.select( + _.sales, + _.returns_, + channel=lit("catalog channel"), + id="catalog_page" + _.cp_catalog_page_id, + profit=_.profit - _.profit_loss, + ), + wsr.select( + _.sales, + _.returns_, + channel=lit("web channel"), + id="web_site" + _.web_site_id, + profit=_.profit - _.profit_loss, + ), + ) + .group_by(rollup(_.channel, _.id)) + .agg(sales=_.sales.sum(), returns_=_.returns_.sum(), profit=_.profit.sum()) + .mutate(s.across(s.of_type("string"), _.nullif(""))) + .order_by(_.channel.asc(nulls_first=True), _.id.asc(nulls_first=True)) + .limit(100) + ) @pytest.mark.notyet( @@ -997,9 +1131,119 @@ def test_13( @tpc_test("ds") -@pytest.mark.xfail(raises=NotImplementedError, reason="requires rollup") def test_14(item, store_sales, date_dim, catalog_sales, web_sales): - raise NotImplementedError() + def make_cross_items(sales, *, sold_date_sk, item_sk): + return ( + sales.join(item, [(item_sk, "i_item_sk")]) + .join(date_dim, [(sold_date_sk, "d_date_sk")]) + .filter(_.d_year == 1999 + 2) + .select( + brand_id=_.i_brand_id, + class_id=_.i_class_id, + category_id=_.i_category_id, + ) + ) + + def make_avg_sales(sales, *, sold_date_sk, quantity, list_price): + return ( + sales.join(date_dim, [(sold_date_sk, "d_date_sk")]) + .filter(_.d_year.between(1999, 1999 + 2)) + .select(quantity=quantity, list_price=list_price) + ) + + def item_sales(sales, *, item_sk, sold_date_sk, channel, quantity, list_price): + return ( + sales.join(item, [(item_sk, "i_item_sk")]) + .join(date_dim, [(sold_date_sk, "d_date_sk")]) + .filter( + item_sk.isin(cross_items.ss_item_sk), + _.d_year == 1999 + 2, + _.d_moy == 11, + ) + .group_by( + channel=channel, + i_brand_id=_.i_brand_id, + i_class_id=_.i_class_id, + i_category_id=_.i_category_id, + ) + .having((quantity * list_price).sum() > avg_sales.average_sales.as_scalar()) + .agg(sales=(quantity * list_price).sum(), number_sales=_.count()) + ) + + cross_items = item.join( + make_cross_items( + store_sales, sold_date_sk=_.ss_sold_date_sk, item_sk=_.ss_item_sk + ).intersect( + make_cross_items( + catalog_sales, sold_date_sk=_.cs_sold_date_sk, item_sk=_.cs_item_sk + ), + make_cross_items( + web_sales, sold_date_sk=_.ws_sold_date_sk, item_sk=_.ws_item_sk + ), + ), + [ + ("i_brand_id", "brand_id"), + ("i_class_id", "class_id"), + ("i_category_id", "category_id"), + ], + ).select(ss_item_sk=_.i_item_sk) + + avg_sales = ( + make_avg_sales( + store_sales, + sold_date_sk=_.ss_sold_date_sk, + quantity=_.ss_quantity, + list_price=_.ss_list_price, + ) + .union( + make_avg_sales( + catalog_sales, + sold_date_sk=_.cs_sold_date_sk, + quantity=_.cs_quantity, + list_price=_.cs_list_price, + ), + make_avg_sales( + web_sales, + sold_date_sk=_.ws_sold_date_sk, + quantity=_.ws_quantity, + list_price=_.ws_list_price, + ), + ) + .agg(average_sales=(_.quantity * _.list_price).mean()) + ) + + return ( + item_sales( + store_sales, + item_sk=_.ss_item_sk, + sold_date_sk=_.ss_sold_date_sk, + channel=lit("store"), + quantity=_.ss_quantity, + list_price=_.ss_list_price, + ) + .union( + item_sales( + catalog_sales, + item_sk=_.cs_item_sk, + sold_date_sk=_.cs_sold_date_sk, + channel=lit("catalog"), + quantity=_.cs_quantity, + list_price=_.cs_list_price, + ), + item_sales( + web_sales, + item_sk=_.ws_item_sk, + sold_date_sk=_.ws_sold_date_sk, + channel=lit("web"), + quantity=_.ws_quantity, + list_price=_.ws_list_price, + ), + ) + .group_by(rollup(_.channel, _.i_brand_id, _.i_class_id, _.i_category_id)) + .agg(sum_sales=_.sales.sum(), sum_number_sales=_.number_sales.sum()) + .order_by(s.across(~s.endswith("_sales"), _.asc(nulls_first=True))) + .limit(100) + ) @tpc_test("ds") @@ -1136,11 +1380,45 @@ def test_17(store_sales, store_returns, catalog_sales, date_dim, store, item): @tpc_test("ds") -@pytest.mark.xfail(raises=NotImplementedError, reason="requires rollup") def test_18( catalog_sales, customer_demographics, customer, customer_address, date_dim, item ): - raise NotImplementedError() + cd1 = customer_demographics + return ( + catalog_sales.join(date_dim, [("cs_sold_date_sk", "d_date_sk")]) + .join(item, [("cs_item_sk", "i_item_sk")]) + .join(cd1, [("cs_bill_cdemo_sk", "cd_demo_sk")]) + .join(customer, [("cs_bill_customer_sk", "c_customer_sk")]) + .join( + customer_demographics[["cd_demo_sk"]], + [("c_current_cdemo_sk", "cd_demo_sk")], + ) + .join(customer_address, [("c_current_addr_sk", "ca_address_sk")]) + .filter( + cd1.cd_gender == "F", + cd1.cd_education_status == "Unknown", + _.c_birth_month.isin((1, 6, 8, 9, 12, 2)), + _.d_year == 1998, + _.ca_state.isin(("MS", "IN", "ND", "OK", "NM", "VA", "MS")), + ) + .group_by(rollup(_.i_item_id, _.ca_country, _.ca_state, _.ca_county)) + .agg( + agg1=_.cs_quantity.cast("decimal(12, 2)").mean(), + agg2=_.cs_list_price.cast("decimal(12, 2)").mean(), + agg3=_.cs_coupon_amt.cast("decimal(12, 2)").mean(), + agg4=_.cs_sales_price.cast("decimal(12, 2)").mean(), + agg5=_.cs_net_profit.cast("decimal(12, 2)").mean(), + agg6=_.c_birth_year.cast("decimal(12, 2)").mean(), + agg7=_.cd_dep_count.cast("decimal(12, 2)").mean(), + ) + .order_by( + _.ca_country.asc(nulls_first=True), + _.ca_state.asc(nulls_first=True), + _.ca_county.asc(nulls_first=True), + _.i_item_id.asc(nulls_first=True), + ) + .limit(100) + ) @tpc_test("ds") @@ -1227,9 +1505,19 @@ def test_21(inventory, warehouse, item, date_dim): @tpc_test("ds") -@pytest.mark.xfail(raises=NotImplementedError, reason="requires rollup") def test_22(inventory, date_dim, item): - raise NotImplementedError() + return ( + inventory.join(date_dim, [("inv_date_sk", "d_date_sk")]) + .join(item, [("inv_item_sk", "i_item_sk")]) + .filter(_.d_month_seq.between(1200, 1200 + 11)) + .group_by(rollup(_.i_product_name, _.i_brand, _.i_class, _.i_category)) + .agg(qoh=_.inv_quantity_on_hand.mean()) + .order_by( + _.qoh.asc(nulls_first=True), + s.across(~s.cols("qoh"), _.asc(nulls_first=True)), + ) + .limit(100) + ) @tpc_test("ds") @@ -3566,10 +3854,34 @@ def agg_sales_net_by_month(sales, ns, sales_expr, net_expr): ) -@pytest.mark.xfail(raises=NotImplementedError, reason="requires rollup") @tpc_test("ds") +@pytest.mark.notyet( + ["clickhouse"], reason="clickhouse returns the wrong result", raises=AssertionError +) def test_67(store_sales, date_dim, store, item): - raise NotImplementedError() + return ( + store_sales.join(date_dim, [("ss_sold_date_sk", "d_date_sk")]) + .join(store, [("ss_store_sk", "s_store_sk")]) + .join(item, [("ss_item_sk", "i_item_sk")]) + .filter(_.d_month_seq.between(1200, 1200 + 11)) + .group_by( + rollup( + _.i_category, + _.i_class, + _.i_brand, + _.i_product_name, + _.d_year, + _.d_qoy, + _.d_moy, + _.s_store_id, + ) + ) + .agg(sumsales=(_.ss_sales_price * _.ss_quantity).coalesce(0).sum()) + .mutate(rk=1 + rank().over(group_by=_.i_category, order_by=_.sumsales.desc())) + .filter(_.rk <= 100) + .order_by(s.across(s.all(), _.asc(nulls_first=True))) + .limit(100) + ) @tpc_test("ds") @@ -3698,14 +4010,53 @@ def test_69( @pytest.mark.notyet( ["trino"], raises=TrinoUserError, reason="grouping() is not allowed in order by" ) -@pytest.mark.notimpl( - ["duckdb", "clickhouse", "snowflake", "datafusion"], - raises=NotImplementedError, - reason="requires rollup", -) +@pytest.mark.notyet(["datafusion"], raises=Exception, reason="grouping not implemented") @tpc_test("ds") def test_70(store_sales, date_dim, store): - raise NotImplementedError() + return ( + store_sales.join(date_dim, [("ss_sold_date_sk", "d_date_sk")]) + .join(store, [("ss_store_sk", "s_store_sk")]) + .filter( + _.d_month_seq.between(1200, 1200 + 11), + _.s_state.isin( + store_sales.join(store, [("ss_store_sk", "s_store_sk")]) + .join(date_dim, [("ss_sold_date_sk", "d_date_sk")]) + .filter(_.d_month_seq.between(1200, 1200 + 11)) + .group_by(_.s_state) + .agg(net_profit=_.ss_net_profit.sum()) + .mutate( + ranking=rank().over( + group_by=_.s_state, order_by=_.net_profit.desc() + ) + + 1 + ) + .filter(_.ranking <= 5) + .s_state + ), + ) + .group_by(rollup(_.s_state, _.s_county)) + .agg( + total_sum=_.ss_net_profit.sum(), + lochierarchy=group_id(_.s_state) + group_id(_.s_county), + county_group_id=group_id(_.s_county), + ) + .mutate( + rank_within_parent=rank().over( + group_by=(_.lochierarchy, ifelse(_.county_group_id, _.s_state, null())), + order_by=_.total_sum.desc(), + ) + + 1 + ) + .select( + _.total_sum, _.s_state, _.s_county, _.lochierarchy, _.rank_within_parent + ) + .order_by( + _.lochierarchy.desc(), + ifelse(_.lochierarchy == 0, _.s_state, null()), + _.rank_within_parent, + ) + .limit(100) + ) @tpc_test("ds") @@ -4193,8 +4544,7 @@ def _sales( return expr -@pytest.mark.xfail(raises=AttributeError, reason="requires rollup") -@tpc_test("ds", result_is_empty=True) +@tpc_test("ds") def test_77( store_sales, date_dim, @@ -4287,7 +4637,7 @@ def test_77( channel=ibis.literal("store channel"), id=_.s_store_sk, sales=_.sales, - returns=_.returns_.coalesce(0), + returns_=_.returns_.coalesce(0), profit=(_.profit - _.profit_loss.coalesce(0)), ) @@ -4295,7 +4645,7 @@ def test_77( channel=ibis.literal("catalog channel"), id=_.cs_call_center_sk, sales=_.sales, - returns=_.returns_, + returns_=_.returns_, profit=(_.profit - _.profit_loss), ) @@ -4303,13 +4653,13 @@ def test_77( channel=ibis.literal("web channel"), id=_.wp_web_page_sk, sales=_.sales, - returns=_.returns_.coalesce(0), + returns_=_.returns_.coalesce(0), profit=(_.profit - _.profit_loss.coalesce(0)), ) expr = ( x1.union(x2, x3) - .group_by(ibis.rollup("channel", "id")) + .group_by(rollup(_.channel, _.id)) .agg( sales=_.sales.sum(), returns_=_.returns_.sum(), @@ -4509,7 +4859,7 @@ def test_79(store_sales, date_dim, store, household_demographics, customer): @tpc_test("ds") -@pytest.mark.xfail(raises=NotImplementedError, reason="requires rollup") +@pytest.mark.notyet(["datafusion"], reason="broken referencing inside datafusion") def test_80( store_sales, store_returns, @@ -4524,7 +4874,122 @@ def test_80( item, promotion, ): - raise NotImplementedError() + def sr( + *, + sales, + returns, + dimension, + sales_item_sk, + returns_item_sk, + sales_order_number, + returns_order_number, + sales_sold_date_sk, + dimension_sk, + sales_dimension_sk, + sales_promo_sk, + grouping_key, + ext_sales_price, + return_amt, + net_profit, + net_loss, + ): + return ( + sales.left_join( + returns, + [ + (sales_item_sk, returns_item_sk), + (sales_order_number, returns_order_number), + ], + ) + .join(date_dim, [(sales_sold_date_sk, "d_date_sk")]) + .join(dimension, [(sales_dimension_sk, dimension_sk)]) + .join(item, [(sales_item_sk, "i_item_sk")]) + .join(promotion, [(sales_promo_sk, "p_promo_sk")]) + .filter( + _.d_date.between(date("2000-08-23"), date("2000-09-22")), + _.i_current_price > 50, + _.p_channel_tv == "N", + ) + .group_by(grouping_key) + .agg( + sales=ext_sales_price.sum(), + returns_=return_amt.coalesce(0).sum(), + profit=(net_profit - net_loss.coalesce(0)).sum(), + ) + ) + + ssr = sr( + sales=store_sales, + returns=store_returns, + dimension=store, + sales_item_sk=_.ss_item_sk, + returns_item_sk=_.sr_item_sk, + sales_order_number=_.ss_ticket_number, + returns_order_number=_.sr_ticket_number, + sales_sold_date_sk=_.ss_sold_date_sk, + dimension_sk=_.s_store_sk, + sales_dimension_sk=_.ss_store_sk, + sales_promo_sk=_.ss_promo_sk, + grouping_key=_.s_store_id.name("store_id"), + ext_sales_price=_.ss_ext_sales_price, + return_amt=_.sr_return_amt, + net_profit=_.ss_net_profit, + net_loss=_.sr_net_loss, + ) + csr = sr( + sales=catalog_sales, + returns=catalog_returns, + dimension=catalog_page, + sales_item_sk=_.cs_item_sk, + returns_item_sk=_.cr_item_sk, + sales_order_number=_.cs_order_number, + returns_order_number=_.cr_order_number, + sales_sold_date_sk=_.cs_sold_date_sk, + dimension_sk=_.cp_catalog_page_sk, + sales_dimension_sk=_.cs_catalog_page_sk, + sales_promo_sk=_.cs_promo_sk, + grouping_key=_.cp_catalog_page_id.name("catalog_page_id"), + ext_sales_price=_.cs_ext_sales_price, + return_amt=_.cr_return_amount, + net_profit=_.cs_net_profit, + net_loss=_.cr_net_loss, + ) + wsr = sr( + sales=web_sales, + returns=web_returns, + dimension=web_site, + sales_item_sk=_.ws_item_sk, + returns_item_sk=_.wr_item_sk, + sales_order_number=_.ws_order_number, + returns_order_number=_.wr_order_number, + sales_sold_date_sk=_.ws_sold_date_sk, + dimension_sk=_.web_site_sk, + sales_dimension_sk=_.ws_web_site_sk, + sales_promo_sk=_.ws_promo_sk, + grouping_key=_.web_site_id, + ext_sales_price=_.ws_ext_sales_price, + return_amt=_.wr_return_amt, + net_profit=_.ws_net_profit, + net_loss=_.wr_net_loss, + ) + + return ( + ssr.mutate(channel=lit("store channel"), id="store" + _.store_id) + .drop("store_id") + .union( + csr.mutate( + channel=lit("catalog channel"), id="catalog_page" + _.catalog_page_id + ).drop("catalog_page_id"), + wsr.mutate(channel=lit("web channel"), id="web_site" + _.web_site_id).drop( + "web_site_id" + ), + ) + .group_by(rollup(_.channel, _.id)) + .agg(sales=_.sales.sum(), returns_=_.returns_.sum(), profit=_.profit.sum()) + .mutate(s.across(s.cols("channel", "id"), _.nullif(""))) + .order_by(_.channel.asc(nulls_first=True), _.id.asc(nulls_first=True)) + .limit(100) + ) @pytest.mark.notyet( @@ -4764,12 +5229,41 @@ def test_85( reason="doesn't support grouping function in order_by", ) @pytest.mark.notimpl( - ["snowflake", "duckdb", "datafusion", "clickhouse"], - raises=NotImplementedError, - reason="requires rollup", + ["datafusion"], raises=Exception, reason="grouping function not implemented" ) def test_86(web_sales, date_dim, item): - raise NotImplementedError() + return ( + web_sales.join( + date_dim.filter(_.d_month_seq.between(1200, 1200 + 11)), + [("ws_sold_date_sk", "d_date_sk")], + ) + .join(item, [("ws_item_sk", "i_item_sk")]) + .group_by(rollup(_.i_category, _.i_class)) + .agg( + total_sum=_.ws_net_paid.sum(), + lochierarchy=group_id(_.i_category) + group_id(_.i_class), + class_grouping=group_id(_.i_class), + ) + .mutate( + rank_within_parent=rank().over( + group_by=( + _.lochierarchy, + ifelse(_.class_grouping == 0, _.i_category, null()), + ), + order_by=_.total_sum.desc(), + ) + + 1 + ) + .select( + _.total_sum, _.i_category, _.i_class, _.lochierarchy, _.rank_within_parent + ) + .order_by( + _.lochierarchy.desc(nulls_first=True), + ifelse(_.lochierarchy == 0, _.i_category, null()).asc(nulls_first=True), + _.rank_within_parent.asc(nulls_first=True), + ) + .limit(100) + ) @tpc_test("ds") diff --git a/ibis/backends/tests/tpc/queries/clickhouse/ds/18.sql b/ibis/backends/tests/tpc/queries/clickhouse/ds/18.sql new file mode 100644 index 0000000000000..a7b523bb94867 --- /dev/null +++ b/ibis/backends/tests/tpc/queries/clickhouse/ds/18.sql @@ -0,0 +1,49 @@ +SELECT i_item_id, + ca_country, + ca_state, + ca_county, + avg(cast(cs_quantity AS decimal(12, 2))) agg1, + avg(cast(cs_list_price AS decimal(12, 2))) agg2, + avg(cast(cs_coupon_amt AS decimal(12, 2))) agg3, + avg(cast(cs_sales_price AS decimal(12, 2))) agg4, + avg(cast(cs_net_profit AS decimal(12, 2))) agg5, + avg(cast(c_birth_year AS decimal(12, 2))) agg6, + avg(cast(cd1.cd_dep_count AS decimal(12, 2))) agg7 +FROM catalog_sales +JOIN customer_demographics cd1 + ON cs_bill_cdemo_sk = cd1.cd_demo_sk +JOIN customer + ON cs_bill_customer_sk = c_customer_sk +JOIN customer_demographics cd2 + ON c_current_cdemo_sk = cd2.cd_demo_sk +JOIN customer_address + ON c_current_addr_sk = ca_address_sk +JOIN date_dim + ON cs_sold_date_sk = d_date_sk +JOIN item + ON cs_item_sk = i_item_sk +WHERE cd1.cd_gender = 'F' + AND cd1.cd_education_status = 'Unknown' + AND c_birth_month IN (1, + 6, + 8, + 9, + 12, + 2) + AND d_year = 1998 + AND ca_state IN ('MS', + 'IN', + 'ND', + 'OK', + 'NM', + 'VA', + 'MS') +GROUP BY ROLLUP (i_item_id, + ca_country, + ca_state, + ca_county) +ORDER BY ca_country NULLS FIRST, + ca_state NULLS FIRST, + ca_county NULLS FIRST, + i_item_id NULLS FIRST +LIMIT 100; diff --git a/ibis/backends/tests/tpc/queries/duckdb/ds/05.sql b/ibis/backends/tests/tpc/queries/duckdb/ds/05.sql index 7c9342dce4c5d..5598c23f43421 100644 --- a/ibis/backends/tests/tpc/queries/duckdb/ds/05.sql +++ b/ibis/backends/tests/tpc/queries/duckdb/ds/05.sql @@ -81,8 +81,8 @@ WITH ssr AS AND d_date BETWEEN cast('2000-08-23' AS date) AND cast('2000-09-06' AS date) AND wsr_web_site_sk = web_site_sk GROUP BY web_site_id) -SELECT channel , - id , +SELECT nullif(channel, '') AS channel , + nullif(id, '') AS id , sum(sales) AS sales , sum(returns_) AS returns_ , sum(profit) AS profit diff --git a/ibis/backends/tests/tpc/queries/duckdb/ds/80.sql b/ibis/backends/tests/tpc/queries/duckdb/ds/80.sql index f06c35b1250a0..6a1f96f47903a 100644 --- a/ibis/backends/tests/tpc/queries/duckdb/ds/80.sql +++ b/ibis/backends/tests/tpc/queries/duckdb/ds/80.sql @@ -55,8 +55,8 @@ WITH ssr AS AND ws_promo_sk = p_promo_sk AND p_channel_tv = 'N' GROUP BY web_site_id) -SELECT channel , - id , +SELECT nullif(channel, '') AS channel , + nullif(id, '') AS id , sum(sales) AS sales , sum(returns_) AS returns_ , sum(profit) AS profit diff --git a/ibis/expr/api.py b/ibis/expr/api.py index d03c48d4ccb02..6d3aad6d645b6 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -41,6 +41,7 @@ null, struct, ) +from ibis.expr.types.groupby import cube, group_id, grouping_sets, rollup from ibis.util import deprecated, experimental if TYPE_CHECKING: @@ -72,6 +73,7 @@ "coalesce", "connect", "cross_join", + "cube", "cume_dist", "cumulative_window", "date", @@ -85,6 +87,8 @@ "following", "get_backend", "greatest", + "group_id", + "grouping_sets", "ifelse", "infer_dtype", "infer_schema", @@ -112,6 +116,7 @@ "read_delta", "read_json", "read_parquet", + "rollup", "row_number", "rows_window", "schema", diff --git a/ibis/expr/decompile.py b/ibis/expr/decompile.py index 3f4eb578e90dd..06259aa47f95a 100644 --- a/ibis/expr/decompile.py +++ b/ibis/expr/decompile.py @@ -171,16 +171,25 @@ def sort(op, parent, keys): @translate.register(ops.Aggregate) -def aggregation(op, parent, groups, metrics): +def aggregation(op, parent, metrics, keys, groups, grouping_sets, rollups, cubes): groups = _wrap_alias(op.groups, groups) metrics = _wrap_alias(op.metrics, metrics) - if groups and metrics: - return f"{parent}.aggregate([{_inline(metrics)}], by=[{_inline(groups)}])" - elif metrics: - return f"{parent}.aggregate([{_inline(metrics)}])" - else: + + if not metrics: raise ValueError("No metrics to aggregate") + args = [f"[{_inline(metrics)}]"] + + if groups: + args.append(f"by=[{_inline(groups)}]") + + if grouping_sets or rollups or cubes: + raise NotImplementedError( + "grouping_sets, rollups, and cubes not yet implemented in the decompiler" + ) + + return f"{parent}.aggregate({', '.join(args)})" + @translate.register(ops.Distinct) def distinct(op, parent): diff --git a/ibis/expr/operations/generic.py b/ibis/expr/operations/generic.py index 0efabf0473079..5042cae96b724 100644 --- a/ibis/expr/operations/generic.py +++ b/ibis/expr/operations/generic.py @@ -336,4 +336,14 @@ def dtype(self): return rlz.highest_precedence_dtype(exprs) +@public +class GroupID(Value): + arg: Annotated[VarTuple[Value], Length(at_least=1)] + + dtype = dt.int64 + + # scalar because it's always used in grouping context + shape = ds.scalar + + public(NULL=NULL) diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 8bd06eac215cc..c0dd4bac4a2cb 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -307,21 +307,40 @@ class Aggregate(Relation): """Aggregate a table by a set of group by columns and metrics.""" parent: Relation - groups: FrozenOrderedDict[str, Unaliased[Value]] + metrics: FrozenOrderedDict[str, Unaliased[Scalar]] - def __init__(self, parent, groups, metrics): - _check_integrity(groups.values(), {parent}) + # duplication is needed so that the compiler will compile the elements of keys/groups + keys: FrozenOrderedDict[str, Unaliased[Value]] + """Values to always output in a projection. Unique expressions across `groups` and `grouping_sets` and friends.""" + + groups: FrozenOrderedDict[str, Unaliased[Value]] + """Grouping keys. Equivalent to `keys` when no grouping sets, rollups, or cubes are present.""" + + grouping_sets: VarTuple[VarTuple[VarTuple[Value]]] = () + rollups: VarTuple[VarTuple[Value]] = () + cubes: VarTuple[VarTuple[Value]] = () + + def __init__(self, parent, keys, groups, metrics, grouping_sets, rollups, cubes): + _check_integrity(keys.values(), {parent}) _check_integrity(metrics.values(), {parent}) - if duplicates := groups.keys() & metrics.keys(): + if duplicates := keys.keys() & metrics.keys(): raise RelationError( f"Cannot add {duplicates} to aggregate, they are already in the groupby" ) - super().__init__(parent=parent, groups=groups, metrics=metrics) + super().__init__( + parent=parent, + keys=keys, + groups=groups, + metrics=metrics, + grouping_sets=grouping_sets, + rollups=rollups, + cubes=cubes, + ) @attribute def values(self): - return FrozenOrderedDict({**self.groups, **self.metrics}) + return FrozenOrderedDict({**self.keys, **self.metrics}) @attribute def schema(self): diff --git a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt index 3fa92abc9f91a..d903d35a03455 100644 --- a/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_aggregate_arg_names/repr.txt @@ -12,9 +12,12 @@ r0 := UnboundTable: alltypes k time Aggregate[r0] - groups: - key1: r0.g - key2: Round(r0.f, digits=0) metrics: c: Sum(r0.c) - d: Mean(r0.d) \ No newline at end of file + d: Mean(r0.d) + keys: + key1: r0.g + key2: Round(r0.f, digits=0) + groups: + key1: r0.g + key2: Round(r0.f, digits=0) \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt b/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt index 9334f1c07925e..bd6f683507390 100644 --- a/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_complex_repr/repr.txt @@ -10,9 +10,11 @@ r2 := Project[r1] x: r1.a + 42 r3 := Aggregate[r2] - groups: - x: r2.x metrics: y: Sum(r2.a) + keys: + x: r2.x + groups: + x: r2.x Limit[r3, n=10] \ No newline at end of file diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt index 2e6f5c480f1ad..a91b0a04a28f7 100644 --- a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_table/repr.txt @@ -7,10 +7,12 @@ r1 := Filter[r0] InValues(value=r0.dest, options=['ORD', 'JFK', 'SFO']) r2 := Aggregate[r1] - groups: - dest: r1.dest metrics: Mean(arrdelay): Mean(r1.arrdelay) + keys: + dest: r1.dest + groups: + dest: r1.dest r3 := Sort[r2] desc r2['Mean(arrdelay)'] diff --git a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt index 128ffd518dd6b..c42ec8eef354e 100644 --- a/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt +++ b/ibis/expr/tests/snapshots/test_format/test_memoize_filtered_tables_in_join/repr.txt @@ -5,11 +5,14 @@ r0 := UnboundTable: purchases amount float64 r1 := Aggregate[r0] + metrics: + total: Sum(r0.amount) + keys: + region: r0.region + kind: r0.kind groups: region: r0.region kind: r0.kind - metrics: - total: Sum(r0.amount) r2 := Filter[r1] r1.kind == 'foo' diff --git a/ibis/expr/tests/test_newrels.py b/ibis/expr/tests/test_newrels.py index 60724c10396f7..f9d0f6ed43179 100644 --- a/ibis/expr/tests/test_newrels.py +++ b/ibis/expr/tests/test_newrels.py @@ -936,12 +936,9 @@ def test_aggregate(): agg = t.aggregate(by=[t.bool_col], metrics=[t.int_col.sum()]) expected = Aggregate( parent=t, - groups={ - "bool_col": t.bool_col, - }, - metrics={ - "Sum(int_col)": t.int_col.sum(), - }, + keys={"bool_col": t.bool_col}, + groups={"bool_col": t.bool_col}, + metrics={"Sum(int_col)": t.int_col.sum()}, ) assert agg.op() == expected @@ -1071,10 +1068,8 @@ def test_aggregate_field_dereferencing(): charge_ = discount_price_ * (1 + f.l_tax) assert a.op() == Aggregate( parent=f, - groups={ - "l_returnflag": f.l_returnflag, - "l_linestatus": f.l_linestatus, - }, + keys={"l_returnflag": f.l_returnflag, "l_linestatus": f.l_linestatus}, + groups={"l_returnflag": f.l_returnflag, "l_linestatus": f.l_linestatus}, metrics={ "sum_qty": f.l_quantity.sum(), "sum_base_price": f.l_extendedprice.sum(), @@ -1117,7 +1112,7 @@ def test_filter_condition_referencing_agg_without_groupby_turns_it_into_a_subque total = (r2.float_col * r2.int_col).sum() subquery = ops.ScalarSubquery( - ops.Aggregate(r2, groups={}, metrics={total.get_name(): total}) + ops.Aggregate(r2, keys={}, groups={}, metrics={total.get_name(): total}) ).to_expr() expected = Filter(parent=r3, predicates=[r3.value > subquery * 0.0001]) @@ -1433,12 +1428,9 @@ def test_self_view_join_followed_by_aggregate_correctly_dereference_fields(): ).to_expr() expected_agg = ops.Aggregate( parent=join, - groups={ - "g": join.g, - }, - metrics={ - "metric": (join.total - join.total_right).max(), - }, + keys={"g": join.g}, + groups={"g": join.g}, + metrics={"metric": (join.total - join.total_right).max()}, ).to_expr() assert join.equals(expected_join) assert agg.equals(expected_agg) diff --git a/ibis/expr/types/groupby.py b/ibis/expr/types/groupby.py index 193e114a8c735..5e9128a294576 100644 --- a/ibis/expr/types/groupby.py +++ b/ibis/expr/types/groupby.py @@ -16,18 +16,21 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Annotated +from typing import TYPE_CHECKING from public import public import ibis +import ibis.common.exceptions as exc import ibis.expr.datatypes as dt import ibis.expr.operations as ops import ibis.expr.types as ir -from ibis.common.grounds import Concrete -from ibis.common.patterns import Length # noqa: TCH001 +from ibis.common.deferred import Deferred, deferrable +from ibis.common.grounds import Annotable, Concrete +from ibis.common.selectors import Expandable from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.rewrites import rewrite_window_input +from ibis.util import experimental if TYPE_CHECKING: from collections.abc import Sequence @@ -38,10 +41,19 @@ class GroupedTable(Concrete): """An intermediate table expression to hold grouping information.""" table: ops.Relation - groupings: Annotated[VarTuple[ops.Value], Length(at_least=1)] + # groupings is allowed to be empty when there are some form of grouping + # sets provided + # + # groupings are *strictly* the things the user has explicitly requested to + # group by that are not part of a grouping set + groupings: VarTuple[ops.Value] orderings: VarTuple[ops.SortKey] = () havings: VarTuple[ops.Value[dt.Boolean]] = () + grouping_sets: VarTuple[VarTuple[VarTuple[ir.Value]]] = () + rollups: VarTuple[VarTuple[ir.Value]] = () + cubes: VarTuple[VarTuple[ir.Value]] = () + def __getitem__(self, args): # Shortcut for projection with window functions return self.select(*args) @@ -61,7 +73,12 @@ def aggregate(self, *metrics, **kwds) -> ir.Table: """Compute aggregates over a group by.""" metrics = self.table.to_expr().bind(*metrics, **kwds) return self.table.to_expr().aggregate( - metrics, by=self.groupings, having=self.havings + metrics, + by=self.groupings, + having=self.havings, + grouping_sets=self.grouping_sets, + rollups=self.rollups, + cubes=self.cubes, ) agg = aggregate @@ -218,6 +235,10 @@ def _selectables(self, *exprs, **kwexprs): -------- [`GroupedTable.mutate`](#ibis.expr.types.groupby.GroupedTable.mutate) """ + if self.grouping_sets or self.rollups or self.cubes: + raise exc.UnsupportedOperationError( + "Grouping sets, rollups, and cubes are not supported in grouped `mutate` or `select`" + ) table = self.table.to_expr() values = table.bind(*exprs, **kwexprs) window = ibis.window(group_by=self.groupings, order_by=self.orderings) @@ -262,6 +283,7 @@ def over( order_by=order_by, ) + # TODO: reject grouping sets here return self.__class__( self.table, self.by, @@ -279,7 +301,14 @@ def count(self) -> ir.Table: The aggregated table """ table = self.table.to_expr() - return table.aggregate(table.count(), by=self.groupings, having=self.havings) + return table.aggregate( + table.count(), + by=self.groupings, + having=self.havings, + grouping_sets=self.grouping_sets, + rollups=self.rollups, + cubes=self.cubes, + ) size = count @@ -314,3 +343,342 @@ def __init__(self, arr, parent): class GroupedNumbers(GroupedArray): mean = _group_agg_dispatch("mean") sum = _group_agg_dispatch("sum") + + +class GroupingSets(Annotable, Expandable): + exprs: VarTuple[VarTuple[str | ir.Value | Deferred]] + + def expand(self, table: ir.Table) -> Sequence[ir.Value]: + # produce all unique expressions in the grouping set, rollup or cube + values = [] + for expr in self.exprs: + values.append(tuple(table.bind(expr))) + return values + + +class GroupingSetsShorthand(Annotable, Expandable): + exprs: VarTuple[str | ir.Value | Deferred] + + def expand(self, table: ir.Table) -> Sequence[ir.Value]: + # produce all unique expressions in the grouping set, rollup or cube + values = [] + for expr in self.exprs: + values.extend(table.bind(expr)) + return values + + +class Rollup(GroupingSetsShorthand): + pass + + +class Cube(GroupingSetsShorthand): + pass + + +@public +@experimental +def rollup(*dims): + """Construct a rollup. + + Rollups are a shorthand for grouping sets that are sequentially more coarse + grained aggregations. + + Conceptually, a rollup is a union of a grouping sets, where each grouping + set is a superset of the previous one. + + Here's some SQL showing `ROLLUP` equivalence to standard issue `GROUP BY`: + + ```sql + -- 1. grouping set is a, b + SELECT a, b, count(*) n + FROM t + GROUP BY a, b + + UNION ALL + + --- 2. grouping set is a (rolled up from a, b) + SELECT a, NULL, count(*) n + FROM t + GROUP BY a + + UNION ALL + + -- 3. no grouping set, i.e., all rows (rolled up from a) + SELECT NULL, NULL, count(*) n + FROM t + ``` + + See Also + -------- + cube + grouping_sets + + Examples + -------- + >>> import ibis + >>> from ibis import _ + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + >>> t.head() + ┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├─────────┼───────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Adelie │ Torgersen │ 39.1 │ 18.7 │ 181 │ … │ + │ Adelie │ Torgersen │ 39.5 │ 17.4 │ 186 │ … │ + │ Adelie │ Torgersen │ 40.3 │ 18.0 │ 195 │ … │ + │ Adelie │ Torgersen │ NULL │ NULL │ NULL │ … │ + │ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │ + └─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘ + >>> ( + ... t.group_by(ibis.rollup(_.island, _.sex)) + ... .agg(mean_bill_length=_.bill_length_mm.mean()) + ... .order_by( + ... _.island.asc(nulls_first=True), + ... _.sex.asc(nulls_first=True), + ... _.mean_bill_length.desc(), + ... ) + ... ) + ┏━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ + ┃ island ┃ sex ┃ mean_bill_length ┃ + ┡━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ + │ string │ string │ float64 │ + ├───────────┼────────┼──────────────────┤ + │ NULL │ NULL │ 43.921930 │ + │ Biscoe │ NULL │ 45.625000 │ + │ Biscoe │ NULL │ 45.257485 │ + │ Biscoe │ female │ 43.307500 │ + │ Biscoe │ male │ 47.119277 │ + │ Dream │ NULL │ 44.167742 │ + │ Dream │ NULL │ 37.500000 │ + │ Dream │ female │ 42.296721 │ + │ Dream │ male │ 46.116129 │ + │ Torgersen │ NULL │ 38.950980 │ + │ … │ … │ … │ + └───────────┴────────┴──────────────────┘ + """ + return Rollup(dims) + + +@public +@experimental +def cube(*dims): + """Construct a cube. + + ::: {.callout-note} + ## Cubes can be very expensive to compute. + ::: + + Cubes are a shorthand for grouping sets that contain all possible ways + to aggregate a set of grouping keys. + + Conceptually, a cube is a union of a grouping sets, where each grouping + set is a member of the set of all sets of grouping keys (the power set). + + See Also + -------- + rollup + grouping_sets + + Examples + -------- + >>> import ibis + >>> from ibis import _ + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + >>> t.head() + ┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├─────────┼───────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Adelie │ Torgersen │ 39.1 │ 18.7 │ 181 │ … │ + │ Adelie │ Torgersen │ 39.5 │ 17.4 │ 186 │ … │ + │ Adelie │ Torgersen │ 40.3 │ 18.0 │ 195 │ … │ + │ Adelie │ Torgersen │ NULL │ NULL │ NULL │ … │ + │ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │ + └─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘ + >>> ( + ... t.group_by(ibis.cube("island", "sex")) + ... .agg(mean_bill_length=_.bill_length_mm.mean()) + ... .order_by( + ... _.island.asc(nulls_first=True), + ... _.sex.asc(nulls_first=True), + ... _.mean_bill_length.desc(), + ... ) + ... ) + ┏━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ + ┃ island ┃ sex ┃ mean_bill_length ┃ + ┡━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ + │ string │ string │ float64 │ + ├────────┼────────┼──────────────────┤ + │ NULL │ NULL │ 43.921930 │ + │ NULL │ NULL │ 41.300000 │ + │ NULL │ female │ 42.096970 │ + │ NULL │ male │ 45.854762 │ + │ Biscoe │ NULL │ 45.625000 │ + │ Biscoe │ NULL │ 45.257485 │ + │ Biscoe │ female │ 43.307500 │ + │ Biscoe │ male │ 47.119277 │ + │ Dream │ NULL │ 44.167742 │ + │ Dream │ NULL │ 37.500000 │ + │ … │ … │ … │ + └────────┴────────┴──────────────────┘ + """ + return Cube(dims) + + +@public +@experimental +def grouping_sets(*dims): + """Construct a grouping set. + + See Also + -------- + rollup + cube + + >>> import ibis + >>> from ibis import _ + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + >>> t.head() + ┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├─────────┼───────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Adelie │ Torgersen │ 39.1 │ 18.7 │ 181 │ … │ + │ Adelie │ Torgersen │ 39.5 │ 17.4 │ 186 │ … │ + │ Adelie │ Torgersen │ 40.3 │ 18.0 │ 195 │ … │ + │ Adelie │ Torgersen │ NULL │ NULL │ NULL │ … │ + │ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │ + └─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘ + >>> ( + ... t.group_by(ibis.grouping_sets((), _.island, (_.island, _.sex))) + ... .agg(mean_bill_length=_.bill_length_mm.mean()) + ... .order_by( + ... _.island.asc(nulls_first=True), + ... _.sex.asc(nulls_first=True), + ... _.mean_bill_length.desc(), + ... ) + ... ) + ┏━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ + ┃ island ┃ sex ┃ mean_bill_length ┃ + ┡━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ + │ string │ string │ float64 │ + ├───────────┼────────┼──────────────────┤ + │ NULL │ NULL │ 43.921930 │ + │ Biscoe │ NULL │ 45.625000 │ + │ Biscoe │ NULL │ 45.257485 │ + │ Biscoe │ female │ 43.307500 │ + │ Biscoe │ male │ 47.119277 │ + │ Dream │ NULL │ 44.167742 │ + │ Dream │ NULL │ 37.500000 │ + │ Dream │ female │ 42.296721 │ + │ Dream │ male │ 46.116129 │ + │ Torgersen │ NULL │ 38.950980 │ + │ … │ … │ … │ + └───────────┴────────┴──────────────────┘ + + The previous example is equivalent to using a rollup: + + >>> ( + ... t.group_by(ibis.rollup(_.island, _.sex)) + ... .agg(mean_bill_length=_.bill_length_mm.mean()) + ... .order_by( + ... _.island.asc(nulls_first=True), + ... _.sex.asc(nulls_first=True), + ... _.mean_bill_length.desc(), + ... ) + ... ) + ┏━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ + ┃ island ┃ sex ┃ mean_bill_length ┃ + ┡━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ + │ string │ string │ float64 │ + ├───────────┼────────┼──────────────────┤ + │ NULL │ NULL │ 43.921930 │ + │ Biscoe │ NULL │ 45.625000 │ + │ Biscoe │ NULL │ 45.257485 │ + │ Biscoe │ female │ 43.307500 │ + │ Biscoe │ male │ 47.119277 │ + │ Dream │ NULL │ 44.167742 │ + │ Dream │ NULL │ 37.500000 │ + │ Dream │ female │ 42.296721 │ + │ Dream │ male │ 46.116129 │ + │ Torgersen │ NULL │ 38.950980 │ + │ … │ … │ … │ + └───────────┴────────┴──────────────────┘ + """ + return GroupingSets(tuple(map(tuple, map(ibis.util.promote_list, dims)))) + + +@experimental +@deferrable +def group_id(first, *rest) -> ir.IntegerScalar: + """Return the grouping ID for a set of columns. + + Input columns must be part of the group by clause. + + ::: {.callout-note} + ## This function can only be called in a group by context. + ::: + + Returns + ------- + IntegerScalar + An integer whose bits represent whether the `i`th + group is present in the current row's aggregated value. + + Examples + -------- + >>> import ibis + >>> from ibis import _ + >>> ibis.options.interactive = True + >>> t = ibis.examples.penguins.fetch() + >>> t.head() + ┏━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━┓ + ┃ species ┃ island ┃ bill_length_mm ┃ bill_depth_mm ┃ flipper_length_mm ┃ … ┃ + ┡━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━┩ + │ string │ string │ float64 │ float64 │ int64 │ … │ + ├─────────┼───────────┼────────────────┼───────────────┼───────────────────┼───┤ + │ Adelie │ Torgersen │ 39.1 │ 18.7 │ 181 │ … │ + │ Adelie │ Torgersen │ 39.5 │ 17.4 │ 186 │ … │ + │ Adelie │ Torgersen │ 40.3 │ 18.0 │ 195 │ … │ + │ Adelie │ Torgersen │ NULL │ NULL │ NULL │ … │ + │ Adelie │ Torgersen │ 36.7 │ 19.3 │ 193 │ … │ + └─────────┴───────────┴────────────────┴───────────────┴───────────────────┴───┘ + >>> ( + ... t.group_by(ibis.rollup(_.island, _.sex)) + ... .agg( + ... group_id=ibis.group_id(_.island, _.sex), + ... mean_bill_length=_.bill_length_mm.mean(), + ... ) + ... .relocate(_.group_id) + ... .order_by( + ... _.group_id.desc(), + ... _.island.asc(nulls_first=True), + ... _.sex.asc(nulls_first=True), + ... _.mean_bill_length.desc(), + ... ) + ... ) + ┏━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ + ┃ group_id ┃ island ┃ sex ┃ mean_bill_length ┃ + ┡━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ + │ int64 │ string │ string │ float64 │ + ├──────────┼───────────┼────────┼──────────────────┤ + │ 3 │ NULL │ NULL │ 43.921930 │ + │ 1 │ Biscoe │ NULL │ 45.257485 │ + │ 1 │ Dream │ NULL │ 44.167742 │ + │ 1 │ Torgersen │ NULL │ 38.950980 │ + │ 0 │ Biscoe │ NULL │ 45.625000 │ + │ 0 │ Biscoe │ female │ 43.307500 │ + │ 0 │ Biscoe │ male │ 47.119277 │ + │ 0 │ Dream │ NULL │ 37.500000 │ + │ 0 │ Dream │ female │ 42.296721 │ + │ 0 │ Dream │ male │ 46.116129 │ + │ … │ … │ … │ … │ + └──────────┴───────────┴────────┴──────────────────┘ + """ + return ops.GroupID((first, *rest)).to_expr() diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index c7b6619639a84..fbbc5ee96c572 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -6,6 +6,7 @@ import warnings from collections import deque from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence +from itertools import chain from keyword import iskeyword from typing import TYPE_CHECKING, Any, Literal, NoReturn, overload @@ -18,13 +19,15 @@ import ibis.expr.operations as ops import ibis.expr.schema as sch from ibis import util +from ibis.common.collections import frozendict from ibis.common.deferred import Deferred, Resolver from ibis.common.selectors import Expandable, Selector from ibis.expr.rewrites import DerefMap from ibis.expr.types.core import Expr, _FixedTextJupyterMixin from ibis.expr.types.generic import Value, literal +from ibis.expr.types.groupby import Cube, GroupingSets, Rollup from ibis.expr.types.temporal import TimestampColumn -from ibis.util import deprecated +from ibis.util import deprecated, flatten_iterable if TYPE_CHECKING: import pandas as pd @@ -887,9 +890,35 @@ def group_by( """ from ibis.expr.types.groupby import GroupedTable + def partition_groups(*by, **key_exprs): + grouping_sets = [] + rollups = [] + cubes = [] + groups = [] + for b in by: + if isinstance(b, Rollup): + rollups.append(b.expand(self)) + elif isinstance(b, Cube): + cubes.append(b.expand(self)) + elif isinstance(b, GroupingSets): + grouping_sets.append(b.expand(self)) + else: + groups.extend(self.bind(b)) + + groups.extend(self.bind(**key_exprs)) + return groups, grouping_sets, rollups, cubes + by = tuple(v for v in by if v is not None) - groups = self.bind(*by, **key_exprs) - return GroupedTable(self, groups) + groups, grouping_sets, rollups, cubes = partition_groups(*by, **key_exprs) + if not (groups or grouping_sets or rollups or cubes): + raise com.IbisInputError("No grouping keys provided") + return GroupedTable( + self, + groupings=groups, + grouping_sets=grouping_sets, + rollups=rollups, + cubes=cubes, + ) # TODO(kszucs): shouldn't this be ibis.rowid() instead not bound to a specific table? def rowid(self) -> ir.IntegerValue: @@ -937,6 +966,9 @@ def aggregate( metrics: Sequence[ir.Scalar] | None = (), by: Sequence[ir.Value] | None = (), having: Sequence[ir.BooleanValue] | None = (), + grouping_sets=(), + rollups=(), + cubes=(), **kwargs: ir.Value, ) -> Table: """Aggregate a table with a given set of reductions grouping by `by`. @@ -956,6 +988,12 @@ def aggregate( ::: {.callout-warning} ## Expressions like `x is None` return `bool` and **will not** generate a SQL comparison to `NULL` ::: + grouping_sets + Grouping sets + rollups + Rollups + cubes + Cubes kwargs Named aggregate expressions @@ -1025,8 +1063,29 @@ def aggregate( else: metrics[metric.name] = metric - # construct the aggregate node - agg = ops.Aggregate(node, groups, metrics).to_expr() + keys = frozendict( + toolz.unique( + chain( + groups.items(), + ( + (g.op().name, g.op()) + for g in flatten_iterable( + chain(*grouping_sets, *rollups, *cubes) + ) + ), + ), + ) + ) + + agg = ops.Aggregate( + node, + groups=groups, + metrics=metrics, + keys=keys, + grouping_sets=grouping_sets, + rollups=rollups, + cubes=cubes, + ).to_expr() if having: # apply the having clause @@ -3006,10 +3065,7 @@ def join( | ir.BooleanColumn | Literal[True] | Literal[False] - | tuple[ - str | ir.Column | ir.Deferred, - str | ir.Column | ir.Deferred, - ] + | tuple[str | ir.Column | Deferred, str | ir.Column | Deferred] ] ) = (), how: JoinKind = "inner", diff --git a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt index afa7b6662e2f0..74487b1330e4d 100644 --- a/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt +++ b/ibis/tests/expr/snapshots/test_format_sql_operations/test_memoize_database_table/repr.txt @@ -21,9 +21,12 @@ r3 := JoinChain[r0] g: r2.g Aggregate[r3] - groups: - g: r3.g - key: r3.key metrics: foo: Mean(r3.f - r3.value) - bar: Sum(r3.f) \ No newline at end of file + bar: Sum(r3.f) + keys: + g: r3.g + key: r3.key + groups: + g: r3.g + key: r3.key \ No newline at end of file diff --git a/ibis/tests/expr/test_aggregation.py b/ibis/tests/expr/test_aggregation.py index 0192c27997aaf..f8534c02293f6 100644 --- a/ibis/tests/expr/test_aggregation.py +++ b/ibis/tests/expr/test_aggregation.py @@ -73,3 +73,28 @@ def test_aggregation_where(table, func): assert r1.equals(r3) assert r1.equals(r4) assert r1.op().where.equals(table.bools.op()) + + +@pytest.fixture +def t(): + return ibis.table(schema={"a": "string", "b": "string", "c": "string"}) + + +def test_rollup(t): + expr = t.group_by(ibis.rollup("a", "b", "c")).agg(n=_.count()) + result = ibis.to_sql(expr, dialect="duckdb") + assert len(result) + + +def test_grouping_sets(t): + gs = ibis.grouping_sets(("a",), ("b",)) + expr = t.group_by(gs).agg(n=_.count()) + result = ibis.to_sql(expr, dialect="duckdb") + assert len(result) + + +def test_cube(t): + gs = ibis.cube("a", "b") + expr = t.group_by(gs).agg(n=_.count()) + result = ibis.to_sql(expr, dialect="duckdb") + assert len(result) diff --git a/ibis/tests/expr/test_analysis.py b/ibis/tests/expr/test_analysis.py index f470e030ee262..06fb60fc8faf4 100644 --- a/ibis/tests/expr/test_analysis.py +++ b/ibis/tests/expr/test_analysis.py @@ -149,6 +149,7 @@ def test_filter_self_join(): agged = purchases.group_by(["region", "kind"]).aggregate(metric) assert agged.op() == ops.Aggregate( parent=purchases, + keys={"region": purchases.region, "kind": purchases.kind}, groups={"region": purchases.region, "kind": purchases.kind}, metrics={"total": purchases.amount.sum()}, ) diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index b0ebe5b76355c..4e62e02f919bf 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -701,6 +701,7 @@ def test_aggregate_having_implicit_metric(table): implicit_having_metric = table.aggregate(metric, by=by, having=having) expected_aggregate = ops.Aggregate( parent=table, + keys={"g": table.g}, groups={"g": table.g}, metrics={"total": table.f.sum(), table.c.sum().get_name(): table.c.sum()}, ) @@ -730,6 +731,7 @@ def test_agg_having_explicit_metric(table): ) expected_aggregate = ops.Aggregate( parent=table, + keys={"g": table.g}, groups={"g": table.g}, metrics={"total": table.f.sum(), "sum": table.c.sum()}, ) @@ -784,6 +786,7 @@ def test_group_by_having_api(table): agg = ops.Aggregate( parent=table, + keys={"g": table.g}, groups={"g": table.g}, metrics={"foo": table.f.sum(), "Mean(d)": table.d.mean()}, ).to_expr() @@ -842,7 +845,7 @@ def test_groupby_convenience(table): ids=["list", "tuple", "none", "selector"], ) def test_group_by_nothing(table, group): - with pytest.raises(ValidationError): + with pytest.raises(com.IbisInputError, match="No grouping keys"): table.group_by(group) @@ -1502,6 +1505,7 @@ def test_having(table): agg = ops.Aggregate( parent=m, + keys={"foo": m.foo}, groups={"foo": m.foo}, metrics={"CountStar()": ops.CountStar(m), "Sum(foo)": ops.Sum(m.foo)}, ).to_expr()