Skip to content

Commit

Permalink
Added documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
knassre-bodo committed Feb 10, 2025
1 parent 8d75668 commit 2cfe999
Showing 1 changed file with 108 additions and 8 deletions.
116 changes: 108 additions & 8 deletions pydough/conversion/hybrid_decorrelater.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Logic for applying decorrelation to hybrid trees before relational conversion
Logic for applying de-correlation to hybrid trees before relational conversion
if the correlate is not a semi/anti join.
"""

Expand Down Expand Up @@ -29,35 +29,85 @@

class Decorrelater:
"""
TODO
Class that encapsulates the logic used for de-correlation of hybrid trees.
"""

def make_decorrelate_parent(
self, hybrid: HybridTree, child_idx: int, required_steps: int
) -> HybridTree:
"""
TODO
Creates a snapshot of the ancestry of the hybrid tree that contains
a correlated child, without any of its children, its descendants, or
any pipeline operators that do not need to be there.
Args:
`hybrid`: The hybrid tree to create a snapshot of in order to aid
in the de-correlation of a correlated child.
`child_idx`: The index of the correlated child of hybrid that the
snapshot is being created to aid in the de-correlation of.
`required_steps`: The index of the last pipeline operator that
needs to be included in the snapshot in order for the child to be
derivable.
Returns:
A snapshot of `hybrid` and its ancestry in the hybrid tree, without
without any of its children or pipeline operators that occur during
or after the derivation of the correlated child, or without any of
its descendants.
"""
if isinstance(hybrid.pipeline[0], HybridPartition) and child_idx == 0:
# Special case: if the correlated child is the data argument of a
# partition operation, then the parent to snapshot is actually the
# parent of the level containing the partition operation. In this
# case, all of the parent's children & pipeline operators should be
# included in the snapshot.
assert hybrid.parent is not None
return self.make_decorrelate_parent(
hybrid.parent, len(hybrid.parent.children), len(hybrid.pipeline)
)
# Temporarily detach the successor of the current level, then create a
# deep copy of the current level (which will include its ancestors),
# then reattach the successor back to the original. This ensures that
# the descendants of the current level are not included when providing
# the parent to the correlated child as its new ancestor.
successor: HybridTree | None = hybrid.successor
hybrid._successor = None
new_hybrid: HybridTree = copy.deepcopy(hybrid)
hybrid._successor = successor
# Ensure the new parent only includes the children & pipeline operators
# that is has to.
new_hybrid._children = new_hybrid._children[:child_idx]
new_hybrid._pipeline = new_hybrid._pipeline[: required_steps + 1]
# breakpoint()
return new_hybrid

def remove_correl_refs(
self, expr: HybridExpr, parent: HybridTree, child_height: int
) -> HybridExpr:
"""
TODO
Recursively & destructively removes correlated references within a
hybrid expression if they point to a specific correlated ancestor
hybrid tree, and replaces them with corresponding BACK references.
Args:
`expr`: The hybrid expression to remove correlated references from.
`parent`: The correlated ancestor hybrid tree that the correlated
references should point to when they are targeted for removal.
`child_height`: The height of the correlated child within the
hybrid tree that the correlated references is point to. This is
the number of BACK indices to shift by when replacing the
correlated reference with a BACK reference.
Returns:
The hybrid expression with all correlated references to `parent`
replaced with corresponding BACK references. The replacement also
happens in-place.
"""
match expr:
case HybridCorrelExpr():
# If the correlated reference points to the parent, then
# replace it with a BACK reference. Otherwise, recursively
# transform its input expression in case it contains another
# correlated reference.
if expr.hybrid is parent:
result: HybridExpr | None = expr.expr.shift_back(child_height)
assert result is not None
Expand All @@ -66,10 +116,14 @@ def remove_correl_refs(
expr.expr = self.remove_correl_refs(expr.expr, parent, child_height)
return expr
case HybridFunctionExpr():
# For regular functions, recursively transform all of their
# arguments.
for idx, arg in enumerate(expr.args):
expr.args[idx] = self.remove_correl_refs(arg, parent, child_height)
return expr
case HybridWindowExpr():
# For window functions, recursively transform all of their
# arguments, partition keys, and order keys.
for idx, arg in enumerate(expr.args):
expr.args[idx] = self.remove_correl_refs(arg, parent, child_height)
for idx, arg in enumerate(expr.partition_args):
Expand All @@ -88,6 +142,8 @@ def remove_correl_refs(
| HybridLiteralExpr()
| HybridColumnExpr()
):
# All other expression types do not require any transformation
# to de-correlate since they cannot contain correlations.
return expr
case _:
raise NotImplementedError(
Expand All @@ -102,13 +158,37 @@ def correl_ref_purge(
child_height: int,
) -> None:
"""
TODO
The recursive procedure to remove correlated references from the
expressions of a hybrid tree or any of its ancestors or children if
they refer to a specific correlated ancestor that is being removed.
Args:
`level`: The current level of the hybrid tree to remove correlated
references from.
`old_parent`: The correlated ancestor hybrid tree that the correlated
references should point to when they are targeted for removal.
`new_parent`: The ancestor of `level` that removal should stop at
because it is the transposed snapshot of `old_parent`, and
therefore it & its ancestors cannot contain any more correlated
references that would be targeted for removal.
`child_height`: The height of the correlated child within the
hybrid tree that the correlated references is point to. This is
the number of BACK indices to shift by when replacing the
correlated reference with a BACK
"""
while level is not None and level is not new_parent:
# First, recursively remove any targeted correlated references from
# the children of the current level.
for child in level.children:
self.correl_ref_purge(
child.subtree, old_parent, new_parent, child_height
)
# Then, remove any correlated references from the pipeline
# operators of the current level. Usually this just means
# transforming the terms/orderings/unique keys of the operation,
# but specific operation types will require special casing if they
# have additional expressions stored in other field that need to be
# transformed.
for operation in level.pipeline:
for name, expr in operation.terms.items():
operation.terms[name] = self.remove_correl_refs(
Expand All @@ -131,6 +211,8 @@ def correl_ref_purge(
operation.condition = self.remove_correl_refs(
operation.condition, old_parent, child_height
)
# Repeat the process on the ancestor until either loop guard
# condition is no longer True.
level = level.parent

def decorrelate_child(
Expand All @@ -141,7 +223,13 @@ def decorrelate_child(
is_aggregate: bool,
) -> None:
"""
TODO
Runs the logic to de-correlate a child of a hybrid tree that contains
a correlated reference. This involves linking the child to a new parent
as its ancestor, the parent being a snapshot of the original hybrid
tree that contained the correlated child as a child. The transformed
child can now replace correlated references with BACK references that
point to terms in its newly expanded ancestry, and the original hybrid
tree cna now join onto this child using its uniqueness keys.
"""
# First, find the height of the child subtree & its top-most level.
child_root: HybridTree = child.subtree
Expand All @@ -168,6 +256,7 @@ def decorrelate_child(
current_level = current_level.parent
additional_levels += 1
child.subtree.join_keys = new_join_keys
# If aggregating, do the same with the aggregation keys.
if is_aggregate:
new_agg_keys: list[HybridExpr] = []
assert child.subtree.join_keys is not None
Expand All @@ -183,6 +272,7 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree:
# hybrid tree.
if hybrid.parent is not None:
hybrid._parent = self.decorrelate_hybrid_tree(hybrid.parent)
hybrid._parent._successor = hybrid
# Iterate across all the children and recursively decorrelate them.
for child in hybrid.children:
child.subtree = self.decorrelate_hybrid_tree(child.subtree)
Expand Down Expand Up @@ -224,7 +314,17 @@ def decorrelate_hybrid_tree(self, hybrid: HybridTree) -> HybridTree:

def run_hybrid_decorrelation(hybrid: HybridTree) -> HybridTree:
"""
TODO
Invokes the procedure to remove correlated references from a hybrid tree
before relational conversion if those correlated references are invalid
(e.g. not from a semi/anti join).
Args:
`hybrid`: The hybrid tree to remove correlated references from.
Returns:
The hybrid tree with all invalid correlated references removed as the
tree structure is re-written to allow them to be replaced with BACK
references. The transformation is also done in-place.
"""
decorr: Decorrelater = Decorrelater()
return decorr.decorrelate_hybrid_tree(hybrid)

0 comments on commit 2cfe999

Please sign in to comment.