diff --git a/pydough/conversion/hybrid_decorrelater.py b/pydough/conversion/hybrid_decorrelater.py index 48a1329d..a4cd5f4a 100644 --- a/pydough/conversion/hybrid_decorrelater.py +++ b/pydough/conversion/hybrid_decorrelater.py @@ -1,5 +1,5 @@ """ -Logic for applying decorrelation to hybrid trees before relational conversion +Logic for applying de-correlation to hybrid trees before relational conversion if the correlate is not a semi/anti join. """ @@ -29,35 +29,85 @@ class Decorrelater: """ - TODO + Class that encapsulates the logic used for de-correlation of hybrid trees. """ def make_decorrelate_parent( self, hybrid: HybridTree, child_idx: int, required_steps: int ) -> HybridTree: """ - TODO + Creates a snapshot of the ancestry of the hybrid tree that contains + a correlated child, without any of its children, its descendants, or + any pipeline operators that do not need to be there. + + Args: + `hybrid`: The hybrid tree to create a snapshot of in order to aid + in the de-correlation of a correlated child. + `child_idx`: The index of the correlated child of hybrid that the + snapshot is being created to aid in the de-correlation of. + `required_steps`: The index of the last pipeline operator that + needs to be included in the snapshot in order for the child to be + derivable. + + Returns: + A snapshot of `hybrid` and its ancestry in the hybrid tree, without + without any of its children or pipeline operators that occur during + or after the derivation of the correlated child, or without any of + its descendants. """ if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0: + # Special case: if the correlated child is the data argument of a + # partition operation, then the parent to snapshot is actually the + # parent of the level containing the partition operation. In this + # case, all of the parent's children & pipeline operators should be + # included in the snapshot. assert hybrid.parent is not None return self.make_decorrelate_parent( hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline) ) + # Temporarily detach the successor of the current level, then create a + # deep copy of the current level (which will include its ancestors), + # then reattach the successor back to the original. This ensures that + # the descendants of the current level are not included when providing + # the parent to the correlated child as its new ancestor. + successor: HybridTree | None = hybrid.successor hybrid._successor = None new_hybrid: HybridTree = copy.deepcopy(hybrid) + hybrid._successor = successor + # Ensure the new parent only includes the children & pipeline operators + # that is has to. new_hybrid._children = new_hybrid._children[:child_idx] new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1] - # breakpoint() return new_hybrid def remove_correl_refs( self, expr: HybridExpr, parent: HybridTree, child_height: int ) -> HybridExpr: """ - TODO + Recursively & destructively removes correlated references within a + hybrid expression if they point to a specific correlated ancestor + hybrid tree, and replaces them with corresponding BACK references. + + Args: + `expr`: The hybrid expression to remove correlated references from. + `parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK reference. + + Returns: + The hybrid expression with all correlated references to `parent` + replaced with corresponding BACK references. The replacement also + happens in-place. """ match expr: case HybridCorrelExpr(): + # If the correlated reference points to the parent, then + # replace it with a BACK reference. Otherwise, recursively + # transform its input expression in case it contains another + # correlated reference. if expr.hybrid is parent: result: HybridExpr | None = expr.expr.shift_back(child_height) assert result is not None @@ -66,10 +116,14 @@ def remove_correl_refs( expr.expr = self.remove_correl_refs(expr.expr, parent, child_height) return expr case HybridFunctionExpr(): + # For regular functions, recursively transform all of their + # arguments. for idx, arg in enumerate(expr.args): expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) return expr case HybridWindowExpr(): + # For window functions, recursively transform all of their + # arguments, partition keys, and order keys. for idx, arg in enumerate(expr.args): expr.args[idx] = self.remove_correl_refs(arg, parent, child_height) for idx, arg in enumerate(expr.partition_args): @@ -88,6 +142,8 @@ def remove_correl_refs( | HybridLiteralExpr() | HybridColumnExpr() ): + # All other expression types do not require any transformation + # to de-correlate since they cannot contain correlations. return expr case _: raise NotImplementedError( @@ -102,13 +158,37 @@ def correl_ref_purge( child_height: int, ) -> None: """ - TODO + The recursive procedure to remove correlated references from the + expressions of a hybrid tree or any of its ancestors or children if + they refer to a specific correlated ancestor that is being removed. + + Args: + `level`: The current level of the hybrid tree to remove correlated + references from. + `old_parent`: The correlated ancestor hybrid tree that the correlated + references should point to when they are targeted for removal. + `new_parent`: The ancestor of `level` that removal should stop at + because it is the transposed snapshot of `old_parent`, and + therefore it & its ancestors cannot contain any more correlated + references that would be targeted for removal. + `child_height`: The height of the correlated child within the + hybrid tree that the correlated references is point to. This is + the number of BACK indices to shift by when replacing the + correlated reference with a BACK """ while level is not None and level is not new_parent: + # First, recursively remove any targeted correlated references from + # the children of the current level. for child in level.children: self.correl_ref_purge( child.subtree, old_parent, new_parent, child_height ) + # Then, remove any correlated references from the pipeline + # operators of the current level. Usually this just means + # transforming the terms/orderings/unique keys of the operation, + # but specific operation types will require special casing if they + # have additional expressions stored in other field that need to be + # transformed. for operation in level.pipeline: for name, expr in operation.terms.items(): operation.terms[name] = self.remove_correl_refs( @@ -131,6 +211,8 @@ def correl_ref_purge( operation.condition = self.remove_correl_refs( operation.condition, old_parent, child_height ) + # Repeat the process on the ancestor until either loop guard + # condition is no longer True. level = level.parent def decorrelate_child( @@ -141,7 +223,13 @@ def decorrelate_child( is_aggregate: bool, ) -> None: """ - TODO + Runs the logic to de-correlate a child of a hybrid tree that contains + a correlated reference. This involves linking the child to a new parent + as its ancestor, the parent being a snapshot of the original hybrid + tree that contained the correlated child as a child. The transformed + child can now replace correlated references with BACK references that + point to terms in its newly expanded ancestry, and the original hybrid + tree cna now join onto this child using its uniqueness keys. """ # First, find the height of the child subtree & its top-most level. child_root: HybridTree = child.subtree @@ -168,6 +256,7 @@ def decorrelate_child( current_level = current_level.parent additional_levels += 1 child.subtree.join_keys = new_join_keys + # If aggregating, do the same with the aggregation keys. if is_aggregate: new_agg_keys: list[HybridExpr] = [] assert child.subtree.join_keys is not None @@ -183,6 +272,7 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: # hybrid tree. if hybrid.parent is not None: hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent) + hybrid._parent._successor = hybrid # Iterate across all the children and recursively decorrelate them. for child in hybrid.children: child.subtree = self.decorrelate_hybrid_tree(child.subtree) @@ -224,7 +314,17 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree: def run_hybrid_decorrelation(hybrid: HybridTree) -> HybridTree: """ - TODO + Invokes the procedure to remove correlated references from a hybrid tree + before relational conversion if those correlated references are invalid + (e.g. not from a semi/anti join). + + Args: + `hybrid`: The hybrid tree to remove correlated references from. + + Returns: + The hybrid tree with all invalid correlated references removed as the + tree structure is re-written to allow them to be replaced with BACK + references. The transformation is also done in-place. """ decorr: Decorrelater = Decorrelater() return decorr.decorrelate_hybrid_tree(hybrid)