-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(api): add distinct
option to collect
#10121
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -479,16 +479,24 @@ | |||||
return self.f.parse_timestamp(format_str, arg, timezone) | ||||||
return self.f.parse_datetime(format_str, arg) | ||||||
|
||||||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): | ||||||
if where is not None and include_null: | ||||||
raise com.UnsupportedOperationError( | ||||||
"Combining `include_null=True` and `where` is not supported " | ||||||
"by bigquery" | ||||||
) | ||||||
out = self.agg.array_agg(arg, where=where, order_by=order_by) | ||||||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): | ||||||
if where is not None: | ||||||
if include_null: | ||||||
raise com.UnsupportedOperationError( | ||||||
"Combining `include_null=True` and `where` is not supported by bigquery" | ||||||
) | ||||||
if distinct: | ||||||
raise com.UnsupportedOperationError( | ||||||
"Combining `distinct=True` and `where` is not supported by bigquery" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) | ||||||
arg = compiler.if_(where, arg, NULL) | ||||||
if distinct: | ||||||
arg = sge.Distinct(expressions=[arg]) | ||||||
if order_by: | ||||||
arg = sge.Order(this=arg, expressions=order_by) | ||||||
if not include_null: | ||||||
out = sge.IgnoreNulls(this=out) | ||||||
return out | ||||||
arg = sge.IgnoreNulls(this=arg) | ||||||
return self.f.array_agg(arg) | ||||||
|
||||||
def _neg_idx_to_pos(self, arg, idx): | ||||||
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx) | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -327,7 +327,11 @@ def visit_ArrayRepeat(self, op, *, arg, times): | |||||
def visit_ArrayPosition(self, op, *, arg, other): | ||||||
return self.f.coalesce(self.f.array_position(arg, other), 0) | ||||||
|
||||||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): | ||||||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): | ||||||
if distinct: | ||||||
raise com.UnsupportedOperationError( | ||||||
"`collect` with `distinct=True` is not supported" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
) | ||||||
if not include_null: | ||||||
cond = arg.is_(sg.not_(NULL, copy=False)) | ||||||
where = cond if where is None else sge.And(this=cond, expression=where) | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -572,20 +572,24 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right): | |
def visit_StructColumn(self, op, *, names, values): | ||
return self.cast(sge.Struct(expressions=list(values)), op.dtype) | ||
|
||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): | ||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): | ||
if order_by: | ||
raise com.UnsupportedOperationError( | ||
"ordering of order-sensitive aggregations via `order_by` is " | ||
"not supported for this backend" | ||
) | ||
# the only way to get filtering *and* respecting nulls is to use | ||
# `FILTER` syntax, but it's broken in various ways for other aggregates | ||
out = self.f.array_agg(arg) | ||
if not include_null: | ||
cond = arg.is_(sg.not_(NULL, copy=False)) | ||
where = cond if where is None else sge.And(this=cond, expression=where) | ||
out = self.f.array_agg(arg) | ||
if where is not None: | ||
out = sge.Filter(this=out, expression=sge.Where(this=where)) | ||
if distinct: | ||
# TODO: Flink supposedly supports `ARRAY_AGG(DISTINCT ...)`, but it | ||
# doesn't work with filtering (either `include_null=False` or | ||
# additional filtering). Their `array_distinct` function does maintain | ||
# ordering though, so we can use it here. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Flink does some really weird stuff here. One parameter combo somehow results in filtering out all but two null values for unknown reasons. |
||
out = self.f.array_distinct(out) | ||
return out | ||
|
||
def visit_Strip(self, op, *, arg): | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -452,25 +452,36 @@ | |||||
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} | ||||||
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) | ||||||
|
||||||
def _array_collect(self, *, arg, where, order_by, include_null): | ||||||
def _array_collect(self, *, arg, where, order_by, include_null, distinct=False): | ||||||
if include_null: | ||||||
raise com.UnsupportedOperationError( | ||||||
"`include_null=True` is not supported by the snowflake backend" | ||||||
) | ||||||
if where is not None and distinct: | ||||||
raise com.UnsupportedOperationError( | ||||||
"Combining `distinct=True` and `where` is not supported by snowflake" | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Only because it's a proper noun and not a generic one :) Alternatively, we can leave this for a follow up and audit all the backends for proper noun capitalization. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know there's more of these in the codebase, I'll leave this as a follow up. |
||||||
) | ||||||
|
||||||
if where is not None: | ||||||
arg = self.if_(where, arg, NULL) | ||||||
|
||||||
if distinct: | ||||||
arg = sge.Distinct(expressions=[arg]) | ||||||
|
||||||
out = self.f.array_agg(arg) | ||||||
|
||||||
if order_by: | ||||||
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) | ||||||
|
||||||
return out | ||||||
|
||||||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null): | ||||||
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct): | ||||||
return self._array_collect( | ||||||
arg=arg, where=where, order_by=order_by, include_null=include_null | ||||||
arg=arg, | ||||||
where=where, | ||||||
order_by=order_by, | ||||||
include_null=include_null, | ||||||
distinct=distinct, | ||||||
) | ||||||
|
||||||
def visit_First(self, op, *, arg, where, order_by, include_null): | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from __future__ import annotations | ||
|
||
import itertools | ||
from datetime import date | ||
from operator import methodcaller | ||
|
||
|
@@ -1301,67 +1302,75 @@ def test_group_concat_ordered(alltypes, df, filtered): | |
assert result == expected | ||
|
||
|
||
@pytest.mark.notimpl( | ||
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"], | ||
raises=com.OperationNotDefinedError, | ||
) | ||
@pytest.mark.notimpl( | ||
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError | ||
) | ||
@pytest.mark.parametrize("filtered", [True, False]) | ||
def test_collect_ordered(alltypes, df, filtered): | ||
ibis_cond = (_.id % 13 == 0) if filtered else None | ||
pd_cond = (df.id % 13 == 0) if filtered else True | ||
result = ( | ||
alltypes.filter(_.bigint_col == 10) | ||
.id.cast("str") | ||
.collect(where=ibis_cond, order_by=_.id.desc()) | ||
.execute() | ||
) | ||
expected = list( | ||
df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str) | ||
) | ||
assert result == expected | ||
def gen_test_collect_marks(distinct, filtered, ordered, include_null): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The other option was to add a bunch of |
||
"""The marks for this test fail for different combinations of parameters. | ||
Rather than set `strict=False` (which can let bugs sneak through), we split | ||
the mark generation into a function""" | ||
if distinct: | ||
yield pytest.mark.notimpl(["datafusion"], raises=com.UnsupportedOperationError) | ||
if ordered: | ||
yield pytest.mark.notimpl( | ||
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError | ||
) | ||
if include_null: | ||
yield pytest.mark.notimpl( | ||
["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError | ||
) | ||
|
||
# Handle special cases | ||
if filtered and distinct: | ||
yield pytest.mark.notimpl( | ||
["bigquery", "snowflake"], | ||
raises=com.UnsupportedOperationError, | ||
reason="Can't combine where and distinct", | ||
) | ||
elif filtered and include_null: | ||
yield pytest.mark.notimpl( | ||
["bigquery"], | ||
raises=com.UnsupportedOperationError, | ||
reason="Can't combine where and include_null", | ||
) | ||
elif include_null: | ||
yield pytest.mark.notimpl( | ||
["bigquery"], | ||
raises=GoogleBadRequest, | ||
reason="BigQuery can't retrieve arrays with null values", | ||
) | ||
|
||
|
||
@pytest.mark.notimpl( | ||
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"], | ||
raises=com.OperationNotDefinedError, | ||
) | ||
@pytest.mark.parametrize("filtered", [True, False]) | ||
@pytest.mark.parametrize( | ||
"include_null", | ||
"distinct, filtered, ordered, include_null", | ||
[ | ||
False, | ||
param( | ||
True, | ||
marks=[ | ||
pytest.mark.notimpl( | ||
["clickhouse", "pyspark", "snowflake"], | ||
raises=com.UnsupportedOperationError, | ||
reason="`include_null=True` is not supported", | ||
), | ||
pytest.mark.notimpl( | ||
["bigquery"], | ||
raises=com.UnsupportedOperationError, | ||
reason="Can't mix `where` and `include_null=True`", | ||
strict=False, | ||
), | ||
], | ||
), | ||
param(*ps, marks=list(gen_test_collect_marks(*ps))) | ||
for ps in itertools.product(*([[True, False]] * 4)) | ||
], | ||
) | ||
def test_collect(alltypes, df, filtered, include_null): | ||
ibis_cond = (_.id % 13 == 0) if filtered else None | ||
pd_cond = (df.id % 13 == 0) if filtered else slice(None) | ||
expr = ( | ||
alltypes.string_col.nullif("3") | ||
.collect(where=ibis_cond, include_null=include_null) | ||
.length() | ||
def test_collect(alltypes, df, distinct, filtered, ordered, include_null): | ||
expr = alltypes.mutate(x=_.string_col.nullif("3")).x.collect( | ||
where=((_.id % 13 == 0) if filtered else None), | ||
include_null=include_null, | ||
distinct=distinct, | ||
order_by=(_.x.desc() if ordered else ()), | ||
) | ||
res = expr.execute() | ||
vals = df.string_col if include_null else df.string_col[df.string_col != "3"] | ||
sol = len(vals[pd_cond]) | ||
|
||
x = df.string_col.where(df.string_col != "3", None) | ||
if filtered: | ||
x = x[df.id % 13 == 0] | ||
if not include_null: | ||
x = x.dropna() | ||
if distinct: | ||
x = x.drop_duplicates() | ||
sol = sorted(x, key=lambda x: (x is not None, x), reverse=True) | ||
|
||
if not ordered: | ||
# If unordered, order afterwards so we can compare | ||
res = sorted(res, key=lambda x: (x is not None, x), reverse=True) | ||
|
||
assert res == sol | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.