Skip to content

Commit

Permalink
Memoize results in SourceScanOptimizer for better CTE generation (#…
Browse files Browse the repository at this point in the history
…1563)

Generation of CTEs is determined by identical nodes in the dataflow
plan. Previously, the `SourceScanOptimizer` would create new nodes
whenever it was able to optimize a branch. In addition, it would create
new nodes in some cases where it was not necessary.

This PR updates `SourceScanOptimizer` to memoize results so that
identical nodes are used in cases where the dataflow branch is the same.

This results in snapshot changes for tests where the optimizer
previously prevented CTEs from being used - please view by commit.
  • Loading branch information
plypaul authored Dec 11, 2024
1 parent d42ed8a commit 50c1c5a
Show file tree
Hide file tree
Showing 280 changed files with 10,193 additions and 9,799 deletions.
3 changes: 3 additions & 0 deletions metricflow/dataflow/dataflow_plan_analyzer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
from collections import defaultdict
from typing import Dict, FrozenSet, Mapping, Sequence, Set

Expand All @@ -8,6 +9,8 @@
from metricflow.dataflow.dataflow_plan import DataflowPlan, DataflowPlanNode
from metricflow.dataflow.dataflow_plan_visitor import DataflowPlanNodeVisitorWithDefaultHandler

logger = logging.getLogger(__name__)


class DataflowPlanAnalyzer:
"""Class to determine more complex properties of the dataflow plan.
Expand Down
15 changes: 11 additions & 4 deletions metricflow/dataflow/optimizer/source_scan/cm_branch_combiner.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, left_branch_node: DataflowPlanNode) -> None: # noqa: D107
self._current_left_node: DataflowPlanNode = left_branch_node

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.debug(LazyFormat(lambda: f"Visiting {node}"))
logger.debug(lambda: f"Visiting {node.node_id}")

def _log_combine_failure(
self,
Expand All @@ -142,8 +142,10 @@ def _log_combine_failure(
) -> None:
logger.debug(
LazyFormat(
lambda: f"Because {combine_failure_reason}, unable to combine nodes "
f"left_node={left_node} right_node={right_node}",
"Unable to combine nodes",
combine_failure_reason=combine_failure_reason,
left_node=left_node.node_id,
right_node=right_node.node_id,
)
)

Expand All @@ -154,7 +156,12 @@ def _log_combine_success(
combined_node: DataflowPlanNode,
) -> None:
logger.debug(
LazyFormat(lambda: f"Combined left_node={left_node} right_node={right_node} combined_node: {combined_node}")
LazyFormat(
"Successfully combined nodes",
left_node=left_node.node_id,
right_node=right_node.node_id,
combined_node=combined_node.node_id,
)
)

def _combine_parent_branches(self, current_right_node: DataflowPlanNode) -> Optional[Sequence[DataflowPlanNode]]:
Expand Down
73 changes: 50 additions & 23 deletions metricflow/dataflow/optimizer/source_scan/source_scan_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from dataclasses import dataclass
from typing import List, Optional, Sequence
from typing import Dict, List, Optional, Sequence

from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId
Expand Down Expand Up @@ -110,21 +110,33 @@ class SourceScanOptimizer(
parents.
"""

def __init__(self) -> None: # noqa: D107
self._node_to_result: Dict[DataflowPlanNode, OptimizeBranchResult] = {}

def _log_visit_node_type(self, node: DataflowPlanNode) -> None:
logger.debug(LazyFormat(lambda: f"Visiting {node}"))
logger.debug(LazyFormat(lambda: f"Visiting {node.node_id}"))

def _default_base_output_handler(
self,
node: DataflowPlanNode,
) -> OptimizeBranchResult:
optimized_parents: Sequence[OptimizeBranchResult] = tuple(
parent_node.accept(self) for parent_node in node.parent_nodes
)
# Parents should always be DataflowPlanNode
return OptimizeBranchResult(
optimized_branch=node.with_new_parents(tuple(x.optimized_branch for x in optimized_parents))
memoized_result = self._node_to_result.get(node)
if memoized_result is not None:
return memoized_result

optimized_parent_nodes: Sequence[DataflowPlanNode] = tuple(
parent_node.accept(self).optimized_branch for parent_node in node.parent_nodes
)

# If no optimization is done, use the same nodes so that common operations can be identified for CTE generation.
if tuple(node.parent_nodes) == optimized_parent_nodes:
result = OptimizeBranchResult(optimized_branch=node)
else:
result = OptimizeBranchResult(optimized_branch=node.with_new_parents(optimized_parent_nodes))

self._node_to_result[node] = result
return result

def visit_source_node(self, node: ReadSqlSourceNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
return self._default_base_output_handler(node)
Expand All @@ -144,18 +156,26 @@ def visit_window_reaggregation_node(self, node: WindowReaggregationNode) -> Opti
def visit_compute_metrics_node(self, node: ComputeMetricsNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
# Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG.

memoized_result = self._node_to_result.get(node)
if memoized_result is not None:
return memoized_result

optimized_parent_result: OptimizeBranchResult = node.parent_node.accept(self)
if optimized_parent_result.optimized_branch is not None:
return OptimizeBranchResult(
result = OptimizeBranchResult(
optimized_branch=ComputeMetricsNode.create(
parent_node=optimized_parent_result.optimized_branch,
metric_specs=node.metric_specs,
for_group_by_source_node=node.for_group_by_source_node,
aggregated_to_elements=node.aggregated_to_elements,
)
)
else:
result = OptimizeBranchResult(optimized_branch=node)

return OptimizeBranchResult(optimized_branch=node)
self._node_to_result[node] = result
return result

def visit_order_by_limit_node(self, node: OrderByLimitNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
Expand Down Expand Up @@ -220,11 +240,16 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
self, node: CombineAggregatedOutputsNode
) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
# The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or CombineAggregatedOutputsNodes

memoized_result = self._node_to_result.get(node)
if memoized_result is not None:
return memoized_result

# The parent node of the CombineAggregatedOutputsNode can be either ComputeMetricsNodes or
# CombineAggregatedOutputsNodes.
# Stores the result of running this optimizer on each parent branch separately.
optimized_parent_branches = []
logger.debug(LazyFormat(lambda: f"{node} has {len(node.parent_nodes)} parent branches"))
logger.debug(LazyFormat(lambda: f"{node.node_id} has {len(node.parent_nodes)} parent branches"))

# Run the optimizer on the parent branch to handle derived metrics, which are defined recursively in the DAG.
for parent_branch in node.parent_nodes:
Expand Down Expand Up @@ -257,14 +282,17 @@ def visit_combine_aggregated_outputs_node( # noqa: D102
logger.debug(lambda: f"Got {len(combined_parent_branches)} branches after combination")
assert len(combined_parent_branches) > 0

# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's no need
# for a CombineAggregatedOutputsNode.
# If we were able to reduce the parent branches of the CombineAggregatedOutputsNode into a single one, there's
# no need for a CombineAggregatedOutputsNode.
if len(combined_parent_branches) == 1:
return OptimizeBranchResult(optimized_branch=combined_parent_branches[0])
result = OptimizeBranchResult(optimized_branch=combined_parent_branches[0])
else:
result = OptimizeBranchResult(
optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches)
)

return OptimizeBranchResult(
optimized_branch=CombineAggregatedOutputsNode.create(parent_nodes=combined_parent_branches)
)
self._node_to_result[node] = result
return result

def visit_constrain_time_range_node(self, node: ConstrainTimeRangeNode) -> OptimizeBranchResult: # noqa: D102
self._log_visit_node_type(node)
Expand All @@ -289,11 +317,10 @@ def optimize(self, dataflow_plan: DataflowPlan) -> DataflowPlan: # noqa: D102

logger.debug(
LazyFormat(
lambda: f"Optimized:\n\n"
f"{dataflow_plan.sink_node.structure_text()}\n\n"
f"to:\n\n"
f"{optimized_result.optimized_branch.structure_text()}",
),
"Optimized dataflow plan",
original_plan=dataflow_plan.sink_node.structure_text(),
optimized_plan=optimized_result.optimized_branch.structure_text(),
)
)

return DataflowPlan(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,25 @@ docstring:
sql_engine: BigQuery
---
-- Compute Metrics via Expressions
WITH sma_28019_cte AS (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, user_id AS user
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
)

SELECT
metric_time__day
metric_time__day AS metric_time__day
, CAST(buys AS FLOAT64) / CAST(NULLIF(visits, 0) AS FLOAT64) AS visit_buy_conversion_rate
FROM (
-- Combine Aggregated Outputs
SELECT
COALESCE(subq_21.metric_time__day, subq_32.metric_time__day) AS metric_time__day
COALESCE(subq_21.metric_time__day, subq_31.metric_time__day) AS metric_time__day
, MAX(subq_21.visits) AS visits
, MAX(subq_32.buys) AS buys
, MAX(subq_31.buys) AS buys
FROM (
-- Constrain Output with WHERE
-- Pass Only Elements: ['visits', 'metric_time__day']
Expand All @@ -22,12 +32,12 @@ FROM (
metric_time__day
, SUM(visits) AS visits
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
-- Read From CTE For node_id=sma_28019
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
metric_time__day
, sma_28019_cte.user
, visits
FROM sma_28019_cte sma_28019_cte
) subq_18
WHERE metric_time__day = '2020-01-01'
GROUP BY
Expand All @@ -43,50 +53,49 @@ FROM (
FROM (
-- Dedupe the fanout with mf_internal_uuid in the conversion data set
SELECT DISTINCT
FIRST_VALUE(subq_25.visits) OVER (
FIRST_VALUE(subq_24.visits) OVER (
PARTITION BY
subq_28.user
, subq_28.metric_time__day
, subq_28.mf_internal_uuid
ORDER BY subq_25.metric_time__day DESC
subq_27.user
, subq_27.metric_time__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.metric_time__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS visits
, FIRST_VALUE(subq_25.metric_time__day) OVER (
, FIRST_VALUE(subq_24.metric_time__day) OVER (
PARTITION BY
subq_28.user
, subq_28.metric_time__day
, subq_28.mf_internal_uuid
ORDER BY subq_25.metric_time__day DESC
subq_27.user
, subq_27.metric_time__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.metric_time__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS metric_time__day
, FIRST_VALUE(subq_25.user) OVER (
, FIRST_VALUE(subq_24.user) OVER (
PARTITION BY
subq_28.user
, subq_28.metric_time__day
, subq_28.mf_internal_uuid
ORDER BY subq_25.metric_time__day DESC
subq_27.user
, subq_27.metric_time__day
, subq_27.mf_internal_uuid
ORDER BY subq_24.metric_time__day DESC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS user
, subq_28.mf_internal_uuid AS mf_internal_uuid
, subq_28.buys AS buys
, subq_27.mf_internal_uuid AS mf_internal_uuid
, subq_27.buys AS buys
FROM (
-- Constrain Output with WHERE
-- Pass Only Elements: ['visits', 'metric_time__day', 'user']
SELECT
metric_time__day
, subq_23.user
, subq_22.user
, visits
FROM (
-- Read Elements From Semantic Model 'visits_source'
-- Metric Time Dimension 'ds'
-- Read From CTE For node_id=sma_28019
SELECT
DATETIME_TRUNC(ds, day) AS metric_time__day
, user_id AS user
, 1 AS visits
FROM ***************************.fct_visits visits_source_src_28000
) subq_23
metric_time__day
, sma_28019_cte.user
, visits
FROM sma_28019_cte sma_28019_cte
) subq_22
WHERE metric_time__day = '2020-01-01'
) subq_25
) subq_24
INNER JOIN (
-- Read Elements From Semantic Model 'buys_source'
-- Metric Time Dimension 'ds'
Expand All @@ -97,19 +106,19 @@ FROM (
, 1 AS buys
, GENERATE_UUID() AS mf_internal_uuid
FROM ***************************.fct_buys buys_source_src_28000
) subq_28
) subq_27
ON
(
subq_25.user = subq_28.user
subq_24.user = subq_27.user
) AND (
(subq_25.metric_time__day <= subq_28.metric_time__day)
(subq_24.metric_time__day <= subq_27.metric_time__day)
)
) subq_29
) subq_28
GROUP BY
metric_time__day
) subq_32
) subq_31
ON
subq_21.metric_time__day = subq_32.metric_time__day
subq_21.metric_time__day = subq_31.metric_time__day
GROUP BY
metric_time__day
) subq_33
) subq_32
Loading

0 comments on commit 50c1c5a

Please sign in to comment.