From b5b6d3f592fdcd6e734aeeafbe812ab9f4d36a55 Mon Sep 17 00:00:00 2001 From: Paul Yang Date: Thu, 14 Nov 2024 09:34:27 -0800 Subject: [PATCH] Fix bug with `MetricFlowQueryRequest.sql_optimization_level` handling (#1524) `MetricFlowQueryRequest` has the `sql_optimization_level` field which is supposed to control the SQL optimization level. However, it was not getting handled correctly, so this PR fixes that issue. --- metricflow/engine/metricflow_engine.py | 13 +++++----- metricflow/execution/dataflow_to_execution.py | 5 ++++ .../sql/optimizer/optimization_levels.py | 4 +++ .../integration/test_mf_engine.py | 25 +++++++++++++++++++ .../test_dataflow_to_execution.py | 13 +++------- 5 files changed, 45 insertions(+), 15 deletions(-) diff --git a/metricflow/engine/metricflow_engine.py b/metricflow/engine/metricflow_engine.py index d5916a11d1..319aa1a2b6 100644 --- a/metricflow/engine/metricflow_engine.py +++ b/metricflow/engine/metricflow_engine.py @@ -409,11 +409,6 @@ def __init__( column_association_resolver=self._column_association_resolver, semantic_manifest_lookup=self._semantic_manifest_lookup, ) - self._to_execution_plan_converter = DataflowToExecutionPlanConverter( - sql_plan_converter=self._to_sql_query_plan_converter, - sql_plan_renderer=self._sql_client.sql_query_plan_renderer, - sql_client=sql_client, - ) self._executor = SequentialPlanExecutor() self._query_parser = query_parser or MetricFlowQueryParser( @@ -539,7 +534,13 @@ def _create_execution_plan(self, mf_query_request: MetricFlowQueryRequest) -> Me ) logger.info(LazyFormat("Building execution plan")) - convert_to_execution_plan_result = self._to_execution_plan_converter.convert_to_execution_plan(dataflow_plan) + _to_execution_plan_converter = DataflowToExecutionPlanConverter( + sql_plan_converter=self._to_sql_query_plan_converter, + sql_plan_renderer=self._sql_client.sql_query_plan_renderer, + sql_client=self._sql_client, + sql_optimization_level=mf_query_request.sql_optimization_level, + ) + convert_to_execution_plan_result = _to_execution_plan_converter.convert_to_execution_plan(dataflow_plan) return MetricFlowExplainResult( query_spec=query_spec, dataflow_plan=dataflow_plan, diff --git a/metricflow/execution/dataflow_to_execution.py b/metricflow/execution/dataflow_to_execution.py index 3a438a9f7c..fa0518e847 100644 --- a/metricflow/execution/dataflow_to_execution.py +++ b/metricflow/execution/dataflow_to_execution.py @@ -40,6 +40,7 @@ from metricflow.plan_conversion.convert_to_sql_plan import ConvertToSqlPlanResult from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from metricflow.sql.render.sql_plan_renderer import SqlPlanRenderResult, SqlQueryPlanRenderer logger = logging.getLogger(__name__) @@ -53,6 +54,7 @@ def __init__( sql_plan_converter: DataflowToSqlQueryPlanConverter, sql_plan_renderer: SqlQueryPlanRenderer, sql_client: SqlClient, + sql_optimization_level: SqlQueryOptimizationLevel, ) -> None: """Constructor. @@ -60,15 +62,18 @@ def __init__( sql_plan_converter: Converts a dataflow plan node to a SQL query plan sql_plan_renderer: Converts a SQL query plan to SQL text sql_client: The client to use for running queries. + sql_optimization_level: The optimization level to use for generating the SQL. """ self._sql_plan_converter = sql_plan_converter self._sql_plan_renderer = sql_plan_renderer self._sql_client = sql_client + self._optimization_level = sql_optimization_level def _convert_to_sql_plan(self, node: DataflowPlanNode) -> ConvertToSqlPlanResult: logger.debug(LazyFormat(lambda: f"Generating SQL query plan from {node.node_id}")) result = self._sql_plan_converter.convert_to_sql_query_plan( sql_engine_type=self._sql_client.sql_engine_type, + optimization_level=self._optimization_level, dataflow_plan_node=node, ) logger.debug(LazyFormat(lambda: f"Generated SQL query plan is:\n{result.sql_plan.structure_text()}")) diff --git a/metricflow/sql/optimizer/optimization_levels.py b/metricflow/sql/optimizer/optimization_levels.py index f1f965d739..061e14ca1b 100644 --- a/metricflow/sql/optimizer/optimization_levels.py +++ b/metricflow/sql/optimizer/optimization_levels.py @@ -23,6 +23,10 @@ class SqlQueryOptimizationLevel(Enum): O4 = "O4" O5 = "O5" + @staticmethod + def default_level() -> SqlQueryOptimizationLevel: # noqa: D102 + return SqlQueryOptimizationLevel.O4 + @dataclass(frozen=True) class SqlGenerationOptionSet: diff --git a/tests_metricflow/integration/test_mf_engine.py b/tests_metricflow/integration/test_mf_engine.py index 9aa1373701..b933835aa8 100644 --- a/tests_metricflow/integration/test_mf_engine.py +++ b/tests_metricflow/integration/test_mf_engine.py @@ -3,6 +3,8 @@ from _pytest.fixtures import FixtureRequest from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration +from metricflow.engine.metricflow_engine import MetricFlowExplainResult, MetricFlowQueryRequest +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from tests_metricflow.integration.conftest import IntegrationTestHelpers from tests_metricflow.snapshot_utils import assert_object_snapshot_equal @@ -16,3 +18,26 @@ def test_list_dimensions( # noqa: D103 obj_id="result0", obj=sorted([dim.qualified_name for dim in it_helpers.mf_engine.list_dimensions()]), ) + + +def test_sql_optimization_level(it_helpers: IntegrationTestHelpers) -> None: + """Check that different SQL optimization levels produce different SQL.""" + assert ( + SqlQueryOptimizationLevel.default_level() != SqlQueryOptimizationLevel.O0 + ), "The default optimization level should be different from the lowest level." + explain_result_at_default_level: MetricFlowExplainResult = it_helpers.mf_engine.explain( + MetricFlowQueryRequest.create_with_random_request_id( + metric_names=("bookings",), + group_by_names=("metric_time",), + sql_optimization_level=SqlQueryOptimizationLevel.default_level(), + ) + ) + explain_result_at_level_0: MetricFlowExplainResult = it_helpers.mf_engine.explain( + MetricFlowQueryRequest.create_with_random_request_id( + metric_names=("bookings",), + group_by_names=("metric_time",), + sql_optimization_level=SqlQueryOptimizationLevel.O0, + ) + ) + + assert explain_result_at_default_level.rendered_sql.sql_query != explain_result_at_level_0.rendered_sql.sql_query diff --git a/tests_metricflow/plan_conversion/test_dataflow_to_execution.py b/tests_metricflow/plan_conversion/test_dataflow_to_execution.py index b92a26a2ec..c5b9503f3c 100644 --- a/tests_metricflow/plan_conversion/test_dataflow_to_execution.py +++ b/tests_metricflow/plan_conversion/test_dataflow_to_execution.py @@ -15,6 +15,7 @@ from metricflow.execution.dataflow_to_execution import DataflowToExecutionPlanConverter from metricflow.plan_conversion.dataflow_to_sql import DataflowToSqlQueryPlanConverter from metricflow.protocols.sql_client import SqlClient +from metricflow.sql.optimizer.optimization_levels import SqlQueryOptimizationLevel from metricflow.sql.render.sql_plan_renderer import DefaultSqlQueryPlanRenderer from tests_metricflow.snapshot_utils import assert_execution_plan_text_equal @@ -30,6 +31,7 @@ def make_execution_plan_converter( # noqa: D103 ), sql_plan_renderer=DefaultSqlQueryPlanRenderer(), sql_client=sql_client, + sql_optimization_level=SqlQueryOptimizationLevel.O4, ) @@ -172,17 +174,10 @@ def test_multihop_joined_plan( ) ) - to_execution_plan_converter = DataflowToExecutionPlanConverter( - sql_plan_converter=DataflowToSqlQueryPlanConverter( - column_association_resolver=DunderColumnAssociationResolver( - partitioned_multi_hop_join_semantic_manifest_lookup - ), - semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup, - ), - sql_plan_renderer=DefaultSqlQueryPlanRenderer(), + to_execution_plan_converter = make_execution_plan_converter( + semantic_manifest_lookup=partitioned_multi_hop_join_semantic_manifest_lookup, sql_client=sql_client, ) - execution_plan = to_execution_plan_converter.convert_to_execution_plan(dataflow_plan).execution_plan assert_execution_plan_text_equal(