From 42b8dcf0fe2fad79cf3ce2632424d35a985d922c Mon Sep 17 00:00:00 2001 From: serramatutu Date: Fri, 13 Dec 2024 10:47:07 +0100 Subject: [PATCH] Add integration test for metric alias in query --- .../integration/configured_test_case.py | 11 +++++++++-- .../integration/test_cases/itest_simple.yaml | 16 ++++++++++++++++ .../integration/test_configured_cases.py | 13 +++++++++++-- 3 files changed, 36 insertions(+), 4 deletions(-) diff --git a/tests_metricflow/integration/configured_test_case.py b/tests_metricflow/integration/configured_test_case.py index 1c677d17df..2e3d58d5f4 100644 --- a/tests_metricflow/integration/configured_test_case.py +++ b/tests_metricflow/integration/configured_test_case.py @@ -4,7 +4,7 @@ import os from collections import OrderedDict from enum import Enum -from typing import Dict, Optional, Sequence, Tuple +from typing import Dict, Optional, Sequence, Tuple, Union import yaml from dbt_semantic_interfaces.implementations.base import FrozenBaseModel @@ -37,6 +37,13 @@ def __repr__(self) -> str: # noqa: D105 return f"{self.__class__.__name__}.{self.name}" +class IntegrationTestMetric(FrozenBaseModel): + """A metric for an integration test.""" + + name: str + alias: Optional[str] + + class ConfiguredIntegrationTestCase(FrozenBaseModel): """Integration test case parsed from YAML files.""" @@ -51,7 +58,7 @@ class Config: # noqa: D106 # The SQL query that can be run to obtain the expected results. check_query: str file_path: str - metrics: Tuple[str, ...] = () + metrics: Tuple[Union[IntegrationTestMetric, str], ...] = () group_bys: Tuple[str, ...] = () group_by_objs: Tuple[Dict, ...] = () order_bys: Tuple[str, ...] = () diff --git a/tests_metricflow/integration/test_cases/itest_simple.yaml b/tests_metricflow/integration/test_cases/itest_simple.yaml index e4dc5cb5d6..11a2bdb62b 100644 --- a/tests_metricflow/integration/test_cases/itest_simple.yaml +++ b/tests_metricflow/integration/test_cases/itest_simple.yaml @@ -46,6 +46,22 @@ integration_test: GROUP BY ds --- +integration_test: + name: simple_query_with_alias + description: Tests selecting a metric with an alias and an associated local dimension. + model: SIMPLE_MODEL + metrics: + - name: "booking_value" + alias: "booking_alias" + group_bys: ["booking__is_instant"] + check_query: | + SELECT + SUM(booking_value) AS booking_alias + , is_instant AS booking__is_instant + FROM {{ source_schema }}.fct_bookings + GROUP BY + is_instant +--- integration_test: name: simple_query_with_joined_dimension_on_unique_id description: Query a metric with a joined dimension where the join key is a unique identifier. diff --git a/tests_metricflow/integration/test_configured_cases.py b/tests_metricflow/integration/test_configured_cases.py index c7763bbb43..ab487246b0 100644 --- a/tests_metricflow/integration/test_configured_cases.py +++ b/tests_metricflow/integration/test_configured_cases.py @@ -14,7 +14,11 @@ from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.protocols.query_parameter import DimensionOrEntityQueryParameter -from metricflow_semantics.specs.query_param_implementations import DimensionOrEntityParameter, TimeDimensionParameter +from metricflow_semantics.specs.query_param_implementations import ( + DimensionOrEntityParameter, + MetricParameter, + TimeDimensionParameter, +) from metricflow_semantics.test_helpers.config_helpers import MetricFlowTestConfiguration from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT from metricflow_semantics.time.time_spine_source import TimeSpineSource @@ -274,7 +278,12 @@ def test_case( group_by.append(DimensionOrEntityParameter(**kwargs)) query_result = engine.query( MetricFlowQueryRequest.create_with_random_request_id( - metric_names=case.metrics, + metrics=tuple( + MetricParameter(name=m, alias=None) + if isinstance(m, str) + else MetricParameter(name=m.name, alias=m.alias) + for m in case.metrics + ), group_by_names=case.group_bys if len(case.group_bys) > 0 else None, group_by=tuple(group_by) if len(group_by) > 0 else None, limit=case.limit,