Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

draft of deferred caching #1185

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 24 additions & 34 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int:
def cached(
self,
array_value: bigframes.core.ArrayValue,
*,
force: bool = False,
use_session: bool = False,
cluster_cols: Sequence[str] = (),
) -> None:
raise NotImplementedError("cached not implemented for this executor")

Expand Down Expand Up @@ -209,6 +205,9 @@ def __init__(
bigframes.core.compile.SQLCompiler(strict=strictly_ordered)
)
self.strictly_ordered: bool = strictly_ordered
self._need_caching: weakref.WeakSet[
nodes.BigFrameNode
] = weakref.WeakSet()
self._cached_executions: weakref.WeakKeyDictionary[
nodes.BigFrameNode, nodes.BigFrameNode
] = weakref.WeakKeyDictionary()
Expand Down Expand Up @@ -443,19 +442,10 @@ def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int:
def cached(
self,
array_value: bigframes.core.ArrayValue,
*,
force: bool = False,
use_session: bool = False,
cluster_cols: Sequence[str] = (),
) -> None:
"""Write the block to a session table."""
# use a heuristic for whether something needs to be cached
if (not force) and self._is_trivially_executable(array_value):
return
elif use_session:
self._cache_with_session_awareness(array_value)
else:
self._cache_with_cluster_cols(array_value, cluster_cols=cluster_cols)
self._need_caching.add(array_value.node)

def _local_get_row_count(
self, array_value: bigframes.core.ArrayValue
Expand All @@ -466,6 +456,12 @@ def _local_get_row_count(
return tree_properties.row_count(plan)

# Helpers
def _cache_subtrees(self, root: nodes.BigFrameNode):
for child in root.child_nodes:
self._cache_subtrees(child)
if root in self._need_caching:
self._cache_with_offsets(root)

def _run_execute_query(
self,
sql: str,
Expand Down Expand Up @@ -531,6 +527,19 @@ def _wait_on_job(
def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
return tree_properties.replace_nodes(node, (dict(self._cached_executions)))

def _materialize_to_cache(self, node: nodes.BigFrameNode):
session_forest = [obj._block._expr.node for obj in node.session.objects]
# So there are a few different options to balance when clustering
# 1. Cluster by ordering - this is often correlated with other
# 2. Cluster by immediate need - This is often best
# 3. Cluster by session need
cluster_cols = bigframes.session.planner.select_cluster_cols(node, session_forest)
if cluster_cols or not self.strictly_ordered:
self._cache_with_cluster_cols(bigframes.core.ArrayValue(node), cluster_cols)
else:
# even in strict ordering mode, probably enough to materialize some ordering key
self._cache_with_offsets(bigframes.core.ArrayValue(node))

def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
"""
Can the block be evaluated very cheaply?
Expand Down Expand Up @@ -580,25 +589,6 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
).node
self._cached_executions[array_value.node] = cached_replacement

def _cache_with_session_awareness(
self,
array_value: bigframes.core.ArrayValue,
) -> None:
session_forest = [obj._block._expr.node for obj in array_value.session.objects]
# These node types are cheap to re-compute
target, cluster_cols = bigframes.session.planner.session_aware_cache_plan(
array_value.node, list(session_forest)
)
cluster_cols_sql_names = [id.sql for id in cluster_cols]
if len(cluster_cols) > 0:
self._cache_with_cluster_cols(
bigframes.core.ArrayValue(target), cluster_cols_sql_names
)
elif self.strictly_ordered:
self._cache_with_offsets(bigframes.core.ArrayValue(target))
else:
self._cache_with_cluster_cols(bigframes.core.ArrayValue(target), [])

def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
"""Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces."""
# Apply existing caching first
Expand Down Expand Up @@ -626,7 +616,7 @@ def _cache_most_complex_subtree(self, node: nodes.BigFrameNode) -> bool:
# No good subtrees to cache, just return original tree
return False

self._cache_with_cluster_cols(bigframes.core.ArrayValue(selection), [])
self._materialize_to_cache(selection)
return True

def _sql_as_cached_temp_table(
Expand Down
80 changes: 32 additions & 48 deletions bigframes/session/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,71 +15,55 @@
from __future__ import annotations

import itertools
from typing import Sequence, Tuple
from typing import Sequence

import bigframes.core.expression as ex
import bigframes.core.identifiers as ids
import bigframes.core.nodes as nodes
import bigframes.core.pruning as predicate_pruning
import bigframes.core.tree_properties as traversals
import functools
import bigframes.dtypes
import bigframes.core.expression


def session_aware_cache_plan(
root: nodes.BigFrameNode, session_forest: Sequence[nodes.BigFrameNode]
) -> Tuple[nodes.BigFrameNode, list[ids.ColumnId]]:
def select_cluster_cols(
cache_node: nodes.BigFrameNode, session_forest: Sequence[nodes.BigFrameNode]
) -> set[ids.ColumnId]:
"""
Determines the best node to cache given a target and a list of object roots for objects in a session.

Returns the node to cache, and optionally a clustering column.
Determines the best cluster cols for materializing a target node give a list of session trees.
"""
node_counts = traversals.count_nodes(session_forest)
# These node types are cheap to re-compute, so it makes more sense to cache their children.
de_cachable_types = (nodes.FilterNode, nodes.ProjectionNode, nodes.SelectionNode)
caching_target = cur_node = root
caching_target_refs = node_counts.get(caching_target, 0)

filters: list[
ex.Expression
] = [] # accumulate filters into this as traverse downwards
clusterable_cols: set[ids.ColumnId] = set()
while isinstance(cur_node, de_cachable_types):
if isinstance(cur_node, nodes.FilterNode):
@functools.cache
def find_direct_predicates(cache_node, root_node: nodes.BigFrameNode) -> set[bigframes.core.expression.Expression]:
if isinstance(root_node, nodes.FilterNode):
# Filter node doesn't define any variables, so no need to chain expressions
filters.append(cur_node.predicate)
elif isinstance(cur_node, nodes.ProjectionNode):
filters.append(root_node.predicate)
elif isinstance(root_node, nodes.ProjectionNode):
# Projection defines the variables that are used in the filter expressions, need to substitute variables with their scalar expressions
# that instead reference variables in the child node.
bindings = {name: expr for expr, name in cur_node.assignments}
bindings = {name: expr for expr, name in root_node.assignments}
filters = [
i.bind_refs(bindings, allow_partial_bindings=True) for i in filters
]
elif isinstance(cur_node, nodes.SelectionNode):
bindings = {output: input for input, output in cur_node.input_output_pairs}
elif isinstance(root_node, nodes.SelectionNode):
bindings = {output: input for input, output in root_node.input_output_pairs}
filters = [i.bind_refs(bindings) for i in filters]
else:
raise ValueError(f"Unexpected de-cached node: {cur_node}")
return frozenset().union(find_direct_predicates(cache_node, child) for child in root_node.child_nodes)


cur_node = cur_node.child
cur_node_refs = node_counts.get(cur_node, 0)
if cur_node_refs > caching_target_refs:
caching_target, caching_target_refs = cur_node, cur_node_refs
cluster_compatible_cols = {
field.id
for field in cur_node.fields
if bigframes.dtypes.is_clusterable(field.dtype)
}
# Cluster cols only consider the target object and not other sesssion objects
clusterable_cols = set(
itertools.chain.from_iterable(
map(
lambda f: predicate_pruning.cluster_cols_for_predicate(
f, cluster_compatible_cols
),
filters,
)
)
cluster_compatible_cols = {
field.id
for field in cur_node.fields
if bigframes.dtypes.is_clusterable(field.dtype)
}
clusterable_cols = set(
itertools.chain.from_iterable(
map(
lambda f: predicate_pruning.cluster_cols_for_predicate(
f, cluster_compatible_cols
),
filters,
)
)
)
# BQ supports up to 4 cluster columns, just prioritize by alphabetical ordering
# TODO: Prioritize caching columns by estimated filter selectivity
return caching_target, sorted(list(clusterable_cols))[:4]
return sorted(list(clusterable_cols))[:4]
Loading