diff --git a/metricflow/plan_conversion/to_sql_plan/dataflow_to_subquery.py b/metricflow/plan_conversion/to_sql_plan/dataflow_to_subquery.py index c7e44715d..15c88a439 100644 --- a/metricflow/plan_conversion/to_sql_plan/dataflow_to_subquery.py +++ b/metricflow/plan_conversion/to_sql_plan/dataflow_to_subquery.py @@ -117,8 +117,8 @@ SelectOnlyLinkableSpecs, ) from metricflow.plan_conversion.to_sql_plan.sql_join_builder import ColumnEqualityDescription, SqlPlanJoinBuilder +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteNode, SqlSelectColumn, ) diff --git a/metricflow/sql/optimizer/column_pruner.py b/metricflow/sql/optimizer/column_pruner.py index af7ac36e7..1880f5203 100644 --- a/metricflow/sql/optimizer/column_pruner.py +++ b/metricflow/sql/optimizer/column_pruner.py @@ -9,15 +9,15 @@ from metricflow.sql.optimizer.required_column_aliases import SqlMapRequiredColumnAliasesVisitor from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteAliasMapping, SqlCteNode, SqlPlanNode, SqlPlanNodeVisitor, - SqlSelectQueryFromClauseNode, ) from metricflow.sql.sql_select_node import SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -87,7 +87,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: """There are no SELECT columns in this node, so pruning cannot apply.""" return node - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: """Pruning cannot be done here since this is an arbitrary user-provided SQL query.""" return node diff --git a/metricflow/sql/optimizer/cte_mapping_lookup_builder.py b/metricflow/sql/optimizer/cte_mapping_lookup_builder.py index f8f781bb8..096eb19f6 100644 --- a/metricflow/sql/optimizer/cte_mapping_lookup_builder.py +++ b/metricflow/sql/optimizer/cte_mapping_lookup_builder.py @@ -8,15 +8,15 @@ from typing_extensions import override from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteAliasMapping, SqlCteNode, SqlPlanNode, SqlPlanNodeVisitor, - SqlSelectQueryFromClauseNode, ) from metricflow.sql.sql_select_node import SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -81,7 +81,7 @@ def visit_table_node(self, node: SqlTableNode) -> None: self._default_handler(node) @override - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None: + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None: self._default_handler(node) @override diff --git a/metricflow/sql/optimizer/required_column_aliases.py b/metricflow/sql/optimizer/required_column_aliases.py index f66daa70f..fb23a14b2 100644 --- a/metricflow/sql/optimizer/required_column_aliases.py +++ b/metricflow/sql/optimizer/required_column_aliases.py @@ -10,16 +10,16 @@ from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteAliasMapping, SqlCteNode, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn, - SqlSelectQueryFromClauseNode, ) from metricflow.sql.sql_select_node import SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -264,7 +264,7 @@ def visit_table_node(self, node: SqlTableNode) -> None: """There are no SELECT columns in this node, so pruning cannot apply.""" return - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None: + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None: """Pruning cannot be done here since this is an arbitrary user-provided SQL query.""" return diff --git a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py index da69e790e..2d7ac6b89 100644 --- a/metricflow/sql/optimizer/rewriting_sub_query_reducer.py +++ b/metricflow/sql/optimizer/rewriting_sub_query_reducer.py @@ -18,15 +18,15 @@ from typing_extensions import override from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteNode, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn, - SqlSelectQueryFromClauseNode, ) from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -764,7 +764,7 @@ def _get_matching_column_for_order_by( def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102 return node - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102 + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102 return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102 @@ -839,7 +839,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102 return node - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102 + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102 return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102 diff --git a/metricflow/sql/optimizer/table_alias_simplifier.py b/metricflow/sql/optimizer/table_alias_simplifier.py index 8339d2244..be24f18ec 100644 --- a/metricflow/sql/optimizer/table_alias_simplifier.py +++ b/metricflow/sql/optimizer/table_alias_simplifier.py @@ -5,15 +5,15 @@ from typing_extensions import override from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteNode, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn, - SqlSelectQueryFromClauseNode, ) from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -83,7 +83,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102 return node - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102 + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102 return node def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102 diff --git a/metricflow/sql/render/sql_plan_renderer.py b/metricflow/sql/render/sql_plan_renderer.py index 9f86b9733..afca824d0 100644 --- a/metricflow/sql/render/sql_plan_renderer.py +++ b/metricflow/sql/render/sql_plan_renderer.py @@ -18,16 +18,16 @@ SqlExpressionRenderResult, ) from metricflow.sql.render.rendering_constants import SqlRenderingConstants +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlCteNode, SqlPlan, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn, - SqlSelectQueryFromClauseNode, ) from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -347,7 +347,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa: bind_parameter_set=SqlBindParameterSet(), ) - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanRenderResult: # noqa: D102 + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanRenderResult: # noqa: D102 return SqlPlanRenderResult( sql=node.select_query.rstrip(), bind_parameter_set=SqlBindParameterSet(), diff --git a/metricflow/sql/sql_ctas_node.py b/metricflow/sql/sql_ctas_node.py new file mode 100644 index 000000000..04b17a03d --- /dev/null +++ b/metricflow/sql/sql_ctas_node.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence, override + +from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix +from metricflow_semantics.sql.sql_table import SqlTable +from metricflow_semantics.visitor import VisitorOutputT + +from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn +from metricflow.sql.sql_select_node import SqlSelectStatementNode +from metricflow.sql.sql_table_node import SqlTableNode + + +@dataclass(frozen=True, eq=False) +class SqlCreateTableAsNode(SqlPlanNode): + """An SQL node representing a CREATE TABLE AS statement. + + Attributes: + sql_table: The SQL table to create. + """ + + sql_table: SqlTable + + def __post_init__(self) -> None: # noqa: D105 + super().__post_init__() + assert len(self.parent_nodes) == 1 + + @staticmethod + def create(sql_table: SqlTable, parent_node: SqlPlanNode) -> SqlCreateTableAsNode: # noqa: D102 + return SqlCreateTableAsNode( + parent_nodes=(parent_node,), + sql_table=sql_table, + ) + + @override + def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: + return visitor.visit_create_table_as_node(self) + + @property + @override + def as_select_node(self) -> Optional[SqlSelectStatementNode]: + return None + + @property + @override + def as_sql_table_node(self) -> Optional[SqlTableNode]: + return None + + @property + @override + def description(self) -> str: + return f"Create table {repr(self.sql_table.sql)}" + + @property + def parent_node(self) -> SqlPlanNode: # noqa: D102 + return self.parent_nodes[0] + + @classmethod + @override + def id_prefix(cls) -> IdPrefix: + return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX + + @override + def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]: + return self.parent_node.nearest_select_columns(cte_source_mapping) + + @override + def copy(self) -> SqlCreateTableAsNode: + return SqlCreateTableAsNode(parent_nodes=self.parent_nodes, sql_table=self.sql_table) diff --git a/metricflow/sql/sql_plan.py b/metricflow/sql/sql_plan.py index 1e60f7cd7..49e792602 100644 --- a/metricflow/sql/sql_plan.py +++ b/metricflow/sql/sql_plan.py @@ -12,11 +12,12 @@ from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DagId, DagNode, MetricFlowDag from metricflow_semantics.sql.sql_exprs import SqlColumnReferenceExpression, SqlExpressionNode -from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import VisitorOutputT from typing_extensions import Self, override +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_select_node import SqlSelectStatementNode +from metricflow.sql.sql_select_text_node import SqlSelectTextNode from metricflow.sql.sql_table_node import SqlTableNode logger = logging.getLogger(__name__) @@ -89,7 +90,7 @@ def visit_table_node(self, node: SqlTableNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError @abstractmethod - def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> VisitorOutputT: # noqa: D102 + def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> VisitorOutputT: # noqa: D102 raise NotImplementedError @abstractmethod @@ -127,110 +128,6 @@ def reference_from(self, source_table_alias: str) -> SqlColumnReferenceExpressio ) -@dataclass(frozen=True, eq=False) -class SqlSelectQueryFromClauseNode(SqlPlanNode): - """An SQL select query that can go in the FROM clause. - - Attributes: - select_query: The SQL select query to include in the FROM clause. - """ - - select_query: str - - @staticmethod - def create(select_query: str) -> SqlSelectQueryFromClauseNode: # noqa: D102 - return SqlSelectQueryFromClauseNode( - parent_nodes=(), - select_query=select_query, - ) - - @classmethod - def id_prefix(cls) -> IdPrefix: # noqa: D102 - return StaticIdPrefix.SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX - - @property - def description(self) -> str: # noqa: D102 - return "Read From a Select Query" - - def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 - return visitor.visit_query_from_clause_node(self) - - @property - def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 - return None - - @override - def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]: - return None - - @property - @override - def as_sql_table_node(self) -> Optional[SqlTableNode]: - return None - - @override - def copy(self) -> SqlSelectQueryFromClauseNode: - return SqlSelectQueryFromClauseNode(parent_nodes=self.parent_nodes, select_query=self.select_query) - - -@dataclass(frozen=True, eq=False) -class SqlCreateTableAsNode(SqlPlanNode): - """An SQL node representing a CREATE TABLE AS statement. - - Attributes: - sql_table: The SQL table to create. - """ - - sql_table: SqlTable - - def __post_init__(self) -> None: # noqa: D105 - super().__post_init__() - assert len(self.parent_nodes) == 1 - - @staticmethod - def create(sql_table: SqlTable, parent_node: SqlPlanNode) -> SqlCreateTableAsNode: # noqa: D102 - return SqlCreateTableAsNode( - parent_nodes=(parent_node,), - sql_table=sql_table, - ) - - @override - def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: - return visitor.visit_create_table_as_node(self) - - @property - @override - def as_select_node(self) -> Optional[SqlSelectStatementNode]: - return None - - @property - @override - def as_sql_table_node(self) -> Optional[SqlTableNode]: - return None - - @property - @override - def description(self) -> str: - return f"Create table {repr(self.sql_table.sql)}" - - @property - def parent_node(self) -> SqlPlanNode: # noqa: D102 - return self.parent_nodes[0] - - @classmethod - @override - def id_prefix(cls) -> IdPrefix: - return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX - - @override - def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]: - return self.parent_node.nearest_select_columns(cte_source_mapping) - - @override - def copy(self) -> SqlCreateTableAsNode: - return SqlCreateTableAsNode(parent_nodes=self.parent_nodes, sql_table=self.sql_table) - - class SqlPlan(MetricFlowDag[SqlPlanNode]): """Model for an SQL statement as a DAG.""" diff --git a/metricflow/sql/sql_select_text_node.py b/metricflow/sql/sql_select_text_node.py new file mode 100644 index 000000000..50cedc73b --- /dev/null +++ b/metricflow/sql/sql_select_text_node.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Sequence, override + +from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix +from metricflow_semantics.visitor import VisitorOutputT + +from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn +from metricflow.sql.sql_select_node import SqlSelectStatementNode +from metricflow.sql.sql_table_node import SqlTableNode + + +@dataclass(frozen=True, eq=False) +class SqlSelectTextNode(SqlPlanNode): + """An SQL select query that can go in the FROM clause. + + Attributes: + select_query: The SQL select query to include in the FROM clause. + """ + + select_query: str + + @staticmethod + def create(select_query: str) -> SqlSelectTextNode: # noqa: D102 + return SqlSelectTextNode( + parent_nodes=(), + select_query=select_query, + ) + + @classmethod + def id_prefix(cls) -> IdPrefix: # noqa: D102 + return StaticIdPrefix.SQL_PLAN_QUERY_FROM_CLAUSE_ID_PREFIX + + @property + def description(self) -> str: # noqa: D102 + return "Read From a Select Query" + + def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT: # noqa: D102 + return visitor.visit_query_from_clause_node(self) + + @property + def as_select_node(self) -> Optional[SqlSelectStatementNode]: # noqa: D102 + return None + + @override + def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]: + return None + + @property + @override + def as_sql_table_node(self) -> Optional[SqlTableNode]: + return None + + @override + def copy(self) -> SqlSelectTextNode: + return SqlSelectTextNode(parent_nodes=self.parent_nodes, select_query=self.select_query) diff --git a/metricflow/sql/sql_table_node.py b/metricflow/sql/sql_table_node.py index 58f97d4d2..720d88e65 100644 --- a/metricflow/sql/sql_table_node.py +++ b/metricflow/sql/sql_table_node.py @@ -1,15 +1,16 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Sequence, Optional, override +from typing import Optional, Sequence, override -from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlCteAliasMapping, SqlSelectColumn -from metricflow.sql.sql_select_node import SqlSelectStatementNode from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix from metricflow_semantics.dag.mf_dag import DisplayedProperty from metricflow_semantics.sql.sql_table import SqlTable from metricflow_semantics.visitor import VisitorOutputT +from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn +from metricflow.sql.sql_select_node import SqlSelectStatementNode + @dataclass(frozen=True, eq=False) class SqlTableNode(SqlPlanNode): diff --git a/tests_metricflow/sql/test_sql_plan_render.py b/tests_metricflow/sql/test_sql_plan_render.py index 5c0bf6c05..54f1ea2a7 100644 --- a/tests_metricflow/sql/test_sql_plan_render.py +++ b/tests_metricflow/sql/test_sql_plan_render.py @@ -19,8 +19,8 @@ from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow.protocols.sql_client import SqlClient +from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode from metricflow.sql.sql_plan import ( - SqlCreateTableAsNode, SqlSelectColumn, ) from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode