Skip to content

Commit

Permalink
Rename to SqlSelectTextNode.
Browse files Browse the repository at this point in the history
  • Loading branch information
plypaul committed Jan 23, 2025
1 parent 9cdd52f commit d401015
Show file tree
Hide file tree
Showing 12 changed files with 155 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,8 @@
SelectOnlyLinkableSpecs,
)
from metricflow.plan_conversion.to_sql_plan.sql_join_builder import ColumnEqualityDescription, SqlPlanJoinBuilder
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlSelectColumn,
)
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from metricflow.sql.optimizer.required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -87,7 +87,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return node

Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/cte_mapping_lookup_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@
from typing_extensions import override

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -81,7 +81,7 @@ def visit_table_node(self, node: SqlTableNode) -> None:
self._default_handler(node)

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None:
self._default_handler(node)

@override
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -264,7 +264,7 @@ def visit_table_node(self, node: SqlTableNode) -> None:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return

Expand Down
8 changes: 4 additions & 4 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -764,7 +764,7 @@ def _get_matching_column_for_order_by(
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down Expand Up @@ -839,7 +839,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -83,7 +83,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down
6 changes: 3 additions & 3 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@
SqlExpressionRenderResult,
)
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlPlan,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -347,7 +347,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa:
bind_parameter_set=SqlBindParameterSet(),
)

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanRenderResult: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanRenderResult: # noqa: D102
return SqlPlanRenderResult(
sql=node.select_query.rstrip(),
bind_parameter_set=SqlBindParameterSet(),
Expand Down
70 changes: 70 additions & 0 deletions metricflow/sql/sql_ctas_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence, override

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode


@dataclass(frozen=True, eq=False)
class SqlCreateTableAsNode(SqlPlanNode):
"""An SQL node representing a CREATE TABLE AS statement.
Attributes:
sql_table: The SQL table to create.
"""

sql_table: SqlTable

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(sql_table: SqlTable, parent_node: SqlPlanNode) -> SqlCreateTableAsNode: # noqa: D102
return SqlCreateTableAsNode(
parent_nodes=(parent_node,),
sql_table=sql_table,
)

@override
def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_create_table_as_node(self)

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
return f"Create table {repr(self.sql_table.sql)}"

@property
def parent_node(self) -> SqlPlanNode: # noqa: D102
return self.parent_nodes[0]

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return self.parent_node.nearest_select_columns(cte_source_mapping)

@override
def copy(self) -> SqlCreateTableAsNode:
return SqlCreateTableAsNode(parent_nodes=self.parent_nodes, sql_table=self.sql_table)
109 changes: 3 additions & 106 deletions metricflow/sql/sql_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag
from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT
from typing_extensions import Self, override

from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,7 +90,7 @@ def visit_table_node(self, node: SqlTableNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> VisitorOutputT: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> VisitorOutputT: # noqa: D102
raise NotImplementedError

@abstractmethod
Expand Down Expand Up @@ -127,110 +128,6 @@ def reference_from(self, source_table_alias: str) -> SqlColumnReferenceExpressio
)


@dataclass(frozen=True, eq=False)
class SqlSelectQueryFromClauseNode(SqlPlanNode):
"""An SQL select query that can go in the FROM clause.
Attributes:
select_query: The SQL select query to include in the FROM clause.
"""

select_query: str

@staticmethod
def create(select_query: str) -> SqlSelectQueryFromClauseNode: # noqa: D102
return SqlSelectQueryFromClauseNode(
parent_nodes=(),
select_query=select_query,
)

@classmethod
def id_prefix(cls) -> IdPrefix: # noqa: D102
return StaticIdPrefix.SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX

@property
def description(self) -> str: # noqa: D102
return "Read From a Select Query"

def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102
return visitor.visit_query_from_clause_node(self)

@property
def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102
return None

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@override
def copy(self) -> SqlSelectQueryFromClauseNode:
return SqlSelectQueryFromClauseNode(parent_nodes=self.parent_nodes, select_query=self.select_query)


@dataclass(frozen=True, eq=False)
class SqlCreateTableAsNode(SqlPlanNode):
"""An SQL node representing a CREATE TABLE AS statement.
Attributes:
sql_table: The SQL table to create.
"""

sql_table: SqlTable

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(sql_table: SqlTable, parent_node: SqlPlanNode) -> SqlCreateTableAsNode: # noqa: D102
return SqlCreateTableAsNode(
parent_nodes=(parent_node,),
sql_table=sql_table,
)

@override
def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_create_table_as_node(self)

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
return f"Create table {repr(self.sql_table.sql)}"

@property
def parent_node(self) -> SqlPlanNode: # noqa: D102
return self.parent_nodes[0]

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return self.parent_node.nearest_select_columns(cte_source_mapping)

@override
def copy(self) -> SqlCreateTableAsNode:
return SqlCreateTableAsNode(parent_nodes=self.parent_nodes, sql_table=self.sql_table)


class SqlPlan(MetricFlowDag[SqlPlanNode]):
"""Model for an SQL statement as a DAG."""

Expand Down
Loading

0 comments on commit d401015

Please sign in to comment.