diff --git a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py index 40cc5fe81a..07b0d73f4f 100644 --- a/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py +++ b/metricflow-semantics/metricflow_semantics/specs/linkable_spec_set.py @@ -102,13 +102,28 @@ def as_tuple(self) -> Tuple[LinkableInstanceSpec, ...]: # noqa: D102 ) ) + def add_specs( + self, + dimension_specs: Tuple[DimensionSpec, ...] = (), + time_dimension_specs: Tuple[TimeDimensionSpec, ...] = (), + entity_specs: Tuple[EntitySpec, ...] = (), + group_by_metric_specs: Tuple[GroupByMetricSpec, ...] = (), + ) -> LinkableSpecSet: + """Return a new set with the new specs in addition to the existing ones.""" + return LinkableSpecSet( + dimension_specs=self.dimension_specs + dimension_specs, + time_dimension_specs=self.time_dimension_specs + time_dimension_specs, + entity_specs=self.entity_specs + entity_specs, + group_by_metric_specs=self.group_by_metric_specs + group_by_metric_specs, + ) + @override def merge(self, other: LinkableSpecSet) -> LinkableSpecSet: - return LinkableSpecSet( - dimension_specs=self.dimension_specs + other.dimension_specs, - time_dimension_specs=self.time_dimension_specs + other.time_dimension_specs, - entity_specs=self.entity_specs + other.entity_specs, - group_by_metric_specs=self.group_by_metric_specs + other.group_by_metric_specs, + return self.add_specs( + dimension_specs=other.dimension_specs, + time_dimension_specs=other.time_dimension_specs, + entity_specs=other.entity_specs, + group_by_metric_specs=other.group_by_metric_specs, ) @classmethod diff --git a/metricflow/dataflow/builder/dataflow_plan_builder.py b/metricflow/dataflow/builder/dataflow_plan_builder.py index c7cdc5ebbd..0ad7e9ed38 100644 --- a/metricflow/dataflow/builder/dataflow_plan_builder.py +++ b/metricflow/dataflow/builder/dataflow_plan_builder.py @@ -182,7 +182,6 @@ def _build_query_output_node( where_filter_specs=(), pushdown_enabled_types=frozenset({PredicateInputType.TIME_RANGE_CONSTRAINT}), ) - return self._build_metrics_output_node( metric_specs=tuple( MetricSpec( @@ -236,6 +235,13 @@ def _optimize_plan(self, plan: DataflowPlan, optimizations: FrozenSet[DataflowPl return plan + def _get_minimum_metric_time_spec_for_metric(self, metric_reference: MetricReference) -> TimeDimensionSpec: + """Gets the minimum metric time spec for the given metric reference.""" + min_granularity = ExpandedTimeGranularity.from_time_granularity( + self._metric_lookup.get_min_queryable_time_granularity(metric_reference) + ) + return DataSet.metric_time_dimension_spec(min_granularity) + def _build_aggregated_conversion_node( self, metric_spec: MetricSpec, @@ -307,14 +313,11 @@ def _build_aggregated_conversion_node( # Get the time dimension used to calculate the conversion window # Currently, both the base/conversion measure uses metric_time as it's the default agg time dimension. # However, eventually, there can be user-specified time dimensions used for this calculation. - default_granularity = ExpandedTimeGranularity.from_time_granularity( - self._metric_lookup.get_min_queryable_time_granularity(metric_spec.reference) - ) - metric_time_dimension_spec = DataSet.metric_time_dimension_spec(default_granularity) + min_metric_time_spec = self._get_minimum_metric_time_spec_for_metric(metric_spec.reference) # Filter the source nodes with only the required specs needed for the calculation constant_property_specs = [] - required_local_specs = [base_measure_spec.measure_spec, entity_spec, metric_time_dimension_spec] + list( + required_local_specs = [base_measure_spec.measure_spec, entity_spec, min_metric_time_spec] + list( base_measure_recipe.required_local_linkable_specs.as_tuple ) for constant_property in constant_properties or []: @@ -345,10 +348,10 @@ def _build_aggregated_conversion_node( # adjusted in the opposite direction. join_conversion_node = JoinConversionEventsNode.create( base_node=unaggregated_base_measure_node, - base_time_dimension_spec=metric_time_dimension_spec, + base_time_dimension_spec=min_metric_time_spec, conversion_node=unaggregated_conversion_measure_node, conversion_measure_spec=conversion_measure_spec.measure_spec, - conversion_time_dimension_spec=metric_time_dimension_spec, + conversion_time_dimension_spec=min_metric_time_spec, unique_identifier_keys=(MetadataSpec(MetricFlowReservedKeywords.MF_INTERNAL_UUID.value),), entity_spec=entity_spec, window=window, @@ -444,21 +447,19 @@ def _build_cumulative_metric_output_node( predicate_pushdown_state: PredicatePushdownState, for_group_by_source_node: bool = False, ) -> DataflowPlanNode: - # TODO: [custom granularity] Figure out how to support custom granularities as defaults - default_granularity = ExpandedTimeGranularity.from_time_granularity( - self._metric_lookup.get_min_queryable_time_granularity(metric_spec.reference) - ) + min_metric_time_spec = self._get_minimum_metric_time_spec_for_metric(metric_spec.reference) + min_granularity = min_metric_time_spec.time_granularity queried_agg_time_dimensions = queried_linkable_specs.included_agg_time_dimension_specs_for_metric( metric_reference=metric_spec.reference, metric_lookup=self._metric_lookup ) - query_includes_agg_time_dimension_with_default_granularity = False + query_includes_agg_time_dimension_with_min_granularity = False for time_dimension_spec in queried_agg_time_dimensions: - if time_dimension_spec.time_granularity == default_granularity: - query_includes_agg_time_dimension_with_default_granularity = True + if time_dimension_spec.time_granularity == min_granularity: + query_includes_agg_time_dimension_with_min_granularity = True break - if query_includes_agg_time_dimension_with_default_granularity or not queried_agg_time_dimensions: + if query_includes_agg_time_dimension_with_min_granularity or len(queried_agg_time_dimensions) == 0: return self._build_base_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=queried_linkable_specs, @@ -467,14 +468,11 @@ def _build_cumulative_metric_output_node( for_group_by_source_node=for_group_by_source_node, ) - # If a cumulative metric is queried without default granularity, it will need to be aggregated twice - + # If a cumulative metric is queried without its minimum granularity, it will need to be aggregated twice: # once as a normal metric, and again using a window function to narrow down to one row per granularity period. # In this case, add metric time at the default granularity to the linkable specs. It will be used in the order by # clause of the window function and later excluded from the output selections. - default_metric_time = DataSet.metric_time_dimension_spec(default_granularity) - include_linkable_specs = queried_linkable_specs.merge( - LinkableSpecSet(time_dimension_specs=(default_metric_time,)) - ) + include_linkable_specs = queried_linkable_specs.add_specs(time_dimension_specs=(min_metric_time_spec,)) compute_metrics_node = self._build_base_metric_output_node( metric_spec=metric_spec, queried_linkable_specs=include_linkable_specs, @@ -485,7 +483,7 @@ def _build_cumulative_metric_output_node( return WindowReaggregationNode.create( parent_node=compute_metrics_node, metric_spec=metric_spec, - order_by_spec=default_metric_time, + order_by_spec=min_metric_time_spec, partition_by_specs=queried_linkable_specs.as_tuple, ) @@ -1613,10 +1611,6 @@ def _build_aggregated_measure_from_measure_source_node( # If querying an offset metric, join to time spine before aggregation. if before_aggregation_time_spine_join_description is not None: - assert queried_agg_time_dimension_specs, ( - "Joining to time spine requires querying with metric time or the appropriate agg_time_dimension." - "This should have been caught by validations." - ) assert before_aggregation_time_spine_join_description.join_type is SqlJoinType.INNER, ( f"Expected {SqlJoinType.INNER} for joining to time spine before aggregation. Remove this if there's a " f"new use case."