Skip to content

Commit 05c5181

Browse files
fix(optimizer)!: refactor Connector simplification to factor in types (#6152)
* fix(optimizer)!: simplify AND with static BOOLEAN type * fix tests * Fix case * fix style --------- Co-authored-by: George Sittas <[email protected]>
1 parent 6bd59ac commit 05c5181

File tree

6 files changed

+1084
-968
lines changed

6 files changed

+1084
-968
lines changed

sqlglot/optimizer/annotate_types.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def annotate_types(
3737
expression_metadata: t.Optional[ExpressionMetadataType] = None,
3838
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
3939
dialect: DialectType = None,
40+
overwrite_types: bool = True,
4041
) -> E:
4142
"""
4243
Infers the types of an expression, annotating its AST accordingly.
@@ -52,16 +53,22 @@ def annotate_types(
5253
Args:
5354
expression: Expression to annotate.
5455
schema: Database schema.
55-
annotators: Maps expression type to corresponding annotation function.
56+
expression_metadata: Maps expression type to corresponding annotation function.
5657
coerces_to: Maps expression type to set of types that it can be coerced into.
58+
overwrite_types: Re-annotate the existing AST types.
5759
5860
Returns:
5961
The expression annotated with types.
6062
"""
6163

6264
schema = ensure_schema(schema, dialect=dialect)
6365

64-
return TypeAnnotator(schema, expression_metadata, coerces_to).annotate(expression)
66+
return TypeAnnotator(
67+
schema=schema,
68+
expression_metadata=expression_metadata,
69+
coerces_to=coerces_to,
70+
overwrite_types=overwrite_types,
71+
).annotate(expression)
6572

6673

6774
def _coerce_date_literal(l: exp.Expression, unit: t.Optional[exp.Expression]) -> exp.DataType.Type:
@@ -178,6 +185,7 @@ def __init__(
178185
expression_metadata: t.Optional[ExpressionMetadataType] = None,
179186
coerces_to: t.Optional[t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]]] = None,
180187
binary_coercions: t.Optional[BinaryCoercions] = None,
188+
overwrite_types: bool = True,
181189
) -> None:
182190
self.schema = schema
183191
self.expression_metadata = (
@@ -202,6 +210,14 @@ def __init__(
202210
# would reprocess the entire subtree to coerce the types of its operands' projections
203211
self._setop_column_types: t.Dict[int, t.Dict[str, exp.DataType | exp.DataType.Type]] = {}
204212

213+
# When set to False, this enables partial annotation by skipping already-annotated nodes
214+
self._overwrite_types = overwrite_types
215+
216+
def clear(self) -> None:
217+
self._visited.clear()
218+
self._null_expressions.clear()
219+
self._setop_column_types.clear()
220+
205221
def _set_type(
206222
self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
207223
) -> None:
@@ -219,9 +235,12 @@ def _set_type(
219235
elif prev_type and t.cast(exp.DataType, prev_type).this == exp.DataType.Type.NULL:
220236
self._null_expressions.pop(expression_id, None)
221237

222-
def annotate(self, expression: E) -> E:
223-
for scope in traverse_scope(expression):
224-
self.annotate_scope(scope)
238+
def annotate(self, expression: E, annotate_scope: bool = True) -> E:
239+
# This flag is used to avoid costly scope traversals when we only care about annotating
240+
# non-column expressions (partial type inference), e.g., when simplifying in the optimizer
241+
if annotate_scope:
242+
for scope in traverse_scope(expression):
243+
self.annotate_scope(scope)
225244

226245
# This takes care of non-traversable expressions
227246
expression = self._maybe_annotate(expression)
@@ -373,7 +392,11 @@ def annotate_scope(self, scope: Scope) -> None:
373392
scope.expression.meta["query_type"] = struct_type
374393

375394
def _maybe_annotate(self, expression: E) -> E:
376-
if id(expression) in self._visited:
395+
if id(expression) in self._visited or (
396+
not self._overwrite_types
397+
and expression.type
398+
and not expression.is_type(exp.DataType.Type.UNKNOWN)
399+
):
377400
return expression # We've already inferred the expression's type
378401

379402
spec = self.expression_metadata.get(expression.__class__)

sqlglot/optimizer/normalize.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlglot.errors import OptimizeError
77
from sqlglot.helper import while_changing
88
from sqlglot.optimizer.scope import find_all_in_scope
9-
from sqlglot.optimizer.simplify import flatten, rewrite_between, uniq_sort
9+
from sqlglot.optimizer.simplify import Simplifier, flatten
1010

1111
logger = logging.getLogger("sqlglot")
1212

@@ -28,14 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
2828
Returns:
2929
sqlglot.Expression: normalized expression
3030
"""
31+
simplifier = Simplifier(annotate_new_expressions=False)
32+
3133
for node in tuple(expression.walk(prune=lambda e: isinstance(e, exp.Connector))):
3234
if isinstance(node, exp.Connector):
3335
if normalized(node, dnf=dnf):
3436
continue
3537
root = node is expression
3638
original = node.copy()
3739

38-
node.transform(rewrite_between, copy=False)
40+
node.transform(simplifier.rewrite_between, copy=False)
3941
distance = normalization_distance(node, dnf=dnf, max_=max_distance)
4042

4143
if distance > max_distance:
@@ -46,7 +48,10 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
4648

4749
try:
4850
node = node.replace(
49-
while_changing(node, lambda e: distributive_law(e, dnf, max_distance))
51+
while_changing(
52+
node,
53+
lambda e: distributive_law(e, dnf, max_distance, simplifier=simplifier),
54+
)
5055
)
5156
except OptimizeError as e:
5257
logger.info(e)
@@ -146,7 +151,7 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
146151
yield from _predicate_lengths(right, dnf, max_, depth)
147152

148153

149-
def distributive_law(expression, dnf, max_distance):
154+
def distributive_law(expression, dnf, max_distance, simplifier=None):
150155
"""
151156
x OR (y AND z) -> (x OR y) AND (x OR z)
152157
(x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -168,32 +173,34 @@ def distributive_law(expression, dnf, max_distance):
168173
from_func = exp.and_ if from_exp == exp.And else exp.or_
169174
to_func = exp.and_ if to_exp == exp.And else exp.or_
170175

176+
simplifier = simplifier or Simplifier(annotate_new_expressions=False)
177+
171178
if isinstance(a, to_exp) and isinstance(b, to_exp):
172179
if len(tuple(a.find_all(exp.Connector))) > len(tuple(b.find_all(exp.Connector))):
173-
return _distribute(a, b, from_func, to_func)
174-
return _distribute(b, a, from_func, to_func)
180+
return _distribute(a, b, from_func, to_func, simplifier)
181+
return _distribute(b, a, from_func, to_func, simplifier)
175182
if isinstance(a, to_exp):
176-
return _distribute(b, a, from_func, to_func)
183+
return _distribute(b, a, from_func, to_func, simplifier)
177184
if isinstance(b, to_exp):
178-
return _distribute(a, b, from_func, to_func)
185+
return _distribute(a, b, from_func, to_func, simplifier)
179186

180187
return expression
181188

182189

183-
def _distribute(a, b, from_func, to_func):
190+
def _distribute(a, b, from_func, to_func, simplifier):
184191
if isinstance(a, exp.Connector):
185192
exp.replace_children(
186193
a,
187194
lambda c: to_func(
188-
uniq_sort(flatten(from_func(c, b.left))),
189-
uniq_sort(flatten(from_func(c, b.right))),
195+
simplifier.uniq_sort(flatten(from_func(c, b.left))),
196+
simplifier.uniq_sort(flatten(from_func(c, b.right))),
190197
copy=False,
191198
),
192199
)
193200
else:
194201
a = to_func(
195-
uniq_sort(flatten(from_func(a, b.left))),
196-
uniq_sort(flatten(from_func(a, b.right))),
202+
simplifier.uniq_sort(flatten(from_func(a, b.left))),
203+
simplifier.uniq_sort(flatten(from_func(a, b.right))),
197204
copy=False,
198205
)
199206

0 commit comments

Comments
 (0)