Skip to content

Commit

Permalink
Add sql generators to LinearRegression and Ridge.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591602164
  • Loading branch information
tcya authored and meterstick-copybara committed Apr 3, 2024
1 parent bcfa3cb commit df7dde9
Show file tree
Hide file tree
Showing 8 changed files with 399 additions and 117 deletions.
32 changes: 29 additions & 3 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
"""Executes the query from to_sql() and process the result."""
query = self.to_sql(table, split_by)
res = execute(str(query))
extra_idx = list(utils.get_extra_idx(self, return_superset=True))
extra_idx = list(self.get_extra_idx(return_superset=True))
indexes = split_by + extra_idx if split_by else extra_idx
columns = [a.alias_raw for a in query.groupby.add(query.columns)]
columns[:len(indexes)] = indexes
Expand All @@ -692,7 +692,7 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
"""Generates SQL query for the metric."""
global_filter = utils.get_global_filter(self)
indexes = sql.Columns(split_by).add(
utils.get_extra_idx(self, return_superset=True)
self.get_extra_idx(return_superset=True)
)
with_data = sql.Datasources()
if isinstance(table, sql.Sql) and table.with_data:
Expand Down Expand Up @@ -941,6 +941,32 @@ def add_edges(metric):
add_edges(self)
return dot.to_string()

def get_extra_idx(self, return_superset=False):
"""Collects the extra indexes added by self and its descendants.
Args:
return_superset: If to return the superset of extra indexes if metric has
incompatible indexes.
Returns:
A tuple of column names which are just the index of metric.compute_on(df).
"""
extra_idx = self.extra_index[:]
children_idx = [
c.get_extra_idx(return_superset)
for c in self.children
if utils.is_metric(c)
]
if len(set(children_idx)) > 1:
if not return_superset:
raise ValueError('Incompatible indexes!')
children_idx_superset = set()
children_idx_superset.update(*children_idx)
children_idx = [list(children_idx_superset)]
if children_idx:
extra_idx += list(children_idx[0])
return tuple(extra_idx)

def traverse(self, include_self=True, include_constants=False):
ms = [self] if include_self else list(self.children)
while ms:
Expand Down Expand Up @@ -1291,7 +1317,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
The global with_data which holds all datasources we need in the WITH
clause.
"""
utils.get_extra_idx(self) # Check if indexes are compatible.
self.get_extra_idx() # Check if indexes are compatible.
local_filter = (
sql.Filters(self.where_).add(local_filter).remove(global_filter)
)
Expand Down
Loading

0 comments on commit df7dde9

Please sign in to comment.