Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): add distinct option to collect #10121

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1003,7 +1003,10 @@ def array_collect(op, in_group_by=False, **kw):
if op.order_by:
keys = [translate(k.expr, **kw).filter(predicate) for k in op.order_by]
descending = [k.descending for k in op.order_by]
arg = arg.sort_by(keys, descending=descending)
arg = arg.sort_by(keys, descending=descending, nulls_last=True)

if op.distinct:
arg = arg.unique(maintain_order=op.order_by is not None)

# Polars' behavior changes for `implode` within a `group_by` currently.
# See https://github.com/pola-rs/polars/issues/16756
Expand Down
26 changes: 17 additions & 9 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,16 +479,24 @@
return self.f.parse_timestamp(format_str, arg, timezone)
return self.f.parse_datetime(format_str, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
if where is not None and include_null:
raise com.UnsupportedOperationError(
"Combining `include_null=True` and `where` is not supported "
"by bigquery"
)
out = self.agg.array_agg(arg, where=where, order_by=order_by)
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if where is not None:
if include_null:
raise com.UnsupportedOperationError(

Check warning on line 485 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L485

Added line #L485 was not covered by tests
"Combining `include_null=True` and `where` is not supported by bigquery"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Combining `include_null=True` and `where` is not supported by bigquery"
"Combining `include_null=True` and `where` is not supported by BigQuery"

)
if distinct:
raise com.UnsupportedOperationError(

Check warning on line 489 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L489

Added line #L489 was not covered by tests
"Combining `distinct=True` and `where` is not supported by bigquery"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Combining `distinct=True` and `where` is not supported by bigquery"
"Combining `distinct=True` and `where` is not supported by BigQuery"

)
arg = compiler.if_(where, arg, NULL)

Check warning on line 492 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L492

Added line #L492 was not covered by tests
if distinct:
arg = sge.Distinct(expressions=[arg])

Check warning on line 494 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L494

Added line #L494 was not covered by tests
if order_by:
arg = sge.Order(this=arg, expressions=order_by)

Check warning on line 496 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L496

Added line #L496 was not covered by tests
if not include_null:
out = sge.IgnoreNulls(this=out)
return out
arg = sge.IgnoreNulls(this=arg)
return self.f.array_agg(arg)

Check warning on line 499 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L498-L499

Added lines #L498 - L499 were not covered by tests

def _neg_idx_to_pos(self, arg, idx):
return self.if_(idx < 0, self.f.array_length(arg) + idx, idx)
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,12 +611,13 @@ def visit_ArrayUnion(self, op, *, left, right):
def visit_ArrayZip(self, op: ops.ArrayZip, *, arg, **_: Any) -> str:
return self.f.arrayZip(*arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the clickhouse backend"
)
return self.agg.groupArray(arg, where=where, order_by=order_by)
func = self.agg.groupUniqArray if distinct else self.agg.groupArray
return func(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
if include_null:
Expand Down
6 changes: 5 additions & 1 deletion ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,11 @@ def visit_ArrayRepeat(self, op, *, arg, times):
def visit_ArrayPosition(self, op, *, arg, other):
return self.f.coalesce(self.f.array_position(arg, other), 0)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if distinct:
raise com.UnsupportedOperationError(
"`collect` with `distinct=True` is not supported"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"`collect` with `distinct=True` is not supported"
"`collect` with `distinct=True` is not supported by DataFusion"

)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,12 @@ def visit_ArrayPosition(self, op, *, arg, other):
self.f.coalesce(self.f.list_indexof(arg, other), 0),
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_ArrayIndex(self, op, *, arg, index):
Expand Down
12 changes: 8 additions & 4 deletions ibis/backends/sql/compilers/flink.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,20 +572,24 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right):
def visit_StructColumn(self, op, *, names, values):
return self.cast(sge.Struct(expressions=list(values)), op.dtype)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if order_by:
raise com.UnsupportedOperationError(
"ordering of order-sensitive aggregations via `order_by` is "
"not supported for this backend"
)
# the only way to get filtering *and* respecting nulls is to use
# `FILTER` syntax, but it's broken in various ways for other aggregates
out = self.f.array_agg(arg)
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
out = self.f.array_agg(arg)
if where is not None:
out = sge.Filter(this=out, expression=sge.Where(this=where))
if distinct:
# TODO: Flink supposedly supports `ARRAY_AGG(DISTINCT ...)`, but it
# doesn't work with filtering (either `include_null=False` or
# additional filtering). Their `array_distinct` function does maintain
# ordering though, so we can use it here.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Flink does some really weird stuff here. One parameter combo somehow results in filtering out all but two null values for unknown reasons.

out = self.f.array_distinct(out)
return out

def visit_Strip(self, op, *, arg):
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,12 @@ def visit_ArrayIntersect(self, op, *, left, right):
)
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_First(self, op, *, arg, where, order_by, include_null):
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,12 +432,16 @@ def visit_ArrayContains(self, op, *, arg, other):
def visit_ArrayStringJoin(self, op, *, arg, sep):
return self.f.concat_ws(sep, arg)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the pyspark backend"
)
return self.agg.array_agg(arg, where=where, order_by=order_by)
if where:
arg = self.if_(where, arg, NULL)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, order_by=order_by)

def visit_StringFind(self, op, *, arg, substr, start, end):
if end is not None:
Expand Down
17 changes: 14 additions & 3 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,25 +452,36 @@
timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9}
return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short])

def _array_collect(self, *, arg, where, order_by, include_null):
def _array_collect(self, *, arg, where, order_by, include_null, distinct=False):
if include_null:
raise com.UnsupportedOperationError(
"`include_null=True` is not supported by the snowflake backend"
)
if where is not None and distinct:
raise com.UnsupportedOperationError(

Check warning on line 461 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L461

Added line #L461 was not covered by tests
"Combining `distinct=True` and `where` is not supported by snowflake"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"Combining `distinct=True` and `where` is not supported by snowflake"
"Combining `distinct=True` and `where` is not supported by Snowflake"

Only because it's a proper noun and not a generic one :)

Alternatively, we can leave this for a follow up and audit all the backends for proper noun capitalization.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know there's more of these in the codebase, I'll leave this as a follow up.

)

if where is not None:
arg = self.if_(where, arg, NULL)

if distinct:
arg = sge.Distinct(expressions=[arg])

Check warning on line 469 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L469

Added line #L469 was not covered by tests

out = self.f.array_agg(arg)

if order_by:
out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by))

return out

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
return self._array_collect(
arg=arg, where=where, order_by=order_by, include_null=include_null
arg=arg,
where=where,
order_by=order_by,
include_null=include_null,
distinct=distinct,
)

def visit_First(self, op, *, arg, where, order_by, include_null):
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,12 @@ def visit_ArrayContains(self, op, *, arg, other):
NULL,
)

def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null):
def visit_ArrayCollect(self, op, *, arg, where, order_by, include_null, distinct):
if not include_null:
cond = arg.is_(sg.not_(NULL, copy=False))
where = cond if where is None else sge.And(this=cond, expression=where)
if distinct:
arg = sge.Distinct(expressions=[arg])
return self.agg.array_agg(arg, where=where, order_by=order_by)

def visit_JSONGetItem(self, op, *, arg, index):
Expand Down
107 changes: 58 additions & 49 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import itertools
from datetime import date
from operator import methodcaller

Expand Down Expand Up @@ -1301,67 +1302,75 @@ def test_group_concat_ordered(alltypes, df, filtered):
assert result == expected


@pytest.mark.notimpl(
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notimpl(
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError
)
@pytest.mark.parametrize("filtered", [True, False])
def test_collect_ordered(alltypes, df, filtered):
ibis_cond = (_.id % 13 == 0) if filtered else None
pd_cond = (df.id % 13 == 0) if filtered else True
result = (
alltypes.filter(_.bigint_col == 10)
.id.cast("str")
.collect(where=ibis_cond, order_by=_.id.desc())
.execute()
)
expected = list(
df.id[(df.bigint_col == 10) & pd_cond].sort_values(ascending=False).astype(str)
)
assert result == expected
def gen_test_collect_marks(distinct, filtered, ordered, include_null):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The other option was to add a bunch of strict=False checks everywhere, making the test less strict. In the long run I might want to add a new pytest shorthand for handling parametrizing a test with a cross-product of parameters with markers for certain parameter combos, but for now breaking the mark generation into a utility function didn't seem too bad.

"""The marks for this test fail for different combinations of parameters.
Rather than set `strict=False` (which can let bugs sneak through), we split
the mark generation into a function"""
if distinct:
yield pytest.mark.notimpl(["datafusion"], raises=com.UnsupportedOperationError)
if ordered:
yield pytest.mark.notimpl(
["clickhouse", "pyspark", "flink"], raises=com.UnsupportedOperationError
)
if include_null:
yield pytest.mark.notimpl(
["clickhouse", "pyspark", "snowflake"], raises=com.UnsupportedOperationError
)

# Handle special cases
if filtered and distinct:
yield pytest.mark.notimpl(
["bigquery", "snowflake"],
raises=com.UnsupportedOperationError,
reason="Can't combine where and distinct",
)
elif filtered and include_null:
yield pytest.mark.notimpl(
["bigquery"],
raises=com.UnsupportedOperationError,
reason="Can't combine where and include_null",
)
elif include_null:
yield pytest.mark.notimpl(
["bigquery"],
raises=GoogleBadRequest,
reason="BigQuery can't retrieve arrays with null values",
)


@pytest.mark.notimpl(
["druid", "exasol", "impala", "mssql", "mysql", "oracle", "sqlite"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.parametrize("filtered", [True, False])
@pytest.mark.parametrize(
"include_null",
"distinct, filtered, ordered, include_null",
[
False,
param(
True,
marks=[
pytest.mark.notimpl(
["clickhouse", "pyspark", "snowflake"],
raises=com.UnsupportedOperationError,
reason="`include_null=True` is not supported",
),
pytest.mark.notimpl(
["bigquery"],
raises=com.UnsupportedOperationError,
reason="Can't mix `where` and `include_null=True`",
strict=False,
),
],
),
param(*ps, marks=list(gen_test_collect_marks(*ps)))
for ps in itertools.product(*([[True, False]] * 4))
],
)
def test_collect(alltypes, df, filtered, include_null):
ibis_cond = (_.id % 13 == 0) if filtered else None
pd_cond = (df.id % 13 == 0) if filtered else slice(None)
expr = (
alltypes.string_col.nullif("3")
.collect(where=ibis_cond, include_null=include_null)
.length()
def test_collect(alltypes, df, distinct, filtered, ordered, include_null):
expr = alltypes.mutate(x=_.string_col.nullif("3")).x.collect(
where=((_.id % 13 == 0) if filtered else None),
include_null=include_null,
distinct=distinct,
order_by=(_.x.desc() if ordered else ()),
)
res = expr.execute()
vals = df.string_col if include_null else df.string_col[df.string_col != "3"]
sol = len(vals[pd_cond])

x = df.string_col.where(df.string_col != "3", None)
if filtered:
x = x[df.id % 13 == 0]
if not include_null:
x = x.dropna()
if distinct:
x = x.drop_duplicates()
sol = sorted(x, key=lambda x: (x is not None, x), reverse=True)

if not ordered:
# If unordered, order afterwards so we can compare
res = sorted(res, key=lambda x: (x is not None, x), reverse=True)

assert res == sol


Expand Down
11 changes: 10 additions & 1 deletion ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import ibis.expr.datashape as ds
import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
from ibis.common.annotations import attribute
from ibis.common.annotations import ValidationError, attribute
from ibis.common.typing import VarTuple # noqa: TCH001
from ibis.expr.operations.core import Column, Value
from ibis.expr.operations.relations import Relation # noqa: TCH001
Expand Down Expand Up @@ -376,6 +376,15 @@ class ArrayCollect(Filterable, Reduction):
arg: Column
order_by: VarTuple[SortKey] = ()
include_null: bool = False
distinct: bool = False

def __init__(self, arg, order_by, distinct, **kwargs):
if distinct and order_by and [arg] != [key.expr for key in order_by]:
raise ValidationError(
"`collect` with `order_by` and `distinct=True` and may only "
"order by the collected column"
jcrist marked this conversation as resolved.
Show resolved Hide resolved
)
super().__init__(arg=arg, order_by=order_by, distinct=distinct, **kwargs)

@attribute
def dtype(self):
Expand Down
20 changes: 20 additions & 0 deletions ibis/expr/tests/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import ibis
import ibis.expr.operations as ops
from ibis import _
from ibis.common.annotations import ValidationError
from ibis.common.deferred import Deferred
from ibis.common.exceptions import IbisTypeError

Expand Down Expand Up @@ -161,3 +162,22 @@ def test_ordered_aggregations_no_order(method):
q3 = func(order_by=())
assert q1.equals(q2)
assert q1.equals(q3)


def test_collect_distinct():
t = ibis.table({"a": "string", "b": "int", "c": "int"}, name="t")
# Fine
t.a.collect(distinct=True)
t.a.collect(distinct=True, order_by=t.a.desc())
(t.a + 1).collect(distinct=True, order_by=(t.a + 1).desc())
jcrist marked this conversation as resolved.
Show resolved Hide resolved

with pytest.raises(ValidationError, match="only order by the collected column"):
t.b.collect(distinct=True, order_by=t.a)
with pytest.raises(ValidationError, match="only order by the collected column"):
t.b.collect(
distinct=True,
order_by=(
t.a,
t.b,
),
)
Loading