Skip to content

Commit

Permalink
Organize DataflowPlanBuilder recursive case handling and move MetricL…
Browse files Browse the repository at this point in the history
…ookUp.measures_for_metric().

This more cleanly separates the recursive handling of derived metrics into
cases:

* Build a node to compute a simple / base metric.
* Build a node to compute a derived metric.
* Build a node to compute a metric of any type.
* Build a node to compute many metrics.

In addition, this also moves MetricLookUp.measures_for_metric() into the
DataflowPlanBuilder because it seemed a little out of place in MetricLookUp
due to the dependence on the column association resolver.
  • Loading branch information
plypaul committed Nov 15, 2023
1 parent a2f779b commit 0361b0e
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 100 deletions.
195 changes: 136 additions & 59 deletions metricflow/dataflow/builder/dataflow_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dbt_semantic_interfaces.pretty_print import pformat_big_objects
from dbt_semantic_interfaces.protocols.metric import MetricTimeWindow, MetricType
from dbt_semantic_interfaces.references import (
MetricReference,
TimeDimensionReference,
)
from dbt_semantic_interfaces.type_enums.time_granularity import TimeGranularity
Expand Down Expand Up @@ -64,6 +65,7 @@
TimeDimensionSpec,
WhereFilterSpec,
)
from metricflow.specs.where_filter_transform import WhereSpecFactory
from metricflow.sql.sql_plan import SqlJoinType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -162,14 +164,113 @@ def build_plan(

return plan

def _build_base_metric_output_node(
self,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> ComputeMetricsNode:
"""Builds a node to compute a metric that is not defined from other metrics."""
metric_reference = metric_spec.reference
metric = self._metric_lookup.get_metric(metric_reference)
metric_input_measure_specs = self._measures_for_metric(
metric_reference=metric_reference,
column_association_resolver=self._column_association_resolver,
)
assert len(metric_input_measure_specs) == 1, "Simple and cumulative metrics must have one input measure."
metric_input_measure_spec = metric_input_measure_specs[0]

logger.info(
f"For {metric_spec}, needed measure is:\n"
f"{pformat_big_objects(metric_input_measure_spec=metric_input_measure_spec)}"
)
combined_where = where_constraint
if metric_spec.constraint:
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
)
aggregated_measures_node = self.build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
time_range_constraint=time_range_constraint,
cumulative=metric.type == MetricType.CUMULATIVE,
cumulative_window=metric.type_params.window if metric.type == MetricType.CUMULATIVE else None,
cumulative_grain_to_date=(
metric.type_params.grain_to_date if metric.type == MetricType.CUMULATIVE else None
),
)

return self.build_computed_metrics_node(
metric_spec=metric_spec,
aggregated_measures_node=aggregated_measures_node,
)

def _build_derived_metric_output_node(
self,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> ComputeMetricsNode:
"""Builds a node to compute a metric defined from other metrics."""
metric = self._metric_lookup.get_metric(metric_spec.reference)
metric_input_specs = self._metric_lookup.metric_input_specs_for_metric(
metric_reference=metric_spec.reference,
column_association_resolver=self._column_association_resolver,
)
logger.info(
f"For {metric.type} metric: {metric_spec}, needed metrics are:\n"
f"{pformat_big_objects(metric_input_specs=metric_input_specs)}"
)
return ComputeMetricsNode(
parent_node=self._build_metrics_output_node(
metric_specs=metric_input_specs,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
),
metric_specs=[metric_spec],
)

def _build_any_metric_output_node(
self,
metric_spec: MetricSpec,
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> ComputeMetricsNode:
"""Builds a node to compute a metric of any type."""
metric = self._metric_lookup.get_metric(metric_spec.reference)

if metric.type is MetricType.SIMPLE or metric.type is MetricType.CUMULATIVE:
return self._build_base_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
)

elif metric.type is MetricType.RATIO or metric.type is MetricType.DERIVED:
return self._build_derived_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
)

assert_values_exhausted(metric.type)

def _build_metrics_output_node(
self,
metric_specs: Sequence[MetricSpec],
queried_linkable_specs: LinkableSpecSet,
where_constraint: Optional[WhereFilterSpec] = None,
time_range_constraint: Optional[TimeRangeConstraint] = None,
) -> BaseOutput:
"""Builds a computed metrics output node.
"""Builds a node that computes all requested metrics.
Args:
metric_specs: Specs for metrics to compute.
Expand All @@ -178,72 +279,19 @@ def _build_metrics_output_node(
time_range_constraint: Time range constraint used to compute the metric.
"""
output_nodes: List[BaseOutput] = []
compute_metrics_node: Optional[ComputeMetricsNode] = None

for metric_spec in metric_specs:
logger.info(f"Generating compute metrics node for {metric_spec}")
metric_reference = metric_spec.reference
metric = self._metric_lookup.get_metric(metric_reference)

if metric.type is MetricType.DERIVED or metric.type is MetricType.RATIO:
metric_input_specs = self._metric_lookup.metric_input_specs_for_metric(
metric_reference=metric_reference,
column_association_resolver=self._column_association_resolver,
)
logger.info(
f"For {metric.type} metric: {metric_spec}, needed metrics are:\n"
f"{pformat_big_objects(metric_input_specs=metric_input_specs)}"
)
compute_metrics_node = ComputeMetricsNode(
parent_node=self._build_metrics_output_node(
metric_specs=metric_input_specs,
queried_linkable_specs=queried_linkable_specs,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
),
metric_specs=[metric_spec],
)
elif metric.type is MetricType.SIMPLE or MetricType.CUMULATIVE:
metric_input_measure_specs = self._metric_lookup.measures_for_metric(
metric_reference=metric_reference,
column_association_resolver=self._column_association_resolver,
)
assert (
len(metric_input_measure_specs) == 1
), "Simple and cumulative metrics must have one input measure."
metric_input_measure_spec = metric_input_measure_specs[0]
self._metric_lookup.get_metric(metric_spec.reference)

logger.info(
f"For {metric_spec}, needed measure is:\n"
f"{pformat_big_objects(metric_input_measure_spec=metric_input_measure_spec)}"
)
combined_where = where_constraint
if metric_spec.constraint:
combined_where = (
combined_where.combine(metric_spec.constraint) if combined_where else metric_spec.constraint
)
aggregated_measures_node = self.build_aggregated_measure(
metric_input_measure_spec=metric_input_measure_spec,
output_nodes.append(
self._build_any_metric_output_node(
metric_spec=metric_spec,
queried_linkable_specs=queried_linkable_specs,
where_constraint=combined_where,
where_constraint=where_constraint,
time_range_constraint=time_range_constraint,
cumulative=metric.type == MetricType.CUMULATIVE,
cumulative_window=metric.type_params.window if metric.type == MetricType.CUMULATIVE else None,
cumulative_grain_to_date=(
metric.type_params.grain_to_date if metric.type == MetricType.CUMULATIVE else None
),
)
compute_metrics_node = self.build_computed_metrics_node(
metric_spec=metric_spec,
aggregated_measures_node=aggregated_measures_node,
)
else:
assert_values_exhausted(metric.type)

assert compute_metrics_node is not None

output_nodes.append(compute_metrics_node)
)

assert len(output_nodes) > 0, "ComputeMetricsNode was not properly constructed"

Expand Down Expand Up @@ -597,6 +645,35 @@ def build_computed_metrics_node(
metric_specs=[metric_spec],
)

def _measures_for_metric(
self,
metric_reference: MetricReference,
column_association_resolver: ColumnAssociationResolver,
) -> Sequence[MetricInputMeasureSpec]:
"""Return the measure specs required to compute the metric."""
metric = self._metric_lookup.get_metric(metric_reference)
input_measure_specs: List[MetricInputMeasureSpec] = []

for input_measure in metric.input_measures:
measure_spec = MeasureSpec(
element_name=input_measure.name,
non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get(
input_measure.measure_reference
),
)
spec = MetricInputMeasureSpec(
measure_spec=measure_spec,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
join_to_timespine=input_measure.join_to_timespine,
fill_nulls_with=input_measure.fill_nulls_with,
)
input_measure_specs.append(spec)

return tuple(input_measure_specs)

def build_aggregated_measure(
self,
metric_input_measure_spec: MetricInputMeasureSpec,
Expand Down
31 changes: 0 additions & 31 deletions metricflow/model/semantics/metric_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
from metricflow.specs.column_assoc import ColumnAssociationResolver
from metricflow.specs.specs import (
LinkableInstanceSpec,
MeasureSpec,
MetricInputMeasureSpec,
MetricSpec,
)
from metricflow.specs.where_filter_transform import WhereSpecFactory
Expand Down Expand Up @@ -123,35 +121,6 @@ def configured_input_measure_for_metric(self, metric_reference: MetricReference)
else:
assert_values_exhausted(metric.type)

def measures_for_metric(
self,
metric_reference: MetricReference,
column_association_resolver: ColumnAssociationResolver,
) -> Sequence[MetricInputMeasureSpec]:
"""Return the measure specs required to compute the metric."""
metric = self.get_metric(metric_reference)
input_measure_specs: List[MetricInputMeasureSpec] = []

for input_measure in metric.input_measures:
measure_spec = MeasureSpec(
element_name=input_measure.name,
non_additive_dimension_spec=self._semantic_model_lookup.non_additive_dimension_specs_by_measure.get(
input_measure.measure_reference
),
)
spec = MetricInputMeasureSpec(
measure_spec=measure_spec,
constraint=WhereSpecFactory(
column_association_resolver=column_association_resolver,
).create_from_where_filter_intersection(input_measure.filter),
alias=input_measure.alias,
join_to_timespine=input_measure.join_to_timespine,
fill_nulls_with=input_measure.fill_nulls_with,
)
input_measure_specs.append(spec)

return tuple(input_measure_specs)

def contains_cumulative_or_time_offset_metric(self, metric_references: Sequence[MetricReference]) -> bool:
"""Returns true if any of the specs correspond to a cumulative metric or a derived metric with time offset."""
for metric_reference in metric_references:
Expand Down
10 changes: 0 additions & 10 deletions metricflow/protocols/semantics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from metricflow.specs.specs import (
LinkableInstanceSpec,
MeasureSpec,
MetricInputMeasureSpec,
MetricSpec,
NonAdditiveDimensionSpec,
)
Expand Down Expand Up @@ -165,15 +164,6 @@ def metric_references(self) -> Sequence[MetricReference]:
def get_metric(self, metric_reference: MetricReference) -> Metric: # noqa:D
raise NotImplementedError

@abstractmethod
def measures_for_metric(
self,
metric_reference: MetricReference,
column_association_resolver: ColumnAssociationResolver,
) -> Sequence[MetricInputMeasureSpec]:
"""Return the measure specs required to compute the metric."""
raise NotImplementedError

@abstractmethod
def contains_cumulative_or_time_offset_metric(self, metric_references: Sequence[MetricReference]) -> bool:
"""Returns true if any of the specs correspond to a cumulative metric or a derived metric with time offset."""
Expand Down

0 comments on commit 0361b0e

Please sign in to comment.