Skip to content

Commit

Permalink
Add sql generators to LinearRegression and Ridge.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591602164
  • Loading branch information
tcya authored and meterstick-copybara committed Apr 15, 2024
1 parent bcfa3cb commit e115f4e
Show file tree
Hide file tree
Showing 9 changed files with 713 additions and 268 deletions.
18 changes: 17 additions & 1 deletion meterstick_custom_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2233,7 +2233,7 @@
},
"source": [
"##Custom Operation\n",
"Writing a custom `Operation` is more complex. Typically an `Operation` needs to compute some util `Metric`s. A common one is its child `Metric`. The tricky part is how to make sure the additional computations interact correctly with the cache. First take a look at the Caching section below to understand how caching works in `Meterstick`. Then here is a decision tree to help you.\n",
"Writing a custom `Operation` is more complex. Typically an `Operation` needs to compute some util `Metric`s. A common one is its child `Metric`. The tricky part is how to make sure the additional computations interact correctly with the cache. First take a look at the Caching section above to understand how caching works in `Meterstick`. Then here is a decision tree to help you.\n",
"\n",
"\n",
" +-----------------------------------+ \n",
Expand Down Expand Up @@ -2272,6 +2272,22 @@
"1. We don't check if the input data is consistent when using caching so users need to make sure the util `Metric` is computed on the same data as other `Metric`s are. If the util `Metric` is computed on the data passed to the method your are overriding, it's safe to use the recommended methods below that will save the result to cache.\n",
"1. A very common scenerio is that an `Operation` needs to compute the child `Metric` first. Use `compute_child` or `compute_child_sql` to do so. Oftentimes the child `Metric` needs to be computed with an extended `split_by`, for example, `PercentChange('grp', 'base', Sum(x)).compute_on(data, 'foo')` will need to compute `Sum(x).compute_on(data, ['foo', 'grp'])` first. The recommended way is that you register the extra dimensions, `grp` in the example, in `__init__()`. Then the default `compute_child` and `compute_child_sql` will return the result of the child `Metric` you want. You only need to implement the `compute_on_children` then.\n",
"1. The extra dimensions are stored in `self.extra_split_by`. There is another attribute `extra_index` which stores the indexes the `Operation` adds. When unspecified, it will be set to the value of `self.extra_split_by`. For complex `Operations` the two can differ. For example, in the computation of `MH(condition_column, baseline_key, stratified_by, child)`, we need to compute `child.compute_on(df, split_by + [condition_column, stratified_by])` so the `extra_split_by` is `[condition_column, stratified_by]`. However, `stratified_by` won't show up in the final result so you need to explicitly set the `extra_index` to `condition_column`.\n",
"1. If you need to store part of the extra dimensions in another attribute, make it a property that is dynamically computed from `extra_split_by` or `extra_index`. Do NOT make it statically assigned in the `__init__`. For example, in `MH`, we don't do\n",
"```\n",
"def __init__(self, ..., stratified_by, ...):\n",
" ...\n",
" self.stratified_by = stratified_by\n",
"```\n",
"Instead, we do\n",
"```\n",
"@property\n",
"def stratified_by(self):\n",
" return self.extra_split_by[len(self.extra_index):]\n",
"@stratified_by.setter\n",
"def stratified_by(self, stratified_by):\n",
" self.extra_split_by[len(self.extra_index):] = stratified_by\n",
"```\n",
"The reason is that in SQL generation, if the dimensions in the attribute have any special character, they will get sanitized so queries based on them need to be adjusted. We will take care of `extra_split_by` and `extra_index` but we don't have knowledge about your attributes so if they are static the SQL query might be invalid.\n",
"1. When you need to manually construct the extended `split_by`, make the extra dimensions come after the original `split_by`. That's how we do it for all built-in `Operation`s so the caching could be maximized.\n",
"1. Try to vectorize the `Operation` as much as possible. If it's hard to vectorize, often you can at least compute the child `Metric` in a vectorized way by calling `compute_child`. Then implement `compute(self, df_slice)` which handles a slice of the data returned by the child `Metric`. See `CumulativeDistribution` below for an example.\n",
"1. When you need to compute a util `Metric` other than the child `Metric`, use `compute_util_metric_on` or `compute_util_metric_on_sql`. `compute_child` and `compute_child_sql` are just wrappers of them for the child `Metric`.\n",
Expand Down
32 changes: 29 additions & 3 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None):
"""Executes the query from to_sql() and process the result."""
query = self.to_sql(table, split_by)
res = execute(str(query))
extra_idx = list(utils.get_extra_idx(self, return_superset=True))
extra_idx = list(self.get_extra_idx(return_superset=True))
indexes = split_by + extra_idx if split_by else extra_idx
columns = [a.alias_raw for a in query.groupby.add(query.columns)]
columns[:len(indexes)] = indexes
Expand All @@ -692,7 +692,7 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None):
"""Generates SQL query for the metric."""
global_filter = utils.get_global_filter(self)
indexes = sql.Columns(split_by).add(
utils.get_extra_idx(self, return_superset=True)
self.get_extra_idx(return_superset=True)
)
with_data = sql.Datasources()
if isinstance(table, sql.Sql) and table.with_data:
Expand Down Expand Up @@ -941,6 +941,32 @@ def add_edges(metric):
add_edges(self)
return dot.to_string()

def get_extra_idx(self, return_superset=False):
"""Collects the extra indexes added by self and its descendants.
Args:
return_superset: If to return the superset of extra indexes if metric has
incompatible indexes.
Returns:
A tuple of column names which are just the index of metric.compute_on(df).
"""
extra_idx = self.extra_index[:]
children_idx = [
c.get_extra_idx(return_superset)
for c in self.children
if utils.is_metric(c)
]
if len(set(children_idx)) > 1:
if not return_superset:
raise ValueError('Incompatible indexes!')
children_idx_superset = set()
children_idx_superset.update(*children_idx)
children_idx = [list(children_idx_superset)]
if children_idx:
extra_idx += list(children_idx[0])
return tuple(extra_idx)

def traverse(self, include_self=True, include_constants=False):
ms = [self] if include_self else list(self.children)
while ms:
Expand Down Expand Up @@ -1291,7 +1317,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
The global with_data which holds all datasources we need in the WITH
clause.
"""
utils.get_extra_idx(self) # Check if indexes are compatible.
self.get_extra_idx() # Check if indexes are compatible.
local_filter = (
sql.Filters(self.where_).add(local_filter).remove(global_filter)
)
Expand Down
Loading

0 comments on commit e115f4e

Please sign in to comment.