Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Entity object syntax for metric filters #1250

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,6 @@ def date_part(self, date_part_name: str) -> QueryInterfaceDimension:
date_part=DatePart(date_part_name.lower()),
)

def descending(self, _is_descending: bool) -> QueryInterfaceDimension:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax("descending is invalid in the where parameter")

def __str__(self) -> str:
"""Returns the column name.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@
from metricflow_semantics.specs.rendered_spec_tracker import RenderedSpecTracker


class WhereFilterEntity(ProtocolHint[QueryInterfaceEntity]):
class RenderedWhereFilterEntity:
"""An entity that is passed in through the where filter parameter."""

@override
def _implements_protocol(self) -> QueryInterfaceEntity:
return self

def __init__( # noqa
self,
column_association_resolver: ColumnAssociationResolver,
Expand All @@ -50,12 +46,6 @@ def __init__( # noqa
self._time_grain = time_grain
self._date_part = date_part

def descending(self, _is_descending: bool) -> QueryInterfaceEntity:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)

def __str__(self) -> str:
"""Returns the column name.

Expand All @@ -79,16 +69,12 @@ def __str__(self) -> str:
return column_association.column_name


class WhereFilterEntityFactory(ProtocolHint[QueryInterfaceEntityFactory]):
"""Creates a WhereFilterEntity.
class RenderedWhereFilterEntityFactory(ProtocolHint[QueryInterfaceEntityFactory]):
"""Creates a RenderedWhereFilterEntity.

Each call to `create` adds an EntitySpec to entity_specs.
"""

@override
def _implements_protocol(self) -> QueryInterfaceEntityFactory:
return self

def __init__( # noqa
self,
column_association_resolver: ColumnAssociationResolver,
Expand All @@ -105,7 +91,7 @@ def create(self, entity_name: str, entity_path: Sequence[str] = ()) -> WhereFilt
"""Create a WhereFilterEntity."""
structured_name = DunderedNameFormatter.parse_name(entity_name.lower())

return WhereFilterEntity(
return RenderedWhereFilterEntity(
column_association_resolver=self._column_association_resolver,
resolved_spec_lookup=self._resolved_spec_lookup,
where_filter_location=self._where_filter_location,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import Sequence, Union

from dbt_semantic_interfaces.call_parameter_sets import (
MetricCallParameterSet,
Expand All @@ -18,6 +18,7 @@
)
from metricflow_semantics.specs.column_assoc import ColumnAssociationResolver
from metricflow_semantics.specs.rendered_spec_tracker import RenderedSpecTracker
from metricflow_semantics.specs.where_filter_entity import WhereFilterEntity


class WhereFilterMetric(ProtocolHint[QueryInterfaceMetric]):
Expand Down Expand Up @@ -94,13 +95,21 @@ def __init__( # noqa
self._where_filter_location = where_filter_location
self._rendered_spec_tracker = rendered_spec_tracker

def create(self, metric_name: str, group_by: Sequence[str] = ()) -> WhereFilterMetric:
def create(self, metric_name: str, group_by: Sequence[Union[str, WhereFilterEntity]] = ()) -> WhereFilterMetric:
"""Create a WhereFilterMetric."""
return WhereFilterMetric(
column_association_resolver=self._column_association_resolver,
resolved_spec_lookup=self._resolved_spec_lookup,
where_filter_location=self._where_filter_location,
rendered_spec_tracker=self._rendered_spec_tracker,
element_name=metric_name,
group_by=tuple(LinkableElementReference(group_by_name.lower()) for group_by_name in group_by),
group_by=tuple(
LinkableElementReference(
# TODO: add entity links
group_by_item.lower()
if isinstance(group_by_item, str)
else group_by_item._element_name
)
for group_by_item in group_by
),
)
Original file line number Diff line number Diff line change
Expand Up @@ -772,3 +772,11 @@ metric:
denominator:
name: listings
filter: "{{ Metric('views', ['listing']) }} > 10"
---
metric:
name: really_active_listings
description: Listings with at least 5 bookings
type: simple
type_params:
measure: listings
filter: "{{ Metric('bookings', [Entity('listing')]) }} > 5"
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import dateutil.relativedelta
from dateutil.relativedelta import relativedelta
from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.type_enums import TimeGranularity
from typing_extensions import override

Expand All @@ -32,8 +31,8 @@ def _relative_delta_for_window(self, time_granularity: TimeGranularity, count: i
return relativedelta(months=count * 3)
elif time_granularity is TimeGranularity.YEAR:
return relativedelta(years=count)
else:
assert_values_exhausted(time_granularity)
# else:
# assert_values_exhausted(time_granularity)

@override
def expand_time_constraint_to_fill_granularity(
Expand Down Expand Up @@ -70,8 +69,8 @@ def adjust_to_start_of_period(
return date_to_adjust + relativedelta(month=10, day=1)
elif time_granularity is TimeGranularity.YEAR:
return date_to_adjust + relativedelta(month=1, day=1)
else:
assert_values_exhausted(time_granularity)
# else:
# assert_values_exhausted(time_granularity)

@override
def adjust_to_end_of_period(
Expand All @@ -94,8 +93,8 @@ def adjust_to_end_of_period(
return date_to_adjust + relativedelta(month=12, day=31)
elif time_granularity is TimeGranularity.YEAR:
return date_to_adjust + relativedelta(month=12, day=31)
else:
assert_values_exhausted(time_granularity)
# else:
# assert_values_exhausted(time_granularity)

@override
def expand_time_constraint_for_cumulative_metric(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@
'listing__nested_fill_nulls_without_time_spine',
'listing__non_referred_bookings_pct',
'listing__popular_listing_bookings_per_booker',
'listing__really_active_listings',
'listing__referred_bookings',
'listing__smallest_listing',
'listing__twice_bookings_fill_nulls_with_0_without_time_spine',
Expand Down Expand Up @@ -291,6 +292,7 @@
'user__listing__user__nested_fill_nulls_without_time_spine',
'user__listing__user__non_referred_bookings_pct',
'user__listing__user__popular_listing_bookings_per_booker',
'user__listing__user__really_active_listings',
'user__listing__user__referred_bookings',
'user__listing__user__smallest_listing',
'user__listing__user__twice_bookings_fill_nulls_with_0_without_time_spine',
Expand All @@ -299,6 +301,7 @@
'user__listings',
'user__lux_listings',
'user__popular_listing_bookings_per_booker',
'user__really_active_listings',
'user__regional_starting_balance_ratios',
'user__revenue',
'user__revenue_all_time',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ Model Join-Path Entity Links
('listings_latest',) ("('listing',)", "('listing',)") nested_fill_nulls_without_time_spine ['JOINED', 'METRIC']
('listings_latest',) ("('listing',)", "('listing',)") non_referred_bookings_pct ['JOINED', 'METRIC']
('listings_latest',) ("('listing',)", "('listing',)") popular_listing_bookings_per_booker ['JOINED', 'METRIC']
('listings_latest',) ("('listing',)", "('listing',)") really_active_listings ['JOINED', 'METRIC']
('listings_latest',) ("('listing',)", "('listing',)") referred_bookings ['JOINED', 'METRIC']
('listings_latest',) ("('listing',)", "('listing',)") smallest_listing ['JOINED', 'METRIC']
('listings_latest',) ("('listing',)", "('listing',)") twice_bookings_fill_nulls_with_0_without_time_spine ['JOINED', 'METRIC']
Expand Down Expand Up @@ -168,6 +169,7 @@ Model Join-Path Entity Links
('listings_latest',) ("('user',)", "('listing', 'user')") nested_fill_nulls_without_time_spine ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('listing', 'user')") non_referred_bookings_pct ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('listing', 'user')") popular_listing_bookings_per_booker ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('listing', 'user')") really_active_listings ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('listing', 'user')") referred_bookings ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('listing', 'user')") smallest_listing ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('listing', 'user')") twice_bookings_fill_nulls_with_0_without_time_spine ['JOINED', 'METRIC']
Expand All @@ -183,6 +185,7 @@ Model Join-Path Entity Links
('listings_latest',) ("('user',)", "('user',)") listings ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('user',)") lux_listings ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('user',)") popular_listing_bookings_per_booker ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('user',)") really_active_listings ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('user',)") regional_starting_balance_ratios ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('user',)") revenue ['JOINED', 'METRIC']
('listings_latest',) ("('user',)", "('user',)") revenue_all_time ['JOINED', 'METRIC']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
'company__user__company__listings',
'company__user__company__lux_listings',
'company__user__company__popular_listing_bookings_per_booker',
'company__user__company__really_active_listings',
'company__user__company__regional_starting_balance_ratios',
'company__user__company__revenue',
'company__user__company__revenue_all_time',
Expand Down Expand Up @@ -361,6 +362,7 @@
'listing__nested_fill_nulls_without_time_spine',
'listing__non_referred_bookings_pct',
'listing__popular_listing_bookings_per_booker',
'listing__really_active_listings',
'listing__referred_bookings',
'listing__smallest_listing',
'listing__twice_bookings_fill_nulls_with_0_without_time_spine',
Expand Down Expand Up @@ -414,6 +416,7 @@
'lux_listing__listing__lux_listing__nested_fill_nulls_without_time_spine',
'lux_listing__listing__lux_listing__non_referred_bookings_pct',
'lux_listing__listing__lux_listing__popular_listing_bookings_per_booker',
'lux_listing__listing__lux_listing__really_active_listings',
'lux_listing__listing__lux_listing__referred_bookings',
'lux_listing__listing__lux_listing__smallest_listing',
'lux_listing__listing__lux_listing__twice_bookings_fill_nulls_with_0_without_time_spine',
Expand Down Expand Up @@ -525,6 +528,7 @@
'user__listing__user__nested_fill_nulls_without_time_spine',
'user__listing__user__non_referred_bookings_pct',
'user__listing__user__popular_listing_bookings_per_booker',
'user__listing__user__really_active_listings',
'user__listing__user__referred_bookings',
'user__listing__user__smallest_listing',
'user__listing__user__twice_bookings_fill_nulls_with_0_without_time_spine',
Expand All @@ -533,6 +537,7 @@
'user__listings',
'user__lux_listings',
'user__popular_listing_bookings_per_booker',
'user__really_active_listings',
'user__regional_starting_balance_ratios',
'user__revenue',
'user__revenue_all_time',
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ classifiers = [
dependencies = [
"Jinja2>=3.1.3",
"PyYAML>=6.0, <7.0.0",
"dbt-semantic-interfaces>=0.5.1, <0.6.0",
"graphviz>=0.18.2, <0.21",
"more-itertools>=8.10.0, <10.2.0",
"pydantic>=1.10.0, <1.11.0",
Expand Down Expand Up @@ -114,6 +113,7 @@ description = "Environment for development. Includes a DuckDB-backed client."
pre-install-commands = [
"pip install -e ./metricflow-semantics",
"pip install -e ./dbt-metricflow[duckdb]",
"pip install -e ../dbt-semantic-interfaces",
]

features = [
Expand Down
27 changes: 26 additions & 1 deletion tests_metricflow/query_rendering/test_metric_filter_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,31 @@ def test_metric_with_metric_in_where_filter(
)


@pytest.mark.sql_engine_snapshot
def test_metric_filter_with_entity_object_syntax(
request: FixtureRequest,
mf_test_configuration: MetricFlowTestConfiguration,
dataflow_plan_builder: DataflowPlanBuilder,
sql_client: SqlClient,
dataflow_to_sql_converter: DataflowToSqlQueryPlanConverter,
query_parser: MetricFlowQueryParser,
) -> None:
"""Tests a query with a metric in the metric-level where filter."""
query_spec = query_parser.parse_and_validate_query(
metric_names=("really_active_listings",),
group_by_names=("metric_time__day",),
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)

convert_and_check(
request=request,
mf_test_configuration=mf_test_configuration,
dataflow_to_sql_converter=dataflow_to_sql_converter,
sql_client=sql_client,
node=dataflow_plan.sink_node,
)


@pytest.mark.sql_engine_snapshot
def test_query_with_derived_metric_in_where_filter(
request: FixtureRequest,
Expand Down Expand Up @@ -242,7 +267,7 @@ def test_metric_filtered_by_itself(
query_spec = query_parser.parse_and_validate_query(
metric_names=("bookers",),
where_constraint=PydanticWhereFilter(
where_sql_template="{{ Metric('bookers', ['guest']) }} > 1.00",
where_sql_template="{{ Metric('bookers', ['listing']) }} > 1.00",
),
).query_spec
dataflow_plan = dataflow_plan_builder.build_plan(query_spec)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,23 @@ FROM (
FROM (
-- Constrain Output with WHERE
SELECT
subq_14.guest__bookers
subq_14.listing__bookers
, subq_14.bookers
FROM (
-- Pass Only Elements: ['bookers', 'guest__bookers']
-- Pass Only Elements: ['bookers', 'listing__bookers']
SELECT
subq_13.guest__bookers
subq_13.listing__bookers
, subq_13.bookers
FROM (
-- Join Standard Outputs
SELECT
subq_6.guest AS guest
, subq_12.guest__bookers AS guest__bookers
subq_6.listing AS listing
, subq_12.listing__bookers AS listing__bookers
, subq_6.bookers AS bookers
FROM (
-- Pass Only Elements: ['bookers', 'guest']
-- Pass Only Elements: ['bookers', 'listing']
SELECT
subq_5.guest
subq_5.listing
, subq_5.bookers
FROM (
-- Metric Time Dimension 'ds'
Expand Down Expand Up @@ -227,24 +227,24 @@ FROM (
) subq_5
) subq_6
LEFT OUTER JOIN (
-- Pass Only Elements: ['guest', 'guest__bookers']
-- Pass Only Elements: ['listing', 'listing__bookers']
SELECT
subq_11.guest
, subq_11.guest__bookers
subq_11.listing
, subq_11.listing__bookers
FROM (
-- Compute Metrics via Expressions
SELECT
subq_10.guest
, subq_10.bookers AS guest__bookers
subq_10.listing
, subq_10.bookers AS listing__bookers
FROM (
-- Aggregate Measures
SELECT
subq_9.guest
subq_9.listing
, COUNT(DISTINCT subq_9.bookers) AS bookers
FROM (
-- Pass Only Elements: ['bookers', 'guest']
-- Pass Only Elements: ['bookers', 'listing']
SELECT
subq_8.guest
subq_8.listing
, subq_8.bookers
FROM (
-- Metric Time Dimension 'ds'
Expand Down Expand Up @@ -443,15 +443,15 @@ FROM (
) subq_8
) subq_9
GROUP BY
guest
listing
) subq_10
) subq_11
) subq_12
ON
subq_6.guest = subq_12.guest
subq_6.listing = subq_12.listing
) subq_13
) subq_14
WHERE guest__bookers > 1.00
WHERE listing__bookers > 1.00
) subq_15
) subq_16
) subq_17
Loading