diff --git a/metricflow-semantics/metricflow_semantics/instances.py b/metricflow-semantics/metricflow_semantics/instances.py index 779a2c879..c6a0f8e2b 100644 --- a/metricflow-semantics/metricflow_semantics/instances.py +++ b/metricflow-semantics/metricflow_semantics/instances.py @@ -49,7 +49,9 @@ class MdoInstance(ABC, Generic[SpecT]): @property def associated_column(self) -> ColumnAssociation: """Helper for getting the associated column until support for multiple associated columns is added.""" - assert len(self.associated_columns) == 1 + assert ( + len(self.associated_columns) == 1 + ), f"Expected exactly one column for {self.__class__.__name__}, but got {self.associated_columns}" return self.associated_columns[0] def accept(self, visitor: InstanceVisitor[VisitorOutputT]) -> VisitorOutputT: diff --git a/metricflow/dataset/sql_dataset.py b/metricflow/dataset/sql_dataset.py index 363dbac33..214d101eb 100644 --- a/metricflow/dataset/sql_dataset.py +++ b/metricflow/dataset/sql_dataset.py @@ -4,7 +4,7 @@ from dbt_semantic_interfaces.references import SemanticModelReference from metricflow_semantics.assert_one_arg import assert_exactly_one_arg_set -from metricflow_semantics.instances import EntityInstance, InstanceSet +from metricflow_semantics.instances import EntityInstance, InstanceSet, TimeDimensionInstance from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat from metricflow_semantics.specs.column_assoc import ColumnAssociation from metricflow_semantics.specs.dimension_spec import DimensionSpec @@ -122,30 +122,32 @@ def column_association_for_dimension( return column_associations_to_return[0] - def column_association_for_time_dimension( - self, - time_dimension_spec: TimeDimensionSpec, - ) -> ColumnAssociation: - """Given the name of the time dimension, return the set of columns associated with it in the data set.""" + def instances_for_time_dimensions( + self, time_dimension_specs: Sequence[TimeDimensionSpec] + ) -> List[TimeDimensionInstance]: + """Return the instances associated with these specs in the data set.""" matching_instances = 0 - column_associations_to_return = None + instances_to_return: List[TimeDimensionInstance] = [] for time_dimension_instance in self.instance_set.time_dimension_instances: - if time_dimension_instance.spec == time_dimension_spec: - column_associations_to_return = time_dimension_instance.associated_columns + if time_dimension_instance.spec in time_dimension_specs: + instances_to_return.append(time_dimension_instance) matching_instances += 1 - if matching_instances > 1: + if matching_instances != len(time_dimension_specs): raise RuntimeError( - f"More than one time dimension instance with spec {time_dimension_spec} in " - f"instance set: {self.instance_set}" + f"Unexpected number of time dimension instances found matching specs.\nSpecs: {time_dimension_specs}\n" + f"Instances: {instances_to_return}" ) - if not column_associations_to_return: - raise RuntimeError( - f"No time dimension instances with spec {time_dimension_spec} in instance set: {self.instance_set}" - ) + return instances_to_return - return column_associations_to_return[0] + def instance_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> TimeDimensionInstance: + """Given the name of the time dimension, return the instance associated with it in the data set.""" + return self.instances_for_time_dimensions([time_dimension_spec])[0] + + def column_association_for_time_dimension(self, time_dimension_spec: TimeDimensionSpec) -> ColumnAssociation: + """Given the name of the time dimension, return the set of columns associated with it in the data set.""" + return self.instance_for_time_dimension(time_dimension_spec).associated_column @property @override diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 61d3d510e..e85eb2af2 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -1470,16 +1470,8 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set}) ) - # Select matching instance from time spine data set (using base grain - custom grain will be joined in a later node). - original_time_spine_dim_instance: Optional[TimeDimensionInstance] = None - for time_dimension_instance in time_spine_dataset.instance_set.time_dimension_instances: - if time_dimension_instance.spec == agg_time_dimension_instance_for_join.spec: - original_time_spine_dim_instance = time_dimension_instance - break - assert original_time_spine_dim_instance, ( - "Couldn't find requested agg_time_dimension_instance_for_join in time spine data set, which " - f"indicates it may have been configured incorrectly. Expected: {agg_time_dimension_instance_for_join.spec};" - f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}" + original_time_spine_dim_instance = time_spine_dataset.instance_for_time_dimension( + agg_time_dimension_instance_for_join.spec ) time_spine_column_select_expr: Union[ SqlColumnReferenceExpression, SqlDateTruncExpression @@ -1590,17 +1582,10 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod # New dataset will be joined to parent dataset without a subquery, so use the same FROM alias as the parent node. parent_alias = parent_data_set.checked_sql_select_node.from_source_alias - parent_time_dimension_instance: Optional[TimeDimensionInstance] = None - for instance in parent_data_set.instance_set.time_dimension_instances: - if instance.spec == node.time_dimension_spec.with_base_grain(): - parent_time_dimension_instance = instance - break - parent_column: Optional[SqlSelectColumn] = None - assert parent_time_dimension_instance, ( - "JoinToCustomGranularityNode's expected time_dimension_spec not found in parent dataset instances. " - f"This indicates internal misconfiguration. Expected: {node.time_dimension_spec.with_base_grain()}; " - f"Got: {[instance.spec for instance in parent_data_set.instance_set.time_dimension_instances]}" + parent_time_dimension_instance = parent_data_set.instance_for_time_dimension( + node.time_dimension_spec.with_base_grain() ) + parent_column: Optional[SqlSelectColumn] = None for select_column in parent_data_set.checked_sql_select_node.select_columns: if select_column.column_alias == parent_time_dimension_instance.associated_column.column_name: parent_column = select_column