Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Sep 30, 2024
1 parent f175b9b commit 179ff72
Show file tree
Hide file tree
Showing 366 changed files with 43,172 additions and 82,208 deletions.
43 changes: 27 additions & 16 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,28 +801,34 @@ def visit_write_to_result_table_node(self, node: WriteToResultTableNode) -> SqlD

def visit_filter_elements_node(self, node: FilterElementsNode) -> SqlDataSet:
"""Generates the query that realizes the behavior of FilterElementsNode."""
from_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = from_data_set.instance_set.transform(FilterElements(node.include_specs))
from_data_set_alias = self._next_unique_table_alias()

# Also, the output columns should always follow the resolver format.
output_instance_set = output_instance_set.transform(ChangeAssociatedColumns(self._column_association_resolver))

# This creates select expressions for all columns referenced in the instance set.
select_columns = output_instance_set.transform(
CreateSelectColumnsForInstances(from_data_set_alias, self._column_association_resolver)
).as_tuple()
parent_data_set: SqlDataSet = node.parent_node.accept(self)
output_instance_set = parent_data_set.instance_set.transform(FilterElements(node.include_specs))
output_column_names = [instance.associated_column.column_name for instance in output_instance_set.as_tuple]
output_select_columns = [
select_column
for select_column in parent_data_set.checked_sql_select_node.select_columns
if select_column.column_alias in output_column_names
]
# where is the limiting factor! need those columns in the select statement
# could have a conditional - use subquery if where clase, else don't
where = parent_data_set.checked_sql_select_node.where

# If distinct values requested, group by all select columns.
group_bys = select_columns if node.distinct else ()
group_bys = tuple(output_select_columns if node.distinct else parent_data_set.checked_sql_select_node.group_bys)
return SqlDataSet(
instance_set=output_instance_set,
# add method from parent node with override params
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=select_columns,
from_source=from_data_set.checked_sql_select_node,
from_source_alias=from_data_set_alias,
description=node.parent_node.description + "\n" + node.description,
select_columns=tuple(output_select_columns),
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_data_set.checked_sql_select_node.from_source_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs,
where=where,
group_bys=group_bys,
order_bys=parent_data_set.checked_sql_select_node.order_bys,
limit=parent_data_set.checked_sql_select_node.limit,
distinct=parent_data_set.checked_sql_select_node.distinct,
),
)

Expand Down Expand Up @@ -1499,6 +1505,11 @@ def visit_join_to_custom_granularity_node(self, node: JoinToCustomGranularityNod
from_source=parent_data_set.checked_sql_select_node.from_source,
from_source_alias=parent_alias,
join_descs=parent_data_set.checked_sql_select_node.join_descs + (join_description,),
where=parent_data_set.checked_sql_select_node.where,
group_bys=parent_data_set.checked_sql_select_node.group_bys,
order_bys=parent_data_set.checked_sql_select_node.order_bys,
limit=parent_data_set.checked_sql_select_node.limit,
distinct=parent_data_set.checked_sql_select_node.distinct,
),
)

Expand Down
Loading

0 comments on commit 179ff72

Please sign in to comment.