66from sqlglot .errors import OptimizeError
77from sqlglot .helper import while_changing
88from sqlglot .optimizer .scope import find_all_in_scope
9- from sqlglot .optimizer .simplify import flatten , rewrite_between , uniq_sort
9+ from sqlglot .optimizer .simplify import Simplifier , flatten
1010
1111logger = logging .getLogger ("sqlglot" )
1212
@@ -28,14 +28,16 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
2828 Returns:
2929 sqlglot.Expression: normalized expression
3030 """
31+ simplifier = Simplifier (annotate_new_expressions = False )
32+
3133 for node in tuple (expression .walk (prune = lambda e : isinstance (e , exp .Connector ))):
3234 if isinstance (node , exp .Connector ):
3335 if normalized (node , dnf = dnf ):
3436 continue
3537 root = node is expression
3638 original = node .copy ()
3739
38- node .transform (rewrite_between , copy = False )
40+ node .transform (simplifier . rewrite_between , copy = False )
3941 distance = normalization_distance (node , dnf = dnf , max_ = max_distance )
4042
4143 if distance > max_distance :
@@ -46,7 +48,10 @@ def normalize(expression: exp.Expression, dnf: bool = False, max_distance: int =
4648
4749 try :
4850 node = node .replace (
49- while_changing (node , lambda e : distributive_law (e , dnf , max_distance ))
51+ while_changing (
52+ node ,
53+ lambda e : distributive_law (e , dnf , max_distance , simplifier = simplifier ),
54+ )
5055 )
5156 except OptimizeError as e :
5257 logger .info (e )
@@ -146,7 +151,7 @@ def _predicate_lengths(expression, dnf, max_=float("inf"), depth=0):
146151 yield from _predicate_lengths (right , dnf , max_ , depth )
147152
148153
149- def distributive_law (expression , dnf , max_distance ):
154+ def distributive_law (expression , dnf , max_distance , simplifier = None ):
150155 """
151156 x OR (y AND z) -> (x OR y) AND (x OR z)
152157 (x AND y) OR (y AND z) -> (x OR y) AND (x OR z) AND (y OR y) AND (y OR z)
@@ -168,32 +173,34 @@ def distributive_law(expression, dnf, max_distance):
168173 from_func = exp .and_ if from_exp == exp .And else exp .or_
169174 to_func = exp .and_ if to_exp == exp .And else exp .or_
170175
176+ simplifier = simplifier or Simplifier (annotate_new_expressions = False )
177+
171178 if isinstance (a , to_exp ) and isinstance (b , to_exp ):
172179 if len (tuple (a .find_all (exp .Connector ))) > len (tuple (b .find_all (exp .Connector ))):
173- return _distribute (a , b , from_func , to_func )
174- return _distribute (b , a , from_func , to_func )
180+ return _distribute (a , b , from_func , to_func , simplifier )
181+ return _distribute (b , a , from_func , to_func , simplifier )
175182 if isinstance (a , to_exp ):
176- return _distribute (b , a , from_func , to_func )
183+ return _distribute (b , a , from_func , to_func , simplifier )
177184 if isinstance (b , to_exp ):
178- return _distribute (a , b , from_func , to_func )
185+ return _distribute (a , b , from_func , to_func , simplifier )
179186
180187 return expression
181188
182189
183- def _distribute (a , b , from_func , to_func ):
190+ def _distribute (a , b , from_func , to_func , simplifier ):
184191 if isinstance (a , exp .Connector ):
185192 exp .replace_children (
186193 a ,
187194 lambda c : to_func (
188- uniq_sort (flatten (from_func (c , b .left ))),
189- uniq_sort (flatten (from_func (c , b .right ))),
195+ simplifier . uniq_sort (flatten (from_func (c , b .left ))),
196+ simplifier . uniq_sort (flatten (from_func (c , b .right ))),
190197 copy = False ,
191198 ),
192199 )
193200 else :
194201 a = to_func (
195- uniq_sort (flatten (from_func (a , b .left ))),
196- uniq_sort (flatten (from_func (a , b .right ))),
202+ simplifier . uniq_sort (flatten (from_func (a , b .left ))),
203+ simplifier . uniq_sort (flatten (from_func (a , b .right ))),
197204 copy = False ,
198205 )
199206
0 commit comments