Skip to content

Commit

Permalink
Move SqlCteNode to a separate file.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jan 23, 2025
1 parent d401015 commit eac06df
Show file tree
Hide file tree
Showing 19 changed files with 139 additions and 126 deletions.
3 changes: 2 additions & 1 deletion metricflow/plan_conversion/to_sql_plan/dataflow_to_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.plan_conversion.instance_converters import CreateSelectColumnsForInstances
from metricflow.plan_conversion.to_sql_plan.dataflow_to_subquery import DataflowNodeToSqlSubqueryVisitor
from metricflow.sql.sql_plan import SqlCteNode, SqlSelectColumn
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@
)
from metricflow.plan_conversion.to_sql_plan.sql_join_builder import ColumnEqualityDescription, SqlPlanJoinBuilder
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlSelectColumn,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
Expand Down
3 changes: 1 addition & 2 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
)
Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/cte_alias_to_cte_node_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat

from metricflow.sql.sql_plan import SqlCteAliasMapping
from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_select_node import SqlSelectStatementNode

logger = logging.getLogger(__name__)
Expand Down
3 changes: 1 addition & 2 deletions metricflow/sql/optimizer/cte_mapping_lookup_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
)
Expand Down
3 changes: 1 addition & 2 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
Expand Down
2 changes: 1 addition & 1 deletion metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
)
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlPlan,
SqlPlanNode,
SqlPlanNodeVisitor,
Expand Down
3 changes: 2 additions & 1 deletion metricflow/sql/sql_ctas_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

Expand Down
113 changes: 113 additions & 0 deletions metricflow/sql/sql_cte_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

from dataclasses import dataclass
from functools import cached_property
from typing import override, Optional, Sequence, Tuple, Mapping

from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode
from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.visitor import VisitorOutputT


@dataclass(frozen=True, eq=False)
class SqlCteNode(SqlPlanNode):
"""Represents a single common table expression."""

select_statement: SqlPlanNode
cte_alias: str

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(select_statement: SqlPlanNode, cte_alias: str) -> SqlCteNode: # noqa: D102
return SqlCteNode(
parent_nodes=(select_statement,),
select_statement=select_statement,
cte_alias=cte_alias,
)

def with_new_select(self, new_select_statement: SqlPlanNode) -> SqlCteNode:
"""Return a node with the same attributes but with the new SELECT statement."""
return SqlCteNode.create(
select_statement=new_select_statement,
cte_alias=self.cte_alias,
)

@override
def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
return "CTE"

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return self.select_statement.nearest_select_columns(cte_source_mapping)

@override
def copy(self) -> SqlCteNode:
return SqlCteNode(
parent_nodes=self.parent_nodes,
select_statement=self.select_statement,
cte_alias=self.cte_alias,
)


@dataclass(frozen=True)
class SqlCteAliasMapping(Mergeable):
"""Thin, dict-like object that maps an alias to the associated `SqlCteNode`.
When merged, the entries from the right mapping take precedence over the entries from the left.
"""

cte_alias_to_cte_node_items: Tuple[Tuple[str, SqlCteNode], ...] = ()

@staticmethod
def create(cte_alias_to_cte_node_mapping: Mapping[str, SqlCteNode]) -> SqlCteAliasMapping: # noqa: D102
cte_alias_to_cte_node_pairs = []
for cte_alias, cte_node in cte_alias_to_cte_node_mapping.items():
cte_alias_to_cte_node_pairs.append((cte_alias, cte_node))

return SqlCteAliasMapping(cte_alias_to_cte_node_items=tuple(cte_alias_to_cte_node_pairs))

@cached_property
def _cte_alias_to_cte_node_dict(self) -> Mapping[str, SqlCteNode]:
return {item[0]: item[1] for item in self.cte_alias_to_cte_node_items}

def get_cte_node_for_alias(self, cte_alias: str) -> Optional[SqlCteNode]:
"""Return the associated `SqlCteNode` for the given alias, or None if the given alias is not known."""
return self._cte_alias_to_cte_node_dict.get(cte_alias)

@override
def merge(self, other: SqlCteAliasMapping) -> SqlCteAliasMapping:
new_mapping = dict(self._cte_alias_to_cte_node_dict)
for cte_alias, cte_node in other.cte_alias_to_cte_node_items:
new_mapping[cte_alias] = cte_node
return SqlCteAliasMapping.create(new_mapping)

@classmethod
@override
def empty_instance(cls) -> SqlCteAliasMapping:
return SqlCteAliasMapping()
110 changes: 4 additions & 106 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,16 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import cached_property
from typing import Generic, Mapping, Optional, Sequence, Tuple
from typing import Generic, Optional, Sequence

from metricflow_semantics.collection_helpers.merger import Mergeable
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.id_prefix import StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag
from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode
from metricflow_semantics.visitor import VisitorOutputT
from typing_extensions import Self, override
from typing_extensions import Self

from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode
Expand Down Expand Up @@ -147,104 +146,3 @@ def __init__(self, render_node: SqlPlanNode, plan_id: Optional[DagId] = None) ->
@property
def render_node(self) -> SqlPlanNode: # noqa: D102
return self._render_node


@dataclass(frozen=True, eq=False)
class SqlCteNode(SqlPlanNode):
"""Represents a single common table expression."""

select_statement: SqlPlanNode
cte_alias: str

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(select_statement: SqlPlanNode, cte_alias: str) -> SqlCteNode: # noqa: D102
return SqlCteNode(
parent_nodes=(select_statement,),
select_statement=select_statement,
cte_alias=cte_alias,
)

def with_new_select(self, new_select_statement: SqlPlanNode) -> SqlCteNode:
"""Return a node with the same attributes but with the new SELECT statement."""
return SqlCteNode.create(
select_statement=new_select_statement,
cte_alias=self.cte_alias,
)

@override
def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_cte_node(self)

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
return "CTE"

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_COMMON_TABLE_EXPRESSION_ID_PREFIX

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return self.select_statement.nearest_select_columns(cte_source_mapping)

@override
def copy(self) -> SqlCteNode:
return SqlCteNode(
parent_nodes=self.parent_nodes,
select_statement=self.select_statement,
cte_alias=self.cte_alias,
)


@dataclass(frozen=True)
class SqlCteAliasMapping(Mergeable):
"""Thin, dict-like object that maps an alias to the associated `SqlCteNode`.
When merged, the entries from the right mapping take precedence over the entries from the left.
"""

cte_alias_to_cte_node_items: Tuple[Tuple[str, SqlCteNode], ...] = ()

@staticmethod
def create(cte_alias_to_cte_node_mapping: Mapping[str, SqlCteNode]) -> SqlCteAliasMapping: # noqa: D102
cte_alias_to_cte_node_pairs = []
for cte_alias, cte_node in cte_alias_to_cte_node_mapping.items():
cte_alias_to_cte_node_pairs.append((cte_alias, cte_node))

return SqlCteAliasMapping(cte_alias_to_cte_node_items=tuple(cte_alias_to_cte_node_pairs))

@cached_property
def _cte_alias_to_cte_node_dict(self) -> Mapping[str, SqlCteNode]:
return {item[0]: item[1] for item in self.cte_alias_to_cte_node_items}

def get_cte_node_for_alias(self, cte_alias: str) -> Optional[SqlCteNode]:
"""Return the associated `SqlCteNode` for the given alias, or None if the given alias is not known."""
return self._cte_alias_to_cte_node_dict.get(cte_alias)

@override
def merge(self, other: SqlCteAliasMapping) -> SqlCteAliasMapping:
new_mapping = dict(self._cte_alias_to_cte_node_dict)
for cte_alias, cte_node in other.cte_alias_to_cte_node_items:
new_mapping[cte_alias] = cte_node
return SqlCteAliasMapping.create(new_mapping)

@classmethod
@override
def empty_instance(cls) -> SqlCteAliasMapping:
return SqlCteAliasMapping()
3 changes: 2 additions & 1 deletion metricflow/sql/sql_select_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlCteNode, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_table_node import SqlTableNode


Expand Down
3 changes: 2 additions & 1 deletion metricflow/sql/sql_select_text_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

Expand Down
3 changes: 2 additions & 1 deletion metricflow/sql/sql_table_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode


Expand Down
2 changes: 1 addition & 1 deletion tests_metricflow/sql/optimizer/test_cte_column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from metricflow.sql.optimizer.column_pruner import SqlColumnPrunerOptimizer
from metricflow.sql.render.sql_plan_renderer import DefaultSqlPlanRenderer, SqlPlanRenderer
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlSelectColumn,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlSelectStatementNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

from metricflow.sql.optimizer.rewriting_sub_query_reducer import SqlRewritingSubQueryReducer
from metricflow.sql.render.sql_plan_renderer import DefaultSqlPlanRenderer, SqlPlanRenderer
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCteNode,
SqlSelectColumn,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlSelectStatementNode
Expand Down
Loading

0 comments on commit eac06df

Please sign in to comment.