From aa2c91f1b030ee842ecb83a0ade0d2d1249de854 Mon Sep 17 00:00:00 2001 From: Xunmo Yang Date: Thu, 2 Nov 2023 11:12:18 -0700 Subject: [PATCH] Enable optimizations in the 'mixed' mode of Jackknife/Bootstrap.compute_on_sql. PiperOrigin-RevId: 578917438 --- README.md | 10 +- meterstick_demo.ipynb | 237 ++++++++++++++++++- metrics.py | 387 +++++++++++++++++++++++++++--- metrics_test.py | 60 ++++- models.py | 35 +-- models_test.py | 49 ++-- operations.py | 533 ++++++++++++++++++++++++++++++++++++++---- operations_test.py | 26 ++- 8 files changed, 1198 insertions(+), 139 deletions(-) diff --git a/README.md b/README.md index a12497a..1a33726 100644 --- a/README.md +++ b/README.md @@ -59,6 +59,8 @@ Currently built-in metrics include: Sum(denominator)`. + `Quantile(variable, quantile(s))`: calculates the `quantile(s)` quantile for `variable`. ++ `Nth(variable, sort_by, n, ascending=True, dropna=False)` computes the `n`th + value after sorting by `sort_by`. + `Variance(variable, unbiased=True)`: calculates the variance of `variable`; `unbiased` determines whether the unbiased (sample) or population estimate is used. @@ -300,10 +302,10 @@ It can help you to sanity check complex Metrics. ## SQL -You can get the SQL query for all built-in Metrics and Operations (except -weighted Quantile) by calling `to_sql(sql_data_source, -split_by)` on the Metric. `sql_data_source` could be a table or a subquery. The -dialect it uses is the [standard SQL](https://cloud.google.com/bigquery/docs/reference/standard-sql) +You can get the SQL query for all built-in Metrics and Operations by calling +`to_sql(sql_data_source, split_by)` on the Metric. `sql_data_source` could be a +table or a subquery. The dialect it uses is the +[standard SQL](https://cloud.google.com/bigquery/docs/reference/standard-sql) in Google Cloud's BigQuery. For example, ```python diff --git a/meterstick_demo.ipynb b/meterstick_demo.ipynb index 287cdbc..1a16e7d 100644 --- a/meterstick_demo.ipynb +++ b/meterstick_demo.ipynb @@ -2299,6 +2299,236 @@ "Mean('clicks', 'impressions').compute_on(df)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "OKK1H6_3qszU" + }, + "source": [ + "## Nth\n", + "\n", + "`Nth(var, sort_by, n, ascending=True, dropna=False)` computes the `n`th value of `var` after sorting by `sort_by`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "height": 81 + }, + "executionInfo": { + "elapsed": 55, + "status": "ok", + "timestamp": 1697859374809, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "qeZdhz-Zfpd1", + "outputId": "5f55b436-a454-49e2-928c-b2901173a613" + }, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \u003cdiv id=\"df-78fb1da5-44ea-49d5-9127-76c891a02da1\" class=\"colab-df-container\"\u003e\n", + " \u003cdiv\u003e\n", + "\u003cstyle scoped\u003e\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "\u003c/style\u003e\n", + "\u003ctable border=\"1\" class=\"dataframe\"\u003e\n", + " \u003cthead\u003e\n", + " \u003ctr style=\"text-align: right;\"\u003e\n", + " \u003cth\u003e\u003c/th\u003e\n", + " \u003cth\u003e1st(clicks) sort by impressions asc\u003c/th\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/thead\u003e\n", + " \u003ctbody\u003e\n", + " \u003ctr\u003e\n", + " \u003cth\u003e0\u003c/th\u003e\n", + " \u003ctd\u003e1.163701\u003c/td\u003e\n", + " \u003c/tr\u003e\n", + " \u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\u003c/div\u003e\n", + " \u003cdiv class=\"colab-df-buttons\"\u003e\n", + "\n", + " \u003cdiv class=\"colab-df-container\"\u003e\n", + " \u003cbutton class=\"colab-df-convert\" onclick=\"convertToInteractive('df-78fb1da5-44ea-49d5-9127-76c891a02da1')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\"\u003e\n", + "\n", + " \u003csvg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\"\u003e\n", + " \u003cpath d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/\u003e\n", + " \u003c/svg\u003e\n", + " \u003c/button\u003e\n", + "\n", + " \u003cstyle\u003e\n", + " .colab-df-container {\n", + " display:flex;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " .colab-df-buttons div {\n", + " margin-bottom: 4px;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " \u003c/style\u003e\n", + "\n", + " \u003cscript\u003e\n", + " const buttonEl =\n", + " document.querySelector('#df-78fb1da5-44ea-49d5-9127-76c891a02da1 button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-78fb1da5-44ea-49d5-9127-76c891a02da1');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '\u003ca target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb\u003edata table notebook\u003c/a\u003e'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " \u003c/script\u003e\n", + " \u003c/div\u003e\n", + "\n", + " \u003c/div\u003e\n", + " \u003c/div\u003e\n" + ], + "text/plain": [ + " 1st(clicks) sort by impressions asc\n", + "0 1.163701" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Nth('clicks', 'impressions', 0).compute_on(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "executionInfo": { + "elapsed": 54, + "status": "ok", + "timestamp": 1697859834379, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "2Nk-ymNiVP4M", + "outputId": "57fd30f0-666d-4310-ac44-a257a20b2fc4" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "1 1.163701\n", + "Name: clicks, dtype: float64" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df.sort_values('impressions').clicks.head(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 54, + "status": "ok", + "timestamp": 1697093923046, + "user": { + "displayName": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "V6uR6l6nrQ-u", + "outputId": "bbfc39de-bc38-4ea3-86f9-e3a6359197c2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# n can be negative and it's equivalent to reversing n and ascending together.\n", + "Nth('x', 'y', -1) == Nth('x', 'y', 0, False)" + ] + }, { "cell_type": "markdown", "metadata": { @@ -18261,12 +18491,9 @@ "source": [ "#SQL\n", "\n", - "You can easily get SQL query for all built-in Metrics and Operations, except for weighted Quantile/CV/Correlation/Cov, by calling\n", - "\n", - "\u003e to_sql(sql_table, split_by).\n", + "You can easily get SQL query for all built-in Metrics and Operations by calling `to_sql(sql_table, split_by)`.\n", "\n", - "You can also directly execute the query by calling\n", - "\u003e compute_on_sql(sql_table, split_by, execute, melted),\n", + "You can also directly execute the query by calling `compute_on_sql(sql_table, split_by, execute, melted)`,\n", "\n", "where `execute` is a function that can execute SQL queries. The return is very similar to compute_on().\n", "\n", diff --git a/metrics.py b/metrics.py index e5e22de..3853854 100644 --- a/metrics.py +++ b/metrics.py @@ -109,6 +109,7 @@ def to_sql(table, split_by=None): 'Mean', 'Max', 'Min', + 'Nth', 'Quantile', 'Variance', 'StandardDeviation', @@ -654,8 +655,19 @@ def compute_through_sql(self, table, split_by, execute, mode): pass if self.where: table = sql.Sql(None, table, self.where) - res = self.compute_on_sql_mixed_mode(table, split_by, execute, mode) - return self.to_series_or_number_if_not_operation(res) + try: + res = self.compute_on_sql_mixed_mode(table, split_by, execute, mode) + return self.to_series_or_number_if_not_operation(res) + except NotImplementedError: + raise + except Exception as e: # pylint: disable=broad-except + if mode: + raise ValueError( + 'Please see the root cause of the failure above. You can try' + ' `mode=None` to see if it helps.' + ) from e + else: + raise def to_series_or_number_if_not_operation(self, df): return self.to_series_or_number(df) if not self.is_operation else df @@ -749,8 +761,11 @@ def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): children = self.compute_children_sql(table, split_by, execute, mode) return self.compute_on_children(children, split_by) - def compute_children_sql(self, table, split_by, execute, mode=None): + def compute_children_sql( + self, table, split_by, execute, mode, *args, **kwargs + ): """The return should be similar to compute_children().""" + del args, kwargs # unused children = [] for c in self.children: if not isinstance(c, Metric): @@ -1299,9 +1314,15 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, local_filter, with_data)[0] for c in self.children ] + children_sql_copy = copy.deepcopy(children_sql) incompatible_sqls = sql.Datasources() - for child_sql in children_sql: - incompatible_sqls.merge(sql.Datasource(child_sql, 'MetricListChildTable')) + child_table_aliases = [] + for i, child_sql in enumerate(children_sql): + child_table_aliases.append( + incompatible_sqls.merge( + sql.Datasource(child_sql, 'MetricListChildTable') + ) + ) name_tmpl = self.name_tmpl or '{}' if len(incompatible_sqls) == 1: @@ -1312,10 +1333,23 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, return res, with_data columns = sql.Columns(indexes.aliases) - for i, (alias, table) in enumerate(incompatible_sqls.children.items()): - data = sql.Datasource(table, alias) - alias = with_data.merge(data) - for c in table.columns: + alias_lookup = {} + from_data = None + for i, child_sql in enumerate(children_sql_copy): + child_table_alias = child_table_aliases[i] + if child_table_alias in alias_lookup: + alias = alias_lookup[child_table_alias] + else: + table = incompatible_sqls.children[child_table_alias] + data = sql.Datasource(table, child_table_alias) + alias = with_data.merge(data) + alias_lookup[child_table_alias] = alias + if i == 0: + from_data = alias + else: + join = 'FULL' if indexes else 'CROSS' + from_data = sql.Join(from_data, alias, join=join, using=indexes) + for c in child_sql.columns: if c not in columns: columns.add( sql.Column( @@ -1323,11 +1357,6 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, alias=name_tmpl.format(c.alias_raw), ) ) - if i == 0: - from_data = alias - else: - join = 'FULL' if indexes else 'CROSS' - from_data = sql.Join(from_data, alias, join=join, using=indexes) query = sql.Sql(columns, from_data) if self.columns: @@ -1892,6 +1921,152 @@ def get_sql_columns(self, local_filter): return sql.Column(self.var, 'MIN({})', self.name, local_filter) +class Nth(SimpleMetric): + """The n-th value of var when sorting by sort_by. + + Attributes: + var: Column to compute on. + var2: Column to sort by. + n: The `n`-th value to get. + ascending: If to sort in ascending order. + dropna: If to drop NA in var before counting. + name: Name of the Metric. + where: A string or list of strings to be concatenated that will be passed to + df.query() as a prefilter. + And all other attributes inherited from SimpleMetric. + """ + + def __init__( + self, + var: Text, + sort_by: Text, + n: int, + ascending: bool = True, + dropna: bool = False, + name: Optional[Text] = None, + where: Optional[Union[Text, Sequence[Text]]] = None, + additional_fingerprint_attrs: Optional[List[str]] = None, + ): + if not isinstance(n, int): + raise ValueError('n must be an integer.') + if n < 0: + n = -n - 1 + ascending = not ascending + self.n = n + self.ascending = ascending + self.dropna = dropna + self.var2 = sort_by + i = n + 1 + if i % 10 == 1 and i % 100 != 11: + tmpl = f'{i}st' + elif i % 10 == 2 and i % 100 != 12: + tmpl = f'{i}nd' + elif i % 10 == 3 and i % 100 != 13: + tmpl = f'{i}rd' + else: + tmpl = f'{i}th' + order = 'asc' if ascending else 'desc' + name_tmpl = '%s({}) sort by %s %s' % (tmpl, sort_by, order) + additional_fingerprint_attrs = (additional_fingerprint_attrs or []) + [ + 'n', + 'dropna', + 'ascending', + ] + super(Nth, self).__init__( + var, + name, + name_tmpl, + where, + additional_fingerprint_attrs=additional_fingerprint_attrs + ) + + def compute_slices(self, df, split_by=None): + if self.dropna: + df = df.dropna(subset=[self.var]) + df = df.sort_values(self.var2, ascending=self.ascending) + if split_by: + return self.group(df, split_by).nth(self.n)[self.var] + if self.n > len(df) - 1: + return np.nan + return df[self.var].values[self.n] + + def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, + local_filter, with_data): + """Gets the SQL query and WITH clause. + + If there is no local filter, the metric can be expressed in one line like + ARRAY_AGG(var IGNORE NULLS ORDER BY sort_by LIMIT n + 1)[SAFE_OFFSET(n)]. In + that case we will fall back to get_sql_columns(). + Otherwise the metric requires multiple subquries. We wil first add + SELECT split_by, var, sort_by FROM table WHERE local_filter + global_filter + to with_data + then generate one line query like above on the subquery. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + The SQL instance for metric, without the WITH clause component. + The global with_data which holds all datasources we need in the WITH + clause. + """ + local_filter = ( + sql.Filters(self.where_).add(local_filter).remove(global_filter) + ) + if not local_filter: + return super(Nth, self).get_sql_and_with_clause( + table, split_by, global_filter, indexes, None, with_data + ) + all_filters = sql.Filters(local_filter).add(global_filter) + if self.dropna: + all_filters.add(f'{self.var} IS NOT NULL') + split_by = sql.Columns(split_by) + var = sql.Column(self.var, alias=self.var) + var2 = sql.Column(self.var2, alias=self.var2) + filtered_sql = sql.Sql( + sql.Columns(split_by).add([var, var2]), table, all_filters + ) + filtered_table = sql.Datasource(filtered_sql, 'WeightedQuantileFiltered') + filtered_table_alias = with_data.merge(filtered_table) + no_filter = Nth( + var.alias, + var2.alias, + n=self.n, + dropna=self.dropna, + ascending=self.ascending, + name=self.name + ) + return super(Nth, no_filter).get_sql_and_with_clause( + filtered_table_alias, split_by.aliases, None, indexes, None, with_data + ) + + def get_sql_columns(self, local_filter): + if local_filter: + raise ValueError( + 'This case should be handled by get_sql_and_with_clause() already.' + ) + order = '' if self.ascending else ' DESC' + dropna = ' IGNORE NULLS' if self.dropna else '' + tmpl = 'ARRAY_AGG({}%s ORDER BY %s%s LIMIT %s)[SAFE_OFFSET(%s)]' % ( + dropna, + self.var2, + order, + self.n + 1, + self.n, + ) + return sql.Column( + self.var, + tmpl, + self.name, + ) + + class Quantile(SimpleMetric): """Quantile estimator. @@ -1938,34 +2113,31 @@ def __init__(self, super(Quantile, self).__init__(var, name, name_tmpl, where, ['quantile', 'weight', 'interpolation']) - def compute(self, df): - """Adapted from https://stackoverflow.com/a/29677616/12728137.""" - if not self.weight: - raise ValueError('Weight is missing in %s.' % self.name) - - sample_weight = np.array(df[self.weight]) - values = np.array(df[self.var]) - sorter = np.argsort(values) - values = values[sorter] - sample_weight = sample_weight[sorter] - weighted_quantiles = np.cumsum(sample_weight) - 0.5 * sample_weight - weighted_quantiles /= np.sum(sample_weight) - res = np.interp(self.quantile, weighted_quantiles, values) - if self.one_quantile: - return res - return pd.DataFrame( - [res], - columns=[self.name_tmpl.format(self.var, q) for q in self.quantile]) - def compute_slices(self, df, split_by=None): if self.weight: - # When there is weight, just loop through slices. - return super(Quantile, self).compute_slices(df, split_by) + # Adapted from https://stackoverflow.com/a/29677616/12728137. + def interp(d): + res = np.interp(self.quantile, d[self.weight], d[self.var]) + if self.one_quantile: + return res + return pd.DataFrame( + [res], + columns=[self.name_tmpl.format(self.var, q) for q in self.quantile]) + + df = df.groupby(split_by + [self.var])[self.weight].sum() + weighted_quantiles = self.group(df, split_by).cumsum() - 0.5 * df + weighted_quantiles /= self.group(df, split_by).sum() + if split_by: + weighted_quantiles = weighted_quantiles.reset_index(self.var) + return self.group(weighted_quantiles, split_by).apply(interp) + else: + weighted_quantiles = weighted_quantiles.to_frame().reset_index() + return interp(weighted_quantiles) + res = self.group(df, split_by)[self.var].quantile( self.quantile, interpolation=self.interpolation) if self.one_quantile: return res - if split_by: res = res.unstack() res.columns = [self.name_tmpl.format(self.var, c) for c in res] @@ -1977,7 +2149,7 @@ def compute_slices(self, df, split_by=None): def get_sql_columns(self, local_filter): """Get SQL columns.""" if self.weight: - raise ValueError('SQL for weighted quantile is not supported!') + raise ValueError('SQL for weighted quantile should already be handled!') if self.one_quantile: alias = 'quantile(%s, %s)' % (self.var, self.quantile) return sql.Column( @@ -1995,6 +2167,147 @@ def get_sql_columns(self, local_filter): sql.Column(self.var, query % int(100 * q), alias, local_filter)) return sql.Columns(quantiles) + def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, + local_filter, with_data): + """Gets the SQL for weighted quantile. + + The query is constructed as following. + 1. Add three subqueries below to the WITH clause: + AggregatedQuantileWeights AS (SELECT + split_by, + val, + SUM(weight) AS weight + FROM T + GROUP BY split_by, val), + QuantileWeights AS (SELECT + split_by, + val, + SAFE_DIVIDE(SUM(weight) OVER (PARTITION BY split_by ORDER BY val) + - 0.5 * weight, + SUM(weight) OVER (PARTITION BY split_by)) AS weight + FROM AggregatedQuantileWeights + ORDER BY split_by, val), + PairedQuantileWeights AS (SELECT + split_by, + val, + weight, + LAG(weight) OVER (PARTITION BY split_by ORDER BY val) AS prev_weight, + LEAD(weight) OVER (PARTITION BY split_by ORDER BY val) AS next_weight, + LEAD(val) OVER (PARTITION BY split_by ORDER BY val) AS next_value + FROM QuantileWeights) + 2. For each quantile q, SELECT + SUM(IF((prev_weight IS NULL AND q <= weight) OR + (next_weight IS NULL AND q >= weight), + val, + IF(q BETWEEN weight AND next_weight, + (next_value * (q - weight) + (next_weight - q) * val) / + (next_weight - weight), + 0))). + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + The SQL instance for metric, without the WITH clause component. + The global with_data which holds all datasources we need in the WITH + clause. + """ + if not self.weight: # Fall back to get_sql_columns(). + return super(Quantile, self).get_sql_and_with_clause( + table, split_by, global_filter, indexes, local_filter, with_data + ) + if self.interpolation != 'linear': + raise NotImplementedError('Only linear interpolation is supported!') + local_filter = ( + sql.Filters(self.where_).add(local_filter).remove(global_filter) + ) + split_by_and_value = sql.Columns(split_by).add(self.var) + weight = sql.Column( + self.weight, 'SUM({})', filters=local_filter, alias=self.weight + ) + cols = sql.Columns(split_by_and_value).add(weight) + deduped_weight_sql = sql.Sql(cols, table, global_filter, split_by_and_value) + deduped_weight_alias = with_data.merge( + sql.Datasource(deduped_weight_sql, 'AggregatedQuantileWeights') + ) + + v = split_by_and_value.aliases[-1] + w = weight.alias + split_by = sql.Columns(split_by.aliases) + split_by_and_value = sql.Columns(split_by_and_value.aliases) + total_weight = sql.Column(w, 'SUM({})', partition=split_by) + cum_weight = sql.Column( + w, + 'SUM({})', + partition=split_by, + order=v, + ) + normalized_weights = (cum_weight - 0.5 * sql.Column(w)) / total_weight + cols = sql.Columns(split_by_and_value).add(normalized_weights.set_alias(w)) + normalized_weights_sql = sql.Sql( + cols, deduped_weight_alias, orderby=split_by_and_value + ) + normalized_weights_alias = with_data.merge( + sql.Datasource(normalized_weights_sql, 'QuantileWeights') + ) + + prev_w = sql.Column( + w, + 'LAG({})', + 'prev_weight', + partition=split_by, + order=v + ) + next_w = sql.Column( + w, + 'LEAD({})', + 'next_weight', + partition=split_by, + order=v + ) + next_val = sql.Column( + v, + 'LEAD({})', + 'next_value', + partition=split_by, + order=v + ) + paired_weights_cols = sql.Columns(cols.aliases).add( + (prev_w, next_w, next_val) + ) + paired_weights_sql = sql.Sql(paired_weights_cols, normalized_weights_alias) + paired_weights_alias = with_data.merge( + sql.Datasource(paired_weights_sql, 'PairedQuantileWeights') + ) + + prev_w = prev_w.alias + next_w = next_w.alias + next_v = next_val.alias + cols = sql.Columns(split_by) + quantiles = [self.quantile] if self.one_quantile else self.quantile + for q in quantiles: + interp = ( + f'({next_v} * ({q} - {w}) + ({next_w} - {q}) * {v})' + f' / ({next_w} - {w})' + ) + cols.add( + sql.Column( + f"""SUM(IF(({prev_w} IS NULL AND {q} <= {w}) OR ({next_w} IS NULL AND {q} >= {w}), {v}, + IF({q} BETWEEN {w} AND {next_w}, {interp}, 0)))""", + alias=f'{self.weight}-weighted quantile({self.var}, {q})', + ) + ) + if self.one_quantile: + cols[-1].set_alias(self.name) + res_sql = sql.Sql(cols, paired_weights_alias, groupby=split_by) + return res_sql, with_data + class Variance(SimpleMetric): """Variance estimator. diff --git a/metrics_test.py b/metrics_test.py index 57c5e69..6523143 100644 --- a/metrics_test.py +++ b/metrics_test.py @@ -33,7 +33,7 @@ from pandas import testing # pylint: disable=g-long-lambda -METRICS_TO_TEST = metrics_to_test = [ +METRICS_TO_TEST = [ ('Ratio', metrics.Ratio('X', 'Y'), lambda d: d.X.sum() / d.Y.sum()), ('Sum', metrics.Sum('X'), lambda d: d.X.sum()), ('Count', metrics.Count('X'), lambda d: d.X.size), @@ -56,6 +56,12 @@ ), ('Max', metrics.Max('X'), lambda d: d.X.max()), ('Min', metrics.Min('X'), lambda d: d.X.min()), + ('Nth', metrics.Nth('X', 'Y', 1), lambda d: d.sort_values('Y').X.values[1]), + ( + 'Nth desc', + metrics.Nth('X', 'Y', 2, False), + lambda d: d.sort_values('Y', ascending=False).X.values[2], + ), ('Quantile', metrics.Quantile('X', 0.2), lambda d: d.X.quantile(0.2)), ('Variance', metrics.Variance('X', True), lambda d: d.X.var()), ( @@ -381,11 +387,57 @@ def test_weighted_quantile_multiple_quantiles_split_by_melted(self): names=['Metric', 'grp'])) testing.assert_frame_equal(output, expected) + def test_nth_na(self): + df = pd.DataFrame({'x': [np.nan, 1], 'w': [1, 0]}) + m = metrics.Nth('x', 'w', 1) + output = m.compute_on(df) + expected = pd.DataFrame({'2nd(x) sort by w asc': [np.NaN]}) + testing.assert_frame_equal(output, expected) + + def test_nth_dropna(self): + df = pd.DataFrame({'x': [np.nan, 1], 'w': [0, 1]}) + m = metrics.Nth('x', 'w', 0, dropna=True) + output = m.compute_on(df) + expected = pd.DataFrame({'1st(x) sort by w asc': [1.0]}) + testing.assert_frame_equal(output, expected) + + def test_nth_n_larger_than_df_len(self): + df = pd.DataFrame({'x': [np.nan, 1], 'w': [0, 1]}) + m = metrics.Nth('x', 'w', 2) + output = m.compute_on(df) + expected = pd.DataFrame({'3rd(x) sort by w asc': [np.NaN]}) + testing.assert_frame_equal(output, expected) + + def test_nth_negative_n(self): + output = metrics.Nth('x', 'w', -1) + expected = metrics.Nth('x', 'w', 0, False) + self.assertEqual(output, expected) + def test_cov_invalid_ddof(self): df = pd.DataFrame({'X': np.random.rand(3), 'w': np.array([1, 1, 2])}) m = metrics.Cov('X', 'X', ddof=5, fweight='w') self.assertTrue(pd.isnull(m.compute_on(df, return_dataframe=False))) + def test_large_metriclist(self): + df = pd.DataFrame({ + 'X': np.random.rand(100) + 5, + 'Y': np.random.rand(100) + 5, + 'w': np.random.rand(100) + 5, + 'w2': np.random.randint(100, size=100) + 5, + 'g1': np.random.randint(3, size=100), + 'g2': np.random.choice(list('ab'), size=100), + }) + ms = [] + for c in METRICS_TO_TEST: + m = copy.deepcopy(c[1]) + m.name = c[0] + m.where = 'X > %.4f' % (np.random.rand() * 20 + 5) + ms.append(m) + m = metrics.MetricList(ms) + actual = m.compute_on(df) + expected = pd.concat((c.compute_on(df) for c in m), axis=1) + testing.assert_frame_equal(expected, actual) + class TestCompositeMetric(absltest.TestCase): """Tests for composition of two metrics.""" @@ -857,6 +909,12 @@ def test_different_metrics_have_different_fingerprints(self): metrics.Mean('x', 'y'), metrics.Max('x'), metrics.Min('x'), + metrics.Nth('x', 'y', 2), + metrics.Nth('x', 'y', 3), + metrics.Nth('z', 'y', 2), + metrics.Nth('x', 'z', 2), + metrics.Nth('x', 'z', 2, False), + metrics.Nth('x', 'z', 2, False, True), metrics.Quantile('x'), metrics.Quantile('x', 0.2), metrics.Variance('x', True), diff --git a/models.py b/models.py index bbac45e..5e67a10 100644 --- a/models.py +++ b/models.py @@ -71,8 +71,10 @@ def __init__( """ if y and not isinstance(y, metrics.Metric): raise ValueError('y must be a Metric!') - if y and count_features(y) != 1: - raise ValueError('y must be a 1D array but is %iD!' % count_features(y)) + if y and operations.count_features(y) != 1: + raise ValueError( + 'y must be a 1D array but is %iD!' % operations.count_features(y) + ) self.group_by = [group_by] if isinstance(group_by, str) else group_by or [] if isinstance(x, Sequence): x = metrics.MetricList(x) @@ -82,7 +84,7 @@ def __init__( self.y = y child = metrics.MetricList((y, x)) self.model = model - self.k = count_features(x) + self.k = operations.count_features(x) self.model_name = model_name if not name and x and y: x_names = ( @@ -159,7 +161,7 @@ def __call__(self, child): model = super(Model, self).__call__(child) model.y = child[0] model.x = metrics.MetricList(child[1:]) - model.k = count_features(model.x) + model.k = operations.count_features(model.x) x_names = [m.name for m in model.x] model.name = '%s(%s ~ %s)' % ( model.model_name, @@ -170,31 +172,6 @@ def __call__(self, child): return model -def count_features(m: metrics.Metric): - """Gets the width of the result of m.compute_on().""" - if not m: - return 0 - if isinstance(m, Model): - return m.k - if isinstance(m, metrics.MetricList): - return sum([count_features(i) for i in m]) - if isinstance(m, operations.MetricWithCI): - return ( - count_features(m.children[0]) * 3 - if m.confidence - else count_features(m.children[0]) * 2 - ) - if isinstance(m, operations.Operation): - return count_features(m.children[0]) - if isinstance(m, metrics.CompositeMetric): - return max([count_features(i) for i in m.children]) - if isinstance(m, metrics.Quantile): - if m.one_quantile: - return 1 - return len(m.quantile) - return 1 - - class LinearRegression(Model): """A class that can fit a linear regression.""" diff --git a/models_test.py b/models_test.py index 246ac0b..b5e4798 100644 --- a/models_test.py +++ b/models_test.py @@ -335,28 +335,45 @@ def test_interaction_with_other_metric(self): def test_count_features(self): s = metrics.Sum('x') - self.assertEqual(models.count_features(metrics.Sum('x')), 1) - self.assertEqual(models.count_features(metrics.MetricList([s, s])), 2) + self.assertEqual(operations.count_features(metrics.Sum('x')), 1) + self.assertEqual(operations.count_features(metrics.MetricList([s, s])), 2) self.assertEqual( - models.count_features( - metrics.MetricList([metrics.Sum('x'), - metrics.MetricList([s])])), 2) + operations.count_features( + metrics.MetricList([metrics.Sum('x'), metrics.MetricList([s])]) + ), + 2, + ) self.assertEqual( - models.count_features(operations.AbsoluteChange('a', 'b', s)), 1) + operations.count_features(operations.AbsoluteChange('a', 'b', s)), 1 + ) self.assertEqual( - models.count_features( + operations.count_features( operations.AbsoluteChange( - 'a', 'b', metrics.MetricList([s, metrics.MetricList([s])]))), 2) + 'a', 'b', metrics.MetricList([s, metrics.MetricList([s])]) + ) + ), + 2, + ) self.assertEqual( - models.count_features( + operations.count_features( operations.AbsoluteChange( - 'a', 'b', - metrics.MetricList([ - operations.AbsoluteChange('a', 'b', - metrics.MetricList([s, s])) - ]))), 2) - self.assertEqual(models.count_features(metrics.Ratio('x', 'y')), 1) - self.assertEqual(models.count_features(metrics.MetricList([s, s]) / 2), 2) + 'a', + 'b', + metrics.MetricList( + [ + operations.AbsoluteChange( + 'a', 'b', metrics.MetricList([s, s]) + ) + ] + ), + ) + ), + 2, + ) + self.assertEqual(operations.count_features(metrics.Ratio('x', 'y')), 1) + self.assertEqual( + operations.count_features(metrics.MetricList([s, s]) / 2), 2 + ) def test_symmetrize_triangular(self): actual = models.symmetrize_triangular([1, 2, 3, 4, 5, 6]) diff --git a/operations.py b/operations.py index 9c10ea5..3f2363b 100644 --- a/operations.py +++ b/operations.py @@ -30,6 +30,31 @@ from scipy import stats +def count_features(m: metrics.Metric): + """Gets the width of the result of m.compute_on().""" + if not m: + return 0 + if isinstance(m, metrics.MetricList): + return sum([count_features(i) for i in m]) + if isinstance(m, MetricWithCI): + return ( + count_features(m.children[0]) * 3 + if m.confidence + else count_features(m.children[0]) * 2 + ) + if isinstance(m, (CUPED, PrePostChange)): + return count_features(m.children[0][0]) + if isinstance(m, Operation): + return count_features(m.children[0]) + if isinstance(m, metrics.CompositeMetric): + return max([count_features(i) for i in m.children]) + if isinstance(m, metrics.Quantile): + if m.one_quantile: + return 1 + return len(m.quantile) + return 1 + + class Operation(metrics.Metric): """A special kind of Metric that operates on other Metric instance(s). @@ -464,15 +489,10 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, """ if not isinstance(self, (PercentChange, AbsoluteChange)): raise ValueError('Not a PercentChange nor AbsoluteChange!') - local_filter = ( - sql.Filters(self.where_).add(local_filter).remove(global_filter) + cond_cols = sql.Columns(self.extra_index) + raw_table_sql, with_data = self.get_change_raw_sql( + table, split_by, global_filter, indexes, local_filter, with_data ) - - child = self.children[0] - cond_cols = sql.Columns(self.extra_split_by) - groupby = sql.Columns(split_by).add(cond_cols) - raw_table_sql, with_data = child.get_sql_and_with_clause( - table, groupby, global_filter, indexes, local_filter, with_data) raw_table = sql.Datasource(raw_table_sql, 'ChangeRaw') raw_table_alias = with_data.merge(raw_table) @@ -517,6 +537,19 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, sql.Join(raw_table_alias, base_table_alias, join=join, using=using), cond), with_data + def get_change_raw_sql( + self, table, split_by, global_filter, indexes, local_filter, with_data + ): + """Gets the query where the comparison will be carried out.""" + local_filter = ( + sql.Filters(self.where_).add(local_filter).remove(global_filter) + ) + groupby = sql.Columns(split_by).add(self.extra_split_by) + raw_table_sql, with_data = self.children[0].get_sql_and_with_clause( + table, groupby, global_filter, indexes, local_filter, with_data + ) + return raw_table_sql, with_data + class PercentChange(Comparison): """Percent change estimator on a Metric. @@ -715,6 +748,30 @@ def adjust_value(self, child, covariates, split_by): # 1. It's faster. See the comments in Metric.compute_slices(). # 2. It ensures that the result is formatted correctly. class Adjust(metrics.Metric): + """Adjusts the value by fitting controlling for the covariates. + + See the class doc for adjustment details. Essentially for every slice for + comparison, we fit a linear regression child = c + k * covariate and use c + as the adjusted value for PercentChange computation later. + Because we center covariate first, when there is only one covariate, k can + be computed as Covariance(child, covariate) / Var(covariate, covariate) + and c = avg(child) - k * avg(covariate). + """ + + def compute_slices(self, df, split_by: Optional[List[Text]] = None): + child = df.iloc[:, :len_child] + cov = df.iloc[:, len_child:] + if len(cov.columns) > 1: + return super(Adjust, self).compute_slices(df, split_by) + adjusted = df.groupby(split_by, observed=True).mean() + cov_col = cov.columns[0] + cov_adjusted = adjusted.iloc[:, -1] + for c in child: + theta = ( + metrics.Cov(c, cov_col) / metrics.Variance(cov_col) + ).compute_on(df, split_by, return_dataframe=False) + adjusted[c] = adjusted[c] - cov_adjusted * theta + return adjusted.iloc[:, :len_child] def compute(self, df_slice): child_slice = df_slice.iloc[:, :len_child] @@ -731,10 +788,96 @@ def compute_children_sql(self, table, split_by, execute, mode=None): child = child.iloc[:, :1] return self.adjust_value(child, covariates, split_by) - def get_sql_and_with_clause( + def get_change_raw_sql( self, table, split_by, global_filter, indexes, local_filter, with_data ): - raise NotImplementedError + """Generates PrePost-adjusted values for PercentChange computation. + + This function generats subquries like + WITH PrePostRaw AS (SELECT + split_by, + stratified_by, + condition_column, + child_metric, + covariate + FROM T + GROUP BY split_by, stratified_by, condition_column), + PrePostcovariateCentered AS (SELECT + split_by, + stratified_by, + condition_column, + child_metric, + covariate - AVG(covariate) OVER (PARTITION BY split_by) AS covariate + FROM PrePostRaw), + ChangeRaw AS (SELECT + split_by, + condition_column, + AVG(child_metric) - SAFE_DIVIDE(AVG(covariate) * COVAR_SAMP(child_metric, + covariate), VAR_SAMP(covariate)) AS child_metric + FROM PrePostcovariateCentered + GROUP BY split_by, condition_column) + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + The SQL instance for metric, without the WITH clause component. + The global with_data which holds all datasources we need in the WITH + clause. + """ + if count_features(self.children[0][1]) > 1: + raise NotImplementedError + local_filter = ( + sql.Filters(self.where_).add(local_filter).remove(global_filter) + ) + all_split_by = sql.Columns(split_by).add(self.extra_split_by) + all_indexes = sql.Columns(split_by).add(self.extra_index) + child_sql, with_data = self.children[0].get_sql_and_with_clause( + table, all_split_by, global_filter, indexes, local_filter, with_data) + child_table = sql.Datasource(child_sql, 'PrePostRaw') + child_table_alias = with_data.merge(child_table) + + split_by = split_by.aliases + all_split_by = all_split_by.aliases + all_indexes = all_indexes.aliases + cols = [ + sql.Column(c.alias, alias=c.alias_raw) + for c in child_sql.all_columns[:-1] + ] + covariate = child_sql.all_columns[-1].alias + covariate_mean = sql.Column(covariate, 'AVG({})', partition=split_by) + covariate_centered = (sql.Column(covariate) - covariate_mean).set_alias( + covariate + ) + cols.append(covariate_centered) + covariate_centered_sql = sql.Sql(cols, child_table_alias) + covariate_centered_table = sql.Datasource( + covariate_centered_sql, 'PrePostcovariateCentered' + ) + covariate_centered_table_alias = with_data.merge(covariate_centered_table) + + to_adjust = [] + for c in child_sql.all_columns[:-1]: + if c.alias in all_split_by: + continue + adjusted = metrics.Mean(c.alias) - metrics.Mean(covariate) * metrics.Cov( + c.alias, covariate + ) / metrics.Variance(covariate) + to_adjust.append(adjusted.set_name(c.alias_raw)) + return metrics.MetricList(to_adjust).get_sql_and_with_clause( + covariate_centered_table_alias, + all_indexes, + None, + all_indexes, + None, + with_data, + ) class CUPED(AbsoluteChange): @@ -742,7 +885,7 @@ class CUPED(AbsoluteChange): Computes the absolute change after controlling for preperiod metrics. Essentially, if the data only has a baseline and a treatment slice, CUPED - 1. centers the covariates + 1. centers the covariates (we skip it because it doesn't affect the result). 2. fit child ~ intercept + covariate. And the intercept is the adjusted effect and has a smaller variance than child. See https://exp-platform.com/cuped for more details. @@ -758,8 +901,7 @@ class CUPED(AbsoluteChange): children: A MetricList whose first element is the Metric we want to compute change on and the rest is the covariates for adjustment. include_base: A boolean for whether the baseline condition should be - included in the output. - And all other attributes inherited from Operation. + included in the output. And all other attributes inherited from Operation. """ def __init__(self, @@ -814,15 +956,6 @@ def adjust_value(self, child, covariates, split_by): The adjusted values of the child (post metrics). """ from sklearn import linear_model # pylint: disable=g-import-not-at-top - # Don't use "-=". For multiindex it might go wrong. The reason is DataFrame - # has different implementations for __sub__ and __isub__. ___isub__ tries - # to reindex to update in place which sometimes lead to lots of NAs. - if split_by: - covariates = ( - covariates - covariates.groupby(split_by, observed=True).mean() - ) - else: - covariates = covariates - covariates.mean() # Align child with covariates in case there is any missing slices. covariates = covariates.reorder_levels(child.index.names) aligned = pd.concat([child, covariates], axis=1) @@ -834,12 +967,34 @@ def adjust_value(self, child, covariates, split_by): # 1. It's faster. See the comments in Metric.compute_slices(). # 2. It ensures that the result is formatted correctly. class Adjust(metrics.Metric): + """Adjusts the value by fitting controlling for the covariates. + + Essentially we fit a linear regression child = c + θ * covariate. + and use child - θ * covariate as the adjusted value. When there is only + one covariate, θ can be computed as + Covariance(child, covariate) / Var(covariate, covariate) + """ + + def compute_slices(self, df, split_by: Optional[List[Text]] = None): + child = df.iloc[:, :len_child] + cov = df.iloc[:, len_child:] + if len(cov.columns) > 1: + return super(Adjust, self).compute_slices(df, split_by) + adjusted = df.groupby(split_by + extra_index, observed=True).mean() + cov_col = cov.columns[0] + cov_adjusted = adjusted.iloc[:, -1] + for c in child: + theta = ( + metrics.Cov(c, cov_col) / metrics.Variance(cov_col) + ).compute_on(df, split_by, return_dataframe=False) + adjusted[c] = adjusted[c] - cov_adjusted * theta + return adjusted.iloc[:, :len_child] def compute(self, df_slice): child_slice = df_slice.iloc[:, :len_child] cov = df_slice.iloc[:, len_child:] adjusted = df_slice.groupby(extra_index, observed=True).mean() - for c in aligned.iloc[:, :len_child]: + for c in child_slice: theta = lm.fit(cov, child_slice[c]).coef_ adjusted[c] = adjusted[c] - adjusted.iloc[:, len_child:].dot(theta) return adjusted.iloc[:, :len_child] @@ -853,10 +1008,102 @@ def compute_children_sql(self, table, split_by, execute, mode=None): child = child.iloc[:, :1] return self.adjust_value(child, covariates, split_by) - def get_sql_and_with_clause( + def get_change_raw_sql( self, table, split_by, global_filter, indexes, local_filter, with_data ): - raise NotImplementedError + """Generates CUPED-adjusted values for AbsoluteChange computation. + + This function generats subquries like + WITH CUPEDRaw AS (SELECT + split_by, + stratified_by, + condition_column, + child_metric, + covariate + FROM T + GROUP BY split_by, stratified_by, condition_column), + CUPEDTheta AS (SELECT + split_by, + SAFE_DIVIDE(COVAR_SAMP(child_metric, covariate), VAR_SAMP(covariate)) + AS child_metric_theta + FROM CUPEDRaw + GROUP BY split_by), + ChangeRaw AS (SELECT + split_by, + condition_column, + AVG(child_metric) - AVG(child_metric_theta * covariate) AS child_metric + FROM CUPEDRaw + FULL JOIN + CUPEDTheta + USING (split_by) + GROUP BY split_by, condition_column) + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + The SQL instance for metric, without the WITH clause component. + The global with_data which holds all datasources we need in the WITH + clause. + """ + if count_features(self.children[0][1]) > 1: + raise NotImplementedError + local_filter = ( + sql.Filters(self.where_).add(local_filter).remove(global_filter) + ) + all_split_by = sql.Columns(split_by).add(self.extra_split_by) + all_indexes = sql.Columns(split_by).add(self.extra_index) + child_sql, with_data = self.children[0].get_sql_and_with_clause( + table, all_split_by, global_filter, indexes, local_filter, with_data) + child_table = sql.Datasource(child_sql, 'CUPEDRaw') + child_table_alias = with_data.merge(child_table) + + split_by = split_by.aliases + all_split_by = all_split_by.aliases + all_indexes = all_indexes.aliases + cols = [] + for c in child_sql.columns: + if c.alias not in all_split_by: + cols.append(c) + covariate = cols.pop().alias + theta = metrics.MetricList( + [ + ( + metrics.Cov(c.alias, covariate) / metrics.Variance(covariate) + ).set_name(f'{c.alias}_theta') + for c in cols + ] + ) + theta_sql, with_data = theta.get_sql_and_with_clause( + child_table_alias, split_by, None, all_indexes, None, with_data + ) + theta_table = sql.Datasource(theta_sql, 'CUPEDTheta') + theta_table_alias = with_data.merge(theta_table) + + to_adjust = sql.Columns( + [ + ( + sql.Column(c.alias, 'AVG({})') + - sql.Column((c.alias, covariate), 'AVG({}_theta * {})') + ).set_alias(c.alias_raw) + for c in cols + ] + ) + join = 'FULL' if split_by else 'CROSS' + adjusted_sql = sql.Sql( + to_adjust, + sql.Join( + child_table_alias, theta_table_alias, using=split_by, join=join + ), + groupby=all_indexes, + ) + return adjusted_sql, with_data class MH(Comparison): @@ -1676,7 +1923,30 @@ def compute_through_sql(self, table, split_by, execute, mode): ) from e def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): - """Computes self in a SQL query and process the result.""" + """Computes self in a SQL query and process the result. + + We first execute the SQL query then process the result. + When confidence is not specified, for each child Metric, the SQL query + returns two columns. The result columns are like + metric1, metric1 jackknife SE, metric2, metric2 jackknife SE, ... + When confidence is specified, for each child Metric, the SQL query + returns four columns. The result columns are like + metric1, metric1 CI lower, metric1 CI upper, + metric2, metric2 CI lower, metric2 CI upper, + ... + metricN, metricN CI lower, metricN CI upper, + metric 1 base value, metric 2 base value, ..., metricN base value. + The base value columns only exist when the child is a PercentChange or + AbsoluteChange. Base values are the raw value the comparison is carried out. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + execute: A function that can executes a SQL query and returns a DataFrame. + + Returns: + The result DataFrame of Jackknife/Bootstrap. + """ res = super(MetricWithCI, self).compute_on_sql_sql_mode(table, split_by, execute) sub_dfs = [] @@ -1730,21 +2000,92 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): return self.add_base_to_res(res, base) def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None): + """Computes the child in SQL and the rest in Python. + + There are two parts. First we compute the standard errors. Then we join it + with point estimate. When the child is a Comparison, we also compute the + base value for the display(). + For the first part, we preaggregate the data when possible. See the docs of + Bootstrap.compute_slices about the details of the preaggregation. The + structure of this part is similar to to_sql(). Note that we apply + preaggregation to both Jackknife and Bootstrap even though + Jackknife.compute_slices doesn't have preaggregation. We don't do + preaggregation in Jackknife.compute_slices because it already cuts the + corner to get leave-one-out-estimates. Adding preaggregation actually slow + things down. Here in 'mixed' mode we don't cut the corner to get LOO so + preaggregation makes sense. + Then we compute the point estimate, join it with the standard error, and do + some data manipulations. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + execute: A function that can executes a SQL query and returns a DataFrame. + mode: It's always 'mixed' or 'magic' otherwise we won't be here. + + Returns: + The result DataFrame of Jackknife/Bootstrap. + """ batch_size = self._runtime_batch_size or self.sql_batch_size - replicates = self.compute_children_sql( - table, split_by, execute, mode, batch_size + if self.has_been_preaggregated or not self.can_precompute(): + if self.where: + table = sql.Sql(None, table, self.where) + self_no_filter = copy.deepcopy(self) + self_no_filter.where = None + return self_no_filter.compute_on_sql_mixed_mode( + table, split_by, execute, mode + ) + + replicates = self.compute_children_sql( + table, split_by, execute, mode, batch_size + ) + std = self.compute_on_children(replicates, split_by) + point_est = self.compute_child_sql( + table, split_by, execute, True, mode=mode + ) + res = point_est.join(utils.melt(std)) + if self.confidence: + res[self.prefix + ' CI-lower'] = ( + res.iloc[:, 0] - res[self.prefix + ' CI-lower'] + ) + res[self.prefix + ' CI-upper'] += res.iloc[:, 0] + res = utils.unmelt(res) + base = self.compute_change_base(table, split_by, execute, mode) + return self.add_base_to_res(res, base) + + expanded, _ = utils.get_fully_expanded_equivalent_metric_tree(self) + if self != expanded: + return expanded.compute_on_sql_mixed_mode(table, split_by, execute, mode) + + # The filter has been taken care of in preaggregation. + expanded.where = None + expanded = utils.push_filters_to_leaf(expanded) + all_split_by = ( + split_by + + list(utils.get_extra_split_by(expanded, return_superset=True)) + + [expanded.unit] ) - std = self.compute_on_children(replicates, split_by) - point_est = self.compute_child_sql( - table, split_by, execute, True, mode=mode) - res = point_est.join(utils.melt(std)) - if self.confidence: - res[self.prefix + - ' CI-lower'] = res.iloc[:, 0] - res[self.prefix + ' CI-lower'] - res[self.prefix + ' CI-upper'] += res.iloc[:, 0] - res = utils.unmelt(res) - base = self.compute_change_base(table, split_by, execute, mode) - return self.add_base_to_res(res, base) + leaf = utils.get_leaf_metrics(expanded) + cols = [ + l.get_sql_columns(l.where_).set_alias(get_preaggregated_metric_var(l)) + for l in leaf + ] + preagg = sql.Sql(cols, table, self.where_, all_split_by) + equiv = get_preaggregated_metric_tree(expanded) + equiv.unit = sql.Column(equiv.unit).alias + split_by = sql.Columns(split_by).aliases + for m in equiv.traverse(): + if isinstance(m, metrics.Metric): + m.extra_index = sql.Columns(m.extra_index).aliases + m.extra_split_by = sql.Columns(m.extra_split_by).aliases + if isinstance(equiv, Bootstrap): + # When each unit only has one row after preaggregation, we sample by + # rows. + if not utils.get_extra_split_by(equiv, return_superset=True): + equiv.unit = None + else: + equiv.has_local_filter = any([l.where for l in leaf]) + return equiv.compute_on_sql_mixed_mode(preagg, split_by, execute, mode) def compute_children_sql(self, table, @@ -1756,6 +2097,20 @@ def compute_children_sql(self, raise NotImplementedError def to_sql(self, table, split_by=None): + """Generates SQL query for the metric. + + The SQL generation is actually delegated to get_sql_and_with_clause(). This + function does the preaggregation when possible. See the docs of + Bootstrap.compute_slices() about the details of the preaggregation. The + structure of this function is similar to compute_on_sql_mixed_mode(). + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + + Returns: + The query that does Jackknife/Bootstrap. + """ if not isinstance(self, (Jackknife, Bootstrap)): raise NotImplementedError split_by = [split_by] if isinstance(split_by, str) else list(split_by or []) @@ -1896,6 +2251,8 @@ def get_sql_and_with_clause( ): has_base_vals = True base_metric = copy.deepcopy(child.children[0]) + if isinstance(child, (CUPED, PrePostChange)): + base_metric = base_metric[0] if child.where: base_metric.add_where(child.where_) base, with_data = base_metric.get_sql_and_with_clause( @@ -2157,7 +2514,83 @@ def get_stderrs(bucket_estimates): def compute_children_sql( self, table, split_by, execute, mode=None, batch_size=None ): - """Compute the children on leave-one-out data in SQL.""" + """Compute the children on leave-one-out data in SQL. + + When + 1. the data have been preaggregated, which means all the leaf Metrics are + Sum and Count, + 2. batch_size is None, + we compute all the leaf nodes on the preaggregated date, grouped by all the + split_by columns we ever need, including self.unit. Then we cut the corner + to get the leave-one-out estimates. See the doc of compute_slices() for more + details. + Otherwise, if batch_size is None or 1, we iterate unique units in the data. + In iteration k, we compute the child on + 'SELECT * FROM table WHERE unit != k' to get the leave-k-out estimate. + If batch_size is larger than 1, in each iteration, we compute the child on + SELECT + * + FROM UNNEST([1, 2, ..., batch_size]) AS meterstick_resample_idx + JOIN + table + ON meterstick_resample_idx != unit), split by meterstick_resample_idx in + addition. + At last we concat the estimates. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + execute: A function that can executes a SQL query and returns a DataFrame. + mode: It's always 'mixed' or 'magic' otherwise we won't be here. + batch_size: The number of units we handle in one iteration. + + Returns: + A DataFrame contains all the leave-one-out estimates. Each row is a child + metric * split_by slice and each column is an estimate. + """ + if self.has_been_preaggregated and not batch_size: + all_split_by = ( + split_by + + [self.unit] + + list(utils.get_extra_split_by(self, return_superset=True)) + ) + all_split_by_no_unit = split_by + list( + utils.get_extra_split_by(self, return_superset=True) + ) + filter_in_leaf = utils.push_filters_to_leaf(self) + leafs = utils.get_leaf_metrics(filter_in_leaf) + for m in leafs: # filters have been handled in preaggregation + m.where = None + # Make sure the output column names are the same as the var so we can + # compute_child on it later. has_been_preaggregated being True means that + # all leaf metrics are Sum or Count. + leafs = copy.deepcopy(leafs) + for m in leafs: + m.name = m.var + leafs = metrics.MetricList(tuple(set(leafs))) + if len(leafs) == 1: + leafs.name = leafs.children[0].name + bucket_res = self.compute_util_metric_on_sql( + leafs, table, all_split_by, execute, mode=mode + ) + + if all_split_by_no_unit: + total = bucket_res.groupby( + level=all_split_by_no_unit, observed=True + ).sum() + else: + total = bucket_res.sum() + bucket_res = bucket_res.fillna(0) + bucket_res = utils.adjust_slices_for_loo(bucket_res, split_by, bucket_res) + loo = total - bucket_res + if all_split_by_no_unit: + # The levels might get messed up. + loo = loo.reorder_levels(all_split_by) + res = filter_in_leaf.children[0].compute_on( + loo, split_by + [self.unit], melted=True + ) + return [res.unstack(self.unit)] + batch_size = batch_size or 1 slice_and_units = sql.Sql( sql.Columns(split_by + [self.unit], distinct=True), @@ -2394,8 +2827,24 @@ def get_samples(self, df, split_by=None): def compute_children_sql( self, table, split_by, execute, mode=None, batch_size=None ): - """Compute the children on resampled data in SQL.""" - batch_size = batch_size or 1000 + """Compute the children on resampled data in SQL. + + We compute the child on bootstrapped data in a batched way. We bootstrap for + batch_size in one iteration. Namely, it's equivalent to compute self but + setting n_replicates to batch_size. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + execute: A function that can executes a SQL query and returns a DataFrame. + mode: It's always 'mixed' or 'magic' otherwise we won't be here. + batch_size: The number of units we handle in one iteration. + + Returns: + A DataFrame contains all the bootstrap estimates. Each row is a child + metric * split_by slice and each column is an estimate. + """ + batch_size = batch_size or self.n_replicates global_filter = utils.get_global_filter(self) util_metric = copy.deepcopy(self) util_metric.n_replicates = batch_size diff --git a/operations_test.py b/operations_test.py index 2e85b56..27c7c5c 100644 --- a/operations_test.py +++ b/operations_test.py @@ -695,13 +695,21 @@ def test_mh_fail_on_nonratio_metric(self): ( 'PrePostChange', operations.PrePostChange( - 'grp', 0, metrics.Sum('x'), metrics.Sum('y'), 'cookie' + 'grp', + 0, + metrics.Sum('x'), + [metrics.Sum('y'), metrics.Sum('y') ** 2], + 'cookie', ), ), ( 'CUPED', operations.CUPED( - 'grp', 0, metrics.Sum('x'), metrics.Sum('y'), 'cookie' + 'grp', + 0, + metrics.Sum('x'), + [metrics.Sum('y'), metrics.Sum('y') ** 2], + 'cookie', ), ), ( @@ -2054,9 +2062,17 @@ def test_poissonbootstrap_unit_cache_across_samples(self, enable_opt): def set_up_metric(m): m = copy.deepcopy(m) if not m.children: - m = m( - metrics.MetricList((metrics.Ratio('x', 'y'), metrics.Ratio('y', 'x'))) - ) + if isinstance(m, (operations.CUPED, operations.PrePostChange)): + m = m( + metrics.MetricList(( + metrics.Ratio('x', 'y'), + metrics.MetricList((metrics.Ratio('y', 'x'), metrics.Sum('y'))), + )) + ) + else: + m = m( + metrics.MetricList((metrics.Ratio('x', 'y'), metrics.Ratio('y', 'x'))) + ) return m