Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sql): fuse distinct with other select nodes when possible #9923

Merged
merged 1 commit into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ class SQLGlotCompiler(abc.ABC):
one_to_zero_index,
add_one_to_nth_value_input,
)
"""A sequence of rewrites to apply to the expression tree before compilation."""
"""A sequence of rewrites to apply to the expression tree before SQL-specific transforms."""

post_rewrites: tuple[type[pats.Replace], ...] = ()
"""A sequence of rewrites to apply to the expression tree after SQL-specific transforms."""

no_limit_value: sge.Null | None = None
"""The value to use to indicate no limit."""
Expand Down Expand Up @@ -606,6 +609,7 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression:
op,
params=params,
rewrites=self.rewrites,
post_rewrites=self.post_rewrites,
fuse_selects=options.sql.fuse_selects,
)

Expand Down Expand Up @@ -1257,9 +1261,11 @@ def _cleanup_names(self, exprs: Mapping[str, sge.Expression]):
else:
yield value.as_(name, quoted=self.quoted, copy=False)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -1286,6 +1292,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_DummyTable(self, op, *, values):
Expand Down Expand Up @@ -1470,11 +1479,6 @@ def visit_Limit(self, op, *, parent, n, offset):
return result.subquery(alias, copy=False)
return result

def visit_Distinct(self, op, *, parent):
return (
sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False)
)

def visit_CTE(self, op, *, parent):
return sg.table(parent.alias_or_name, quoted=self.quoted)

Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
exclude_unsupported_window_frame_from_ops,
exclude_unsupported_window_frame_from_rank,
exclude_unsupported_window_frame_from_row_number,
split_select_distinct_with_order_by,
)
from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit

Expand Down Expand Up @@ -113,6 +114,7 @@ class BigQueryCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_rank,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)

supports_qualify = True

Expand Down
3 changes: 3 additions & 0 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import DataFusionType
from ibis.backends.sql.dialects import DataFusion
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.common.temporal import IntervalUnit, TimestampUnit
from ibis.expr.operations.udf import InputType

Expand All @@ -26,6 +27,8 @@ class DataFusionCompiler(SQLGlotCompiler):

agg = AggGen(supports_filter=True, supports_order_by=True)

post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.ArgMax,
ops.ArgMin,
Expand Down
11 changes: 9 additions & 2 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
exclude_unsupported_window_frame_from_row_number,
p,
replace,
split_select_distinct_with_order_by,
)
from ibis.common.deferred import var

Expand Down Expand Up @@ -69,6 +70,7 @@ class MSSQLCompiler(SQLGlotCompiler):
rewrite_rows_range_order_by_window,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)
copy_func_args = True

UNSUPPORTED_OPS = (
Expand Down Expand Up @@ -479,9 +481,11 @@ def visit_All(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return sge.Min(this=arg)

def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_keys):
def visit_Select(
self, op, *, parent, selections, predicates, qualified, sort_keys, distinct
):
# if we've constructed a useless projection return the parent relation
if not (selections or predicates or qualified or sort_keys):
if not (selections or predicates or qualified or sort_keys or distinct):
return parent

result = parent
Expand All @@ -500,6 +504,9 @@ def visit_Select(self, op, *, parent, selections, predicates, qualified, sort_ke
if sort_keys:
result = result.order_by(*sort_keys, copy=False)

if distinct:
result = result.distinct()

return result

def visit_TimestampAdd(self, op, *, left, right):
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler
from ibis.backends.sql.datatypes import PostgresType
from ibis.backends.sql.dialects import Postgres
from ibis.backends.sql.rewrites import split_select_distinct_with_order_by
from ibis.common.exceptions import InvalidDecoratorError
from ibis.util import gen_name

Expand All @@ -41,6 +42,7 @@ class PostgresCompiler(SQLGlotCompiler):

dialect = Postgres
type_mapper = PostgresType
post_rewrites = (split_select_distinct_with_order_by,)

agg = AggGen(supports_filter=True, supports_order_by=True)

Expand Down
8 changes: 7 additions & 1 deletion ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from ibis.backends.sql.compilers.base import FALSE, NULL, STAR, SQLGlotCompiler
from ibis.backends.sql.datatypes import PySparkType
from ibis.backends.sql.dialects import PySpark
from ibis.backends.sql.rewrites import FirstValue, LastValue, p
from ibis.backends.sql.rewrites import (
FirstValue,
LastValue,
p,
split_select_distinct_with_order_by,
)
from ibis.common.patterns import replace
from ibis.config import options
from ibis.expr.operations.udf import InputType
Expand Down Expand Up @@ -51,6 +56,7 @@ class PySparkCompiler(SQLGlotCompiler):
dialect = PySpark
type_mapper = PySparkType
rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites)
post_rewrites = (split_select_distinct_with_order_by,)

UNSUPPORTED_OPS = (
ops.RowID,
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ibis.backends.sql.dialects import Trino
from ibis.backends.sql.rewrites import (
exclude_unsupported_window_frame_from_ops,
split_select_distinct_with_order_by,
)
from ibis.util import gen_name

Expand All @@ -39,6 +40,7 @@ class TrinoCompiler(SQLGlotCompiler):
exclude_unsupported_window_frame_from_ops,
*SQLGlotCompiler.rewrites,
)
post_rewrites = (split_select_distinct_with_order_by,)
quoted = True

NAN = sg.func("nan")
Expand Down
99 changes: 98 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
predicates: VarTuple[ops.Value[dt.Boolean]] = ()
qualified: VarTuple[ops.Value[dt.Boolean]] = ()
sort_keys: VarTuple[ops.SortKey] = ()
distinct: bool = False

def is_star_selection(self):
return tuple(self.values.items()) == tuple(self.parent.fields.items())
Expand Down Expand Up @@ -128,6 +129,12 @@
return Select(_.parent, selections=_.values, sort_keys=_.keys)


@replace(p.Distinct)
def distinct_to_select(_, **kwargs):
"""Convert a Distinct node to a Select node."""
return Select(_.parent, selections=_.values, distinct=True)


@replace(p.DropColumns)
def drop_columns_to_select(_, **kwargs):
"""Convert a DropColumns node to a Select node."""
Expand Down Expand Up @@ -244,6 +251,48 @@
if _.parent.find_below(blocking, filter=ops.Value):
return _

if _.parent.distinct:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also dragons here.

# The inner query is distinct.
#
# If the outer query is distinct, it's only safe to merge if it's a simple subselection:
# - Fusing in the presence of non-deterministic calls in the select would lead to
# incorrect results
# - Fusing in the presence of expensive calls in the select would lead to potential
# performance pitfalls
if _.distinct and not all(
isinstance(v, ops.Field) for v in _.selections.values()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this cover the alias-in-outer-project case?

SELECT a, b AS c
FROM (
  SELECT DISTINCT
    a, b
  FROM t
)

would become

SELECT DISTINCT
  a, b AS c
FROM t

If not, fine to either handle later or perhaps never if it doesn't come up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This branch is for both being distinct, so:

SELECT DISTINCT a, b as c
FROM (SELECT DISTINCT a, b FROM t)

In this case, yes, the aliases are properly handled.

The branch on line 270 handles the SELECT ... FROM (SELECT DISTINCT ...) case. Right now it only works with SELECT * cases, but we might be able to make it work with outer queries that rename columns but otherwise select all of them. Right now I don't think that's worth it.

):
return _

Check warning on line 265 in ibis/backends/sql/rewrites.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/rewrites.py#L265

Added line #L265 was not covered by tests

# If the outer query isn't distinct, it's only safe to merge if the outer is a SELECT *:
# - If new columns are added, they might be non-distinct, changing the distinctness
# - If previous columns are removed, that would also change the distinctness
if not _.distinct and not _.is_star_selection():
return _

distinct = True
elif _.distinct:
# The outer query is distinct and the inner isn't. It's only safe to merge if either
# - The inner query isn't ordered
# - The outer query is a SELECT *
#
# Otherwise we run the risk that the outer query drops columns needed for the ordering of
# the inner query - many backends don't allow select distinc queries to order by columns
# that aren't present in their selection, like
#
# SELECT DISTINCT a, b FROM t ORDER BY c --- some backends will explode at this
#
# An alternate solution would be to drop the inner ORDER BY clause, since the backend will
# ignore it anyway since it's a subquery. That feels potentially risky though, better
# to generate the SQL as written.
if _.parent.sort_keys and not _.is_star_selection():
jcrist marked this conversation as resolved.
Show resolved Hide resolved
return _

Check warning on line 289 in ibis/backends/sql/rewrites.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/rewrites.py#L289

Added line #L289 was not covered by tests
jcrist marked this conversation as resolved.
Show resolved Hide resolved

distinct = True
else:
# Neither query is distinct, safe to merge
distinct = False

subs = {ops.Field(_.parent, k): v for k, v in _.parent.values.items()}
selections = {k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()}

Expand All @@ -266,6 +315,7 @@
predicates=unique_predicates,
qualified=unique_qualified,
sort_keys=unique_sort_keys,
distinct=distinct,
)
return result if complexity(result) <= complexity(_) else _

Expand All @@ -289,6 +339,7 @@
node: ops.Node,
params: Mapping[ops.ScalarParameter, Any],
rewrites: Sequence[Pattern] = (),
post_rewrites: Sequence[Pattern] = (),
fuse_selects: bool = True,
) -> tuple[ops.Node, list[ops.Node]]:
"""Lower the ibis expression graph to a SQL-like relational algebra.
Expand All @@ -300,7 +351,9 @@
params
A mapping of scalar parameters to their values.
rewrites
Supplementary rewrites to apply to the expression graph.
Supplementary rewrites to apply before SQL-specific transforms.
post_rewrites
Supplementary rewrites to apply after SQL-specific transforms.
fuse_selects
Whether to merge subsequent Select nodes into one where possible.

Expand All @@ -322,6 +375,7 @@
| project_to_select
| filter_to_select
| sort_to_select
| distinct_to_select
| fill_null_to_select
| drop_null_to_select
| drop_columns_to_select
Expand All @@ -335,6 +389,9 @@
else:
simplified = sqlized

if post_rewrites:
simplified = simplified.replace(reduce(operator.or_, post_rewrites))

# extract common table expressions while wrapping them in a CTE node
ctes = extract_ctes(simplified)

Expand All @@ -351,6 +408,46 @@
# supplemental rewrites selectively used on a per-backend basis


@replace(Select)
def split_select_distinct_with_order_by(_):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are dragons here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed opportunity to call this supplant_merged_with_ordered_distinct 😂

"""Split a `SELECT DISTINCT ... ORDER BY` query when needed.

Some databases (postgres, pyspark, ...) have issues with two types of
ordered select distinct statements:

```
--- ORDER BY with an expression instead of a name in the select list
SELECT DISTINCT a, b FROM t ORDER BY a + 1

--- ORDER BY using a qualified column name, rather than the alias in the select list
SELECT DISTINCT a, b as x FROM t ORDER BY b --- or t.b
```

We solve both these cases by splitting everything except the `ORDER BY`
into a subquery.

```
SELECT DISTINCT a, b FROM t WHERE a > 10 ORDER BY a + 1
--- is rewritten as ->
SELECT * FROM (SELECT DISTINCT a, b FROM t WHERE a > 10) ORDER BY a + 1
```
"""
# risingwave and pyspark also don't allow qualified names as sort keys, like
# SELECT DISTINCT t.a FROM t ORDER BY t.a
# To avoid having specific rewrite rules for these backends to use only
# local names, we always split SELECT DISTINCT from ORDER BY here. Otherwise we
# could also avoid splitting if all sort keys appear in the select list.
if _.distinct and _.sort_keys:
inner = _.copy(sort_keys=())
subs = {v: ops.Field(inner, k) for k, v in inner.values.items()}
sort_keys = tuple(s.replace(subs, filter=ops.Value) for s in _.sort_keys)
selections = {
k: v.replace(subs, filter=ops.Value) for k, v in _.selections.items()
}
return Select(inner, selections=selections, sort_keys=sort_keys)
return _


@replace(p.WindowFunction(func=p.NTile(y), order_by=()))
def add_order_by_to_empty_ranking_window_functions(_, **kwargs):
"""Add an ORDER BY clause to rank window functions that don't have one."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
SELECT DISTINCT
*
FROM (
SELECT
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
) AS "t1"
"t0"."string_col",
"t0"."int_col"
FROM "functional_alltypes" AS "t0"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
SELECT DISTINCT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10 AND "t0"."a" > 10
ORDER BY
"t0"."a" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT DISTINCT
"t0"."a",
"t0"."b",
"t0"."c"
FROM "test" AS "t0"
WHERE
"t0"."c" > 10 AND "t0"."a" > 10
Loading
Loading