Skip to content

Commit

Permalink
Rename SqlQuery to SqlStatement (#1572)
Browse files Browse the repository at this point in the history
This PR renames `SqlQuery` to `SqlStatement` since it can now contain
"CREATE TABLE ... AS ..." instead of just a `SELECT ...`. This also
makes some renames / updates to code referring to `sql_query`.
  • Loading branch information
plypaul authored Dec 14, 2024
1 parent 1c6fed0 commit dd0e8c3
Show file tree
Hide file tree
Showing 14 changed files with 86 additions and 74 deletions.
4 changes: 2 additions & 2 deletions dbt-metricflow/dbt_metricflow/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,9 @@ def query(
if explain:
assert explain_result
sql = (
explain_result.rendered_sql_without_descriptions.sql_query
explain_result.sql_statement.without_descriptions.sql
if not show_sql_descriptions
else explain_result.rendered_sql.sql_query
else explain_result.sql_statement.sql
)
if show_dataflow_plan:
click.echo("🔎 Generated Dataflow Plan + SQL (remove --explain to see data):")
Expand Down
38 changes: 17 additions & 21 deletions metricflow/engine/metricflow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from metricflow.execution.dataflow_to_execution import (
DataflowToExecutionPlanConverter,
)
from metricflow.execution.execution_plan import ExecutionPlan, SqlQuery
from metricflow.execution.execution_plan import ExecutionPlan, SqlStatement
from metricflow.execution.executor import SequentialPlanExecutor
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
from metricflow.protocols.sql_client import SqlClient
Expand Down Expand Up @@ -177,35 +177,31 @@ class MetricFlowExplainResult:
output_table: Optional[SqlTable] = None

@property
def rendered_sql(self) -> SqlQuery:
def sql_statement(self) -> SqlStatement:
"""Return the SQL query that would be run for the given query."""
execution_plan = self.execution_plan
if len(execution_plan.tasks) != 1:
raise NotImplementedError(
f"Multiple tasks in the execution plan not yet supported. Got tasks: {execution_plan.tasks}"
str(
LazyFormat(
"Multiple tasks in the execution plan not yet supported.",
tasks=[task.task_id for task in execution_plan.tasks],
)
)
)

sql_query = execution_plan.tasks[0].sql_query
if not sql_query:
sql_statement = execution_plan.tasks[0].sql_statement
if not sql_statement:
raise NotImplementedError(
f"Execution plan tasks without a SQL query not yet supported. Got tasks: {execution_plan.tasks}"
str(
LazyFormat(
"Execution plan tasks without a SQL statement are not yet supported.",
tasks=[task.task_id for task in execution_plan.tasks],
)
)
)

return sql_query

@property
def rendered_sql_without_descriptions(self) -> SqlQuery:
"""Return the SQL query without the inline descriptions."""
sql_query = self.rendered_sql
return SqlQuery(
sql_query="\n".join(
filter(
lambda line: not line.strip().startswith("--"),
sql_query.sql_query.split("\n"),
)
),
bind_parameter_set=sql_query.bind_parameter_set,
)
return sql_statement

@property
def execution_plan(self) -> ExecutionPlan: # noqa: D102
Expand Down
8 changes: 4 additions & 4 deletions metricflow/execution/dataflow_to_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ExecutionPlan,
SelectSqlQueryToDataTableTask,
SelectSqlQueryToTableTask,
SqlQuery,
SqlStatement,
)
from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult
from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter
Expand Down Expand Up @@ -91,7 +91,7 @@ def visit_write_to_result_data_table_node(self, node: WriteToResultDataTableNode
leaf_tasks=(
SelectSqlQueryToDataTableTask.create(
sql_client=self._sql_client,
sql_query=SqlQuery(render_sql_result.sql, render_sql_result.bind_parameter_set),
sql_statement=SqlStatement(render_sql_result.sql, render_sql_result.bind_parameter_set),
),
)
)
Expand All @@ -109,8 +109,8 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> Conv
leaf_tasks=(
SelectSqlQueryToTableTask.create(
sql_client=self._sql_client,
sql_query=SqlQuery(
sql_query=render_sql_result.sql,
sql_statement=SqlStatement(
sql=render_sql_result.sql,
bind_parameter_set=render_sql_result.bind_parameter_set,
),
output_table=node.output_sql_table,
Expand Down
74 changes: 43 additions & 31 deletions metricflow/execution/execution_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class ExecutionPlanTask(DagNode["ExecutionPlanTask"], Visitable, ABC):
for these nodes as it seems more intuitive.
Attributes:
sql_query: If this runs a SQL query, return the associated SQL.
sql_statement: If this runs a SQL query, return the associated SQL.
"""

sql_query: Optional[SqlQuery]
sql_statement: Optional[SqlStatement]

@abstractmethod
def execute(self) -> TaskExecutionResult:
Expand All @@ -44,13 +44,26 @@ def task_id(self) -> NodeId:


@dataclass(frozen=True)
class SqlQuery:
"""A SQL query that can be run along with bind parameters."""
class SqlStatement:
"""Encapsulates a SQL statement along with the bind parameters that should be used."""

# This field will be renamed as it is confusing given the class name.
sql_query: str
sql: str
bind_parameter_set: SqlBindParameterSet

@property
def without_descriptions(self) -> SqlStatement:
"""Return the SQL query without the inline descriptions."""
return SqlStatement(
sql="\n".join(
filter(
lambda line: not line.strip().startswith("--"),
self.sql.split("\n"),
)
),
bind_parameter_set=self.bind_parameter_set,
)


@dataclass(frozen=True)
class TaskExecutionError(Exception):
Expand Down Expand Up @@ -80,7 +93,7 @@ class SelectSqlQueryToDataTableTask(ExecutionPlanTask):
Attributes:
sql_client: The SQL client used to run the query.
sql_query: The SQL query to run.
sql_statement: The SQL query to run.
parent_nodes: The parent tasks for this execution plan task.
"""

Expand All @@ -90,12 +103,12 @@ class SelectSqlQueryToDataTableTask(ExecutionPlanTask):
@staticmethod
def create( # noqa: D102
sql_client: SqlClient,
sql_query: SqlQuery,
sql_statement: SqlStatement,
parent_nodes: Sequence[ExecutionPlanTask] = (),
) -> SelectSqlQueryToDataTableTask:
return SelectSqlQueryToDataTableTask(
sql_client=sql_client,
sql_query=sql_query,
sql_statement=sql_statement,
parent_nodes=tuple(parent_nodes),
)

Expand All @@ -109,31 +122,30 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
sql_query = self.sql_query
assert sql_query is not None, f"{self.sql_query=} should have been set during creation."
return tuple(super().displayed_properties) + (DisplayedProperty(key="sql_query", value=sql_query.sql_query),)
assert self.sql_statement is not None, f"{self.sql_statement=} should have been set during creation."
return tuple(super().displayed_properties) + (DisplayedProperty(key="sql", value=self.sql_statement.sql),)

def execute(self) -> TaskExecutionResult: # noqa: D102
start_time = time.time()
sql_query = self.sql_query
assert sql_query is not None, f"{self.sql_query=} should have been set during creation."
sql_statement = self.sql_statement
assert sql_statement is not None, f"{self.sql_statement=} should have been set during creation."

df = self.sql_client.query(
sql_query.sql_query,
sql_bind_parameter_set=sql_query.bind_parameter_set,
sql_statement.sql,
sql_bind_parameter_set=sql_statement.bind_parameter_set,
)

end_time = time.time()
return TaskExecutionResult(
start_time=start_time,
end_time=end_time,
sql=sql_query.sql_query,
bind_params=sql_query.bind_parameter_set,
sql=sql_statement.sql,
bind_params=sql_statement.bind_parameter_set,
df=df,
)

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(sql_query='{self.sql_query}')"
return f"{self.__class__.__name__}(sql_statement={self.sql_statement!r})"


@dataclass(frozen=True)
Expand All @@ -144,7 +156,7 @@ class SelectSqlQueryToTableTask(ExecutionPlanTask):
Attributes:
sql_client: The SQL client used to run the query.
sql_query: The SQL query to run.
sql_statement: The SQL query to run.
output_table: The table where the results will be written.
"""

Expand All @@ -154,13 +166,13 @@ class SelectSqlQueryToTableTask(ExecutionPlanTask):
@staticmethod
def create( # noqa: D102
sql_client: SqlClient,
sql_query: SqlQuery,
sql_statement: SqlStatement,
output_table: SqlTable,
parent_nodes: Sequence[ExecutionPlanTask] = (),
) -> SelectSqlQueryToTableTask:
return SelectSqlQueryToTableTask(
sql_client=sql_client,
sql_query=sql_query,
sql_statement=sql_statement,
output_table=output_table,
parent_nodes=tuple(parent_nodes),
)
Expand All @@ -175,31 +187,31 @@ def description(self) -> str: # noqa: D102

@property
def displayed_properties(self) -> Sequence[DisplayedProperty]: # noqa: D102
sql_query = self.sql_query
assert sql_query is not None, f"{self.sql_query=} should have been set during creation."
sql_statement = self.sql_statement
assert sql_statement is not None, f"{self.sql_statement=} should have been set during creation."
return tuple(super().displayed_properties) + (
DisplayedProperty(key="sql_query", value=sql_query.sql_query),
DisplayedProperty(key="sql_statement", value=sql_statement.sql),
DisplayedProperty(key="output_table", value=self.output_table),
DisplayedProperty(key="bind_parameter_set", value=sql_query.bind_parameter_set),
DisplayedProperty(key="bind_parameter_set", value=sql_statement.bind_parameter_set),
)

def execute(self) -> TaskExecutionResult: # noqa: D102
sql_query = self.sql_query
assert sql_query is not None, f"{self.sql_query=} should have been set during creation."
sql_statement = self.sql_statement
assert sql_statement is not None, f"{self.sql_statement=} should have been set during creation."
start_time = time.time()
logger.debug(LazyFormat(lambda: f"Dropping table {self.output_table} in case it already exists"))
self.sql_client.execute(f"DROP TABLE IF EXISTS {self.output_table.sql}")
logger.debug(LazyFormat(lambda: f"Creating table {self.output_table} using a query"))
self.sql_client.execute(
sql_query.sql_query,
sql_bind_parameter_set=sql_query.bind_parameter_set,
sql_statement.sql,
sql_bind_parameter_set=sql_statement.bind_parameter_set,
)

end_time = time.time()
return TaskExecutionResult(start_time=start_time, end_time=end_time, sql=sql_query.sql_query)
return TaskExecutionResult(start_time=start_time, end_time=end_time, sql=sql_statement.sql)

def __repr__(self) -> str: # noqa: D105
return f"{self.__class__.__name__}(sql_query='{self.sql_query}', output_table={self.output_table})"
return f"{self.__class__.__name__}(sql_statement={self.sql_statement!r}', output_table={self.output_table})"


class ExecutionPlan(MetricFlowDag[ExecutionPlanTask]):
Expand Down
4 changes: 2 additions & 2 deletions metricflow/validation/data_warehouse_model_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ def _gen_explain_query_task_query_and_params(
) -> Tuple[str, SqlBindParameterSet]:
explain_result: MetricFlowExplainResult = mf_engine.explain(mf_request=mf_request)
return (
explain_result.rendered_sql_without_descriptions.sql_query,
explain_result.rendered_sql_without_descriptions.bind_parameter_set,
explain_result.sql_statement.without_descriptions.sql,
explain_result.sql_statement.without_descriptions.bind_parameter_set,
)

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tests_metricflow/engine/test_explain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _explain_one_query(mf_engine: MetricFlowEngine) -> str:
explain_result: MetricFlowExplainResult = mf_engine.explain(
MetricFlowQueryRequest.create_with_random_request_id(saved_query_name="p0_booking")
)
return explain_result.rendered_sql.sql_query
return explain_result.sql_statement.sql


def test_concurrent_explain_consistency(
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_optimization_level(
sql_optimization_level=optimization_level,
)
)
results[optimization_level.value] = explain_result.rendered_sql_without_descriptions.sql_query
results[optimization_level.value] = explain_result.sql_statement.without_descriptions.sql

assert_str_snapshot_equal(
request=request,
Expand Down
2 changes: 1 addition & 1 deletion tests_metricflow/execution/noop_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create( # noqa: D102
) -> NoOpExecutionPlanTask:
return NoOpExecutionPlanTask(
parent_nodes=tuple(parent_tasks),
sql_query=None,
sql_statement=None,
should_error=should_error,
)

Expand Down
8 changes: 4 additions & 4 deletions tests_metricflow/execution/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
ExecutionPlan,
SelectSqlQueryToDataTableTask,
SelectSqlQueryToTableTask,
SqlQuery,
SqlStatement,
)
from metricflow.execution.executor import SequentialPlanExecutor
from metricflow.protocols.sql_client import SqlClient, SqlEngine
from tests_metricflow.sql.compare_data_table import assert_data_tables_equal


def test_read_sql_task(sql_client: SqlClient) -> None: # noqa: D103
task = SelectSqlQueryToDataTableTask.create(sql_client, SqlQuery("SELECT 1 AS foo", SqlBindParameterSet()))
task = SelectSqlQueryToDataTableTask.create(sql_client, SqlStatement("SELECT 1 AS foo", SqlBindParameterSet()))
execution_plan = ExecutionPlan(leaf_tasks=[task], dag_id=DagId.from_str("plan0"))

results = SequentialPlanExecutor().execute_plan(execution_plan)
Expand All @@ -44,8 +44,8 @@ def test_write_table_task( # noqa: D103
output_table = SqlTable(schema_name=mf_test_configuration.mf_system_schema, table_name=f"test_table_{random_id()}")
task = SelectSqlQueryToTableTask.create(
sql_client=sql_client,
sql_query=SqlQuery(
sql_query=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo",
sql_statement=SqlStatement(
sql=f"CREATE TABLE {output_table.sql} AS SELECT 1 AS foo",
bind_parameter_set=SqlBindParameterSet(),
),
output_table=output_table,
Expand Down
6 changes: 3 additions & 3 deletions tests_metricflow/integration/test_rendered_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_render_query( # noqa: D103
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="query0",
sql=result.rendered_sql.sql_query,
sql=result.sql_statement.sql,
sql_engine=it_helpers.sql_client.sql_engine_type,
)

Expand Down Expand Up @@ -64,7 +64,7 @@ def test_id_enumeration( # noqa: D103
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="query",
sql=result.rendered_sql.sql_query,
sql=result.sql_statement.sql,
sql_engine=sql_client.sql_engine_type,
)

Expand All @@ -80,6 +80,6 @@ def test_id_enumeration( # noqa: D103
request=request,
mf_test_configuration=mf_test_configuration,
snapshot_id="query",
sql=result.rendered_sql.sql_query,
sql=result.sql_statement.sql,
sql_engine=sql_client.sql_engine_type,
)
Loading

0 comments on commit dd0e8c3

Please sign in to comment.