Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented date_part in where filter #852

Merged
merged 2 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20231107-180843.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Implemented date_part in where filter.
time: 2023-11-07T18:08:43.67846-06:00
custom:
Author: DevonFulcher
Issue: None
13 changes: 10 additions & 3 deletions metricflow/specs/dimension_spec_resolver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Sequence
from typing import Optional, Sequence

from dbt_semantic_interfaces.call_parameter_sets import (
DimensionCallParameterSet,
Expand All @@ -10,6 +10,7 @@
from dbt_semantic_interfaces.naming.dundered import DunderedNameFormatter
from dbt_semantic_interfaces.references import DimensionReference, EntityReference, TimeDimensionReference
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart

from metricflow.specs.specs import DEFAULT_TIME_GRANULARITY, DimensionSpec, TimeDimensionSpec

Expand All @@ -35,16 +36,21 @@ def resolve_dimension_spec(self, name: str, entity_path: Sequence[str]) -> Dimen
)

def resolve_time_dimension_spec(
self, name: str, time_granularity_name: TimeGranularity, entity_path: Sequence[str]
self,
name: str,
time_granularity: Optional[TimeGranularity],
entity_path: Sequence[str],
date_part: Optional[DatePart],
) -> TimeDimensionSpec:
"""Resolve TimeDimension spec with the call_parameter_sets."""
structured_name = DunderedNameFormatter.parse_name(name)
call_parameter_set = TimeDimensionCallParameterSet(
time_dimension_reference=TimeDimensionReference(element_name=structured_name.element_name),
time_granularity=time_granularity_name,
time_granularity=time_granularity,
entity_path=(
tuple(EntityReference(element_name=arg) for arg in entity_path) + structured_name.entity_links
),
date_part=date_part,
)
assert call_parameter_set in self._call_parameter_sets.time_dimension_call_parameter_sets
return TimeDimensionSpec(
Expand All @@ -56,4 +62,5 @@ def resolve_time_dimension_spec(
if call_parameter_set.time_granularity is not None
else DEFAULT_TIME_GRANULARITY
),
date_part=call_parameter_set.date_part,
)
39 changes: 28 additions & 11 deletions metricflow/specs/where_filter_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
QueryInterfaceDimensionFactory,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.dimension_spec_resolver import DimensionSpecResolver
from metricflow.specs.specs import TimeDimensionSpec
from metricflow.specs.specs import DimensionSpec, InstanceSpec, TimeDimensionSpec


class WhereFilterDimension(ProtocolHint[QueryInterfaceDimension]):
Expand All @@ -37,32 +38,48 @@ def __init__( # noqa
self._column_association_resolver = column_association_resolver
self._name = name
self._entity_path = entity_path
self.dimension_spec = self._dimension_spec_resolver.resolve_dimension_spec(name, entity_path)
self.time_dimension_spec: Optional[TimeDimensionSpec] = None
self.dimension_spec: DimensionSpec = self._dimension_spec_resolver.resolve_dimension_spec(
self._name, self._entity_path
)
self.date_part_name: Optional[str] = None
self.time_granularity_name: Optional[str] = None

@property
def time_dimension_spec(self) -> TimeDimensionSpec:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so this returns a TimeDimension with DEFAULT_TIME_GRANULARITY if neither grain nor date_part are set. Interesting. I think that makes the most sense with where we're heading with time dimension expressions.

"""TimeDimensionSpec that results from the builder-pattern configuration."""
return self._dimension_spec_resolver.resolve_time_dimension_spec(
self._name,
TimeGranularity(self.time_granularity_name) if self.time_granularity_name else None,
self._entity_path,
DatePart(self.date_part_name) if self.date_part_name else None,
)

def grain(self, time_granularity_name: str) -> QueryInterfaceDimension:
"""The time granularity."""
self.time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
self._name, TimeGranularity(time_granularity_name), self._entity_path
)
self.time_granularity_name = time_granularity_name
return self

def date_part(self, _date_part: str) -> QueryInterfaceDimension:
def date_part(self, date_part_name: str) -> QueryInterfaceDimension:
"""The date_part requested to extract."""
raise InvalidQuerySyntax("date_part isn't currently supported in the where parameter")
self.date_part_name = date_part_name
return self

def descending(self, _is_descending: bool) -> QueryInterfaceDimension:
"""Set the sort order for order-by."""
raise InvalidQuerySyntax("descending is invalid in the where parameter")

def _get_spec(self) -> InstanceSpec:
"""Get either the TimeDimensionSpec or DimensionSpec."""
if self.time_granularity_name or self.date_part_name:
return self.time_dimension_spec
return self.dimension_spec

def __str__(self) -> str:
"""Returns the column name.

Important in the Jinja sandbox.
"""
return self._column_association_resolver.resolve_spec(
self.time_dimension_spec or self.dimension_spec
).column_name
return self._column_association_resolver.resolve_spec(self._get_spec()).column_name


class WhereFilterDimensionFactory(ProtocolHint[QueryInterfaceDimensionFactory]):
Expand Down
8 changes: 5 additions & 3 deletions metricflow/specs/where_filter_time_dimension.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
QueryInterfaceTimeDimensionFactory,
)
from dbt_semantic_interfaces.type_enums import TimeGranularity
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from typing_extensions import override

from metricflow.errors.errors import InvalidQuerySyntax
Expand Down Expand Up @@ -68,10 +69,11 @@ def create(
raise InvalidQuerySyntax(
"Can't set descending in the where clause. Try setting descending in the order_by clause instead"
)
if date_part_name:
raise InvalidQuerySyntax("date_part_name isn't currently supported in the where parameter")
time_dimension_spec = self._dimension_spec_resolver.resolve_time_dimension_spec(
time_dimension_name, TimeGranularity(time_granularity_name), entity_path
time_dimension_name,
TimeGranularity(time_granularity_name) if time_dimension_name else None,
entity_path,
DatePart(date_part_name) if date_part_name else None,
)
self.time_dimension_specs.append(time_dimension_spec)
column_name = self._column_association_resolver.resolve_spec(time_dimension_spec).column_name
Expand Down
6 changes: 3 additions & 3 deletions metricflow/specs/where_filter_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def create_from_where_filter(self, where_filter: WhereFilter) -> WhereFilterSpec
)

"""
Dimensions that are created with a grain parameter, Dimension(...).grain(...), are
added to dimension_specs otherwise they are add to time_dimension_factory.time_dimension_specs
Dimensions that are created with a grain or date_part parameter, Dimension(...).grain(...), are
added to time_dimension_factory.time_dimension_specs otherwise they are add to dimension_specs
"""
dimension_specs = []
for dimension in dimension_factory.created:
if dimension.time_dimension_spec:
if dimension.time_granularity_name or dimension.date_part_name:
time_dimension_factory.time_dimension_specs.append(dimension.time_dimension_spec)
else:
dimension_specs.append(dimension.dimension_spec)
Expand Down
89 changes: 89 additions & 0 deletions metricflow/test/model/test_where_filter_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from dbt_semantic_interfaces.implementations.filters.where_filter import PydanticWhereFilter
from dbt_semantic_interfaces.references import EntityReference
from dbt_semantic_interfaces.type_enums.date_part import DatePart
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity

from metricflow.query.query_exceptions import InvalidQueryException
Expand Down Expand Up @@ -98,6 +99,94 @@ def test_time_dimension_in_filter( # noqa: D
)


def test_date_part_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
where_filter = PydanticWhereFilter(where_sql_template="{{ Dimension('metric_time').date_part('year') }} = '2020'")

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="metric_time",
entity_links=(),
time_granularity=TimeGranularity.DAY,
date_part=DatePart.YEAR,
),
),
entity_specs=(),
)


@pytest.mark.parametrize(
"where_sql",
(
("{{ TimeDimension('metric_time', 'WEEK', date_part_name='year') }} = '2020'"),
("{{ Dimension('metric_time').date_part('year').grain('WEEK') }} = '2020'"),
("{{ Dimension('metric_time').grain('WEEK').date_part('year') }} = '2020'"),
Comment on lines +129 to +131
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope these aren't case-sensitive. I may add some additional tests on after we merge.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had fixed the case sensitivity here dbt-labs/dbt-semantic-interfaces#207. So, date_part was case-sensitive as of this PR, but that change to DSI fixes it.

),
)
def test_date_part_and_grain_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver, where_sql: str
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_year = '2020'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="metric_time",
entity_links=(),
time_granularity=TimeGranularity.WEEK,
date_part=DatePart.YEAR,
),
),
entity_specs=(),
)


@pytest.mark.parametrize(
"where_sql",
(
("{{ TimeDimension('metric_time', 'WEEK', date_part_name='day') }} = '2020'"),
("{{ Dimension('metric_time').date_part('day').grain('WEEK') }} = '2020'"),
("{{ Dimension('metric_time').grain('WEEK').date_part('day') }} = '2020'"),
),
)
def test_date_part_less_than_grain_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver, where_sql: str
) -> None:
where_filter = PydanticWhereFilter(where_sql_template=where_sql)

where_filter_spec = WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter(where_filter)

assert where_filter_spec.where_sql == "metric_time__extract_day = '2020'"
assert where_filter_spec.linkable_spec_set == LinkableSpecSet(
dimension_specs=(),
time_dimension_specs=(
TimeDimensionSpec(
element_name="metric_time",
entity_links=(),
time_granularity=TimeGranularity.WEEK,
date_part=DatePart.DAY,
),
),
entity_specs=(),
)


def test_entity_in_filter( # noqa: D
column_association_resolver: ColumnAssociationResolver,
) -> None:
Expand Down