diff --git a/meterstick_custom_metrics.ipynb b/meterstick_custom_metrics.ipynb index 93b0ebb..2d5f40c 100644 --- a/meterstick_custom_metrics.ipynb +++ b/meterstick_custom_metrics.ipynb @@ -2233,7 +2233,7 @@ }, "source": [ "##Custom Operation\n", - "Writing a custom `Operation` is more complex. Typically an `Operation` needs to compute some util `Metric`s. A common one is its child `Metric`. The tricky part is how to make sure the additional computations interact correctly with the cache. First take a look at the Caching section below to understand how caching works in `Meterstick`. Then here is a decision tree to help you.\n", + "Writing a custom `Operation` is more complex. Typically an `Operation` needs to compute some util `Metric`s. A common one is its child `Metric`. The tricky part is how to make sure the additional computations interact correctly with the cache. First take a look at the Caching section above to understand how caching works in `Meterstick`. Then here is a decision tree to help you.\n", "\n", "\n", " +-----------------------------------+ \n", @@ -2272,6 +2272,22 @@ "1. We don't check if the input data is consistent when using caching so users need to make sure the util `Metric` is computed on the same data as other `Metric`s are. If the util `Metric` is computed on the data passed to the method your are overriding, it's safe to use the recommended methods below that will save the result to cache.\n", "1. A very common scenerio is that an `Operation` needs to compute the child `Metric` first. Use `compute_child` or `compute_child_sql` to do so. Oftentimes the child `Metric` needs to be computed with an extended `split_by`, for example, `PercentChange('grp', 'base', Sum(x)).compute_on(data, 'foo')` will need to compute `Sum(x).compute_on(data, ['foo', 'grp'])` first. The recommended way is that you register the extra dimensions, `grp` in the example, in `__init__()`. Then the default `compute_child` and `compute_child_sql` will return the result of the child `Metric` you want. You only need to implement the `compute_on_children` then.\n", "1. The extra dimensions are stored in `self.extra_split_by`. There is another attribute `extra_index` which stores the indexes the `Operation` adds. When unspecified, it will be set to the value of `self.extra_split_by`. For complex `Operations` the two can differ. For example, in the computation of `MH(condition_column, baseline_key, stratified_by, child)`, we need to compute `child.compute_on(df, split_by + [condition_column, stratified_by])` so the `extra_split_by` is `[condition_column, stratified_by]`. However, `stratified_by` won't show up in the final result so you need to explicitly set the `extra_index` to `condition_column`.\n", + "1. If you need to store part of the extra dimensions in another attribute, make it a property that is dynamically computed from `extra_split_by` or `extra_index`. Do NOT make it statically assigned in the `__init__`. For example, in `MH`, we don't do\n", + "```\n", + "def __init__(self, ..., stratified_by, ...):\n", + " ...\n", + " self.stratified_by = stratified_by\n", + "```\n", + "Instead, we do\n", + "```\n", + "@property\n", + "def stratified_by(self):\n", + " return self.extra_split_by[len(self.extra_index):]\n", + "@stratified_by.setter\n", + "def stratified_by(self, stratified_by):\n", + " self.extra_split_by[len(self.extra_index):] = stratified_by\n", + "```\n", + "The reason is that in SQL generation, if the dimensions in the attribute have any special character, they will get sanitized so queries based on them need to be adjusted. We will take care of `extra_split_by` and `extra_index` but we don't have knowledge about your attributes so if they are static the SQL query might be invalid.\n", "1. When you need to manually construct the extended `split_by`, make the extra dimensions come after the original `split_by`. That's how we do it for all built-in `Operation`s so the caching could be maximized.\n", "1. Try to vectorize the `Operation` as much as possible. If it's hard to vectorize, often you can at least compute the child `Metric` in a vectorized way by calling `compute_child`. Then implement `compute(self, df_slice)` which handles a slice of the data returned by the child `Metric`. See `CumulativeDistribution` below for an example.\n", "1. When you need to compute a util `Metric` other than the child `Metric`, use `compute_util_metric_on` or `compute_util_metric_on_sql`. `compute_child` and `compute_child_sql` are just wrappers of them for the child `Metric`.\n", diff --git a/metrics.py b/metrics.py index 67cc98d..768f135 100644 --- a/metrics.py +++ b/metrics.py @@ -677,7 +677,7 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): """Executes the query from to_sql() and process the result.""" query = self.to_sql(table, split_by) res = execute(str(query)) - extra_idx = list(utils.get_extra_idx(self, return_superset=True)) + extra_idx = list(self.get_extra_idx(return_superset=True)) indexes = split_by + extra_idx if split_by else extra_idx columns = [a.alias_raw for a in query.groupby.add(query.columns)] columns[:len(indexes)] = indexes @@ -692,7 +692,7 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None): """Generates SQL query for the metric.""" global_filter = utils.get_global_filter(self) indexes = sql.Columns(split_by).add( - utils.get_extra_idx(self, return_superset=True) + self.get_extra_idx(return_superset=True) ) with_data = sql.Datasources() if isinstance(table, sql.Sql) and table.with_data: @@ -941,6 +941,32 @@ def add_edges(metric): add_edges(self) return dot.to_string() + def get_extra_idx(self, return_superset=False): + """Collects the extra indexes added by self and its descendants. + + Args: + return_superset: If to return the superset of extra indexes if metric has + incompatible indexes. + + Returns: + A tuple of column names which are just the index of metric.compute_on(df). + """ + extra_idx = self.extra_index[:] + children_idx = [ + c.get_extra_idx(return_superset) + for c in self.children + if utils.is_metric(c) + ] + if len(set(children_idx)) > 1: + if not return_superset: + raise ValueError('Incompatible indexes!') + children_idx_superset = set() + children_idx_superset.update(*children_idx) + children_idx = [list(children_idx_superset)] + if children_idx: + extra_idx += list(children_idx[0]) + return tuple(extra_idx) + def traverse(self, include_self=True, include_constants=False): ms = [self] if include_self else list(self.children) while ms: @@ -1291,7 +1317,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ - utils.get_extra_idx(self) # Check if indexes are compatible. + self.get_extra_idx() # Check if indexes are compatible. local_filter = ( sql.Filters(self.where_).add(local_filter).remove(global_filter) ) diff --git a/models.py b/models.py index 5e67a10..4e3fdb4 100644 --- a/models.py +++ b/models.py @@ -17,6 +17,7 @@ from __future__ import division from __future__ import print_function +import copy import itertools from typing import List, Optional, Sequence, Text, Union @@ -75,23 +76,13 @@ def __init__( raise ValueError( 'y must be a 1D array but is %iD!' % operations.count_features(y) ) - self.group_by = [group_by] if isinstance(group_by, str) else group_by or [] - if isinstance(x, Sequence): - x = metrics.MetricList(x) + if isinstance(x, metrics.Metric): + x = [x] child = None if x and y: - self.x = x - self.y = y - child = metrics.MetricList((y, x)) + child = metrics.MetricList([y] + x) self.model = model - self.k = operations.count_features(x) self.model_name = model_name - if not name and x and y: - x_names = ( - [m.name for m in x] if isinstance(x, metrics.MetricList) else [x.name] - ) - name = '%s(%s ~ %s)' % (model_name, y.name, ' + '.join(x_names)) - name_tmpl = '%s Coefficient: {}' % name additional_fingerprint_attrs = ( [additional_fingerprint_attrs] if isinstance(additional_fingerprint_attrs, str) @@ -99,7 +90,7 @@ def __init__( ) super(Model, self).__init__( child, - name_tmpl, + None, group_by, [], name=name, @@ -132,6 +123,17 @@ def compute(self, df): def compute_through_sql(self, table, split_by, execute, mode): try: + if ( + not mode + and isinstance(self, (LinearRegression, Ridge)) + and not self.normalize + and self.k > 5 + ): + print( + 'INFO: SQL generation for your Model can be slow because the number' + ' of features > 5. Try compute_on_sql(mode="mixed") (for small' + ' data) or compute_on_sql(mode="magic") (for large data).' + ) if mode == 'magic': if self.where: table = sql.Sql(None, table, self.where_) @@ -155,22 +157,66 @@ def compute_through_sql(self, table, split_by, execute, mode): def compute_on_sql_magic_mode(self, table, split_by, execute): raise NotImplementedError - def __call__(self, child): - if not isinstance(child, metrics.MetricList): - raise ValueError(f'Model can only take a MetricList but got {child}!') - model = super(Model, self).__call__(child) - model.y = child[0] - model.x = metrics.MetricList(child[1:]) - model.k = operations.count_features(model.x) - x_names = [m.name for m in model.x] - model.name = '%s(%s ~ %s)' % ( - model.model_name, - model.y.name, + @property + def y(self): + if not self.children or not isinstance( + self.children[0], metrics.MetricList + ): + raise ValueError('y must be a Metric!') + return self.children[0][0] + + @property + def x(self): + if not self.children or not isinstance( + self.children[0], metrics.MetricList + ): + raise ValueError('x must be a MetricList!') + return metrics.MetricList(self.children[0][1:]) + + @property + def k(self): + return operations.count_features(self.x) + + @property + def name(self): + if self.name_: # pytype: disable=attribute-error + return self.name_ # pytype: disable=attribute-error + if not self.children: + return self.model_name + x_names = [m.name for m in self.x] + return '%s(%s ~ %s)' % ( + self.model_name, + self.y.name, ' + '.join(x_names), ) - model.name_tmpl = model.name + ' Coefficient: {}' + + @name.setter + def name(self, name): + self.name_ = name + + @property + def name_tmpl(self): + if self.name_tmpl_: # pytype: disable=attribute-error + return self.name_tmpl_ # pytype: disable=attribute-error + return self.name + ' Coefficient: {}' + + @name_tmpl.setter + def name_tmpl(self, name_tmpl): + self.name_tmpl_ = name_tmpl + + @property + def group_by(self): + return self.extra_split_by + + def __call__(self, child: metrics.Metric): + model = copy.deepcopy(self) if self.children else self + model.children = (child,) return model + def get_extra_idx(self, return_superset=False): + # Model blocks the propagation of extra split_by from the descendants. + return () + class LinearRegression(Model): """A class that can fit a linear regression.""" @@ -193,6 +239,21 @@ def __init__( y, x, group_by, model, 'OLS', where, name, fit_intercept, normalize ) + def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, + local_filter, with_data): + return Ridge( + self.y, + self.x, + self.group_by, + 0, + self.fit_intercept, + self.normalize, + self.where_, + self.name, + ).get_sql_and_with_clause( + table, split_by, global_filter, indexes, local_filter, with_data + ) + def compute_on_sql_magic_mode(self, table, split_by, execute): return Ridge( self.y, @@ -251,11 +312,95 @@ def __init__( ) self.alpha = alpha + def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, + local_filter, with_data): + """Gets the SQL query and WITH clause. + + First we get the query that computes all the elements of X'X and X'y. This + step is same to that in the 'magic' mode. Then we get the elements of + (X'X)^(-1)*(X'y) by doing symbolic computation in SymPy, and translate + them to SQL queries. + + Args: + table: The table we want to query from. + split_by: The columns that we use to split the data. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + The SQL instance for metric, without the WITH clause component. + The global with_data which holds all datasources we need in the WITH + clause. + """ + if self.normalize: + raise NotImplementedError( + 'SQL generator is not implemented for models with normalization.' + ) + + import sympy # pylint: disable=g-import-not-at-top + + xs, sufficient_stats, _, _ = get_sufficient_stats_elements_sql( + self, + table, + split_by, + None, + False, + self.alpha, + global_filter, + indexes, + local_filter, + with_data, + ) + with_data.merge(sufficient_stats.with_data) + sufficient_stats.with_data = None + sufficient_stats_table = sql.Datasource( + sufficient_stats, 'SufficientStatElements' + ) + sufficient_stats_alias = with_data.merge(sufficient_stats_table) + n = len(xs) + bool(self.fit_intercept) + split_by = sql.Columns(split_by.aliases) + sufficient_stats_cols = [ + c for c in sufficient_stats.columns.aliases if c not in split_by + ] + n_x_t_x_elements = n * (n + 1) // 2 - bool(self.fit_intercept) + x_t_x_elements = sufficient_stats_cols[:n_x_t_x_elements] + x_t_y_elements = sufficient_stats_cols[ + n_x_t_x_elements : n_x_t_x_elements + n + ] + if self.fit_intercept: + x_t_x_elements = [1] + x_t_x_elements + penalty = 0 + if isinstance(self, Ridge) and self.alpha: + n_obs = sufficient_stats_cols[-1] + # We use AVG() to compute x_t_x so the penalty needs to be scaled. + penalty = self.alpha / sympy.Symbol(n_obs) + coefs = utils.get_ridge_coefficients( + x_t_x_elements, x_t_y_elements, self.fit_intercept, penalty + ) + xs = xs.raw_aliases + cols = sql.Columns(split_by) + if self.fit_intercept: + xs = ['intercept'] + xs + for x, c in zip(xs, coefs): + # ccode prints x**2 to pow(x, 2) which works in SQL. + cols.add( + [sql.Column(sympy.printing.ccode(c), alias=self.name_tmpl.format(x))] + ) + return sql.Sql(cols, sufficient_stats_alias), with_data + def compute_on_sql_magic_mode(self, table, split_by, execute): # Never normalize for the sufficient_stats. Normalization is handled in # compute_ridge_coefs() instead. xs, sufficient_stats, _, _ = get_sufficient_stats_elements( - self, table, split_by, execute, normalize=False, include_n_obs=True + self, + table, + split_by, + execute, + normalize=False, + include_n_obs=self.alpha, ) return apply_algorithm_to_sufficient_stats_elements( sufficient_stats, split_by, compute_ridge_coefs, xs, self @@ -267,7 +412,6 @@ def get_sufficient_stats_elements( table, split_by, execute, - fit_intercept=None, normalize=None, include_n_obs=False, ): @@ -278,7 +422,6 @@ def get_sufficient_stats_elements( table: The table we want to query from. split_by: The columns that we use to split the data. execute: A function that can executes a SQL query and returns a DataFrame. - fit_intercept: If to include intercept in the model. normalize: If to normalize the X. Note that only has effect when m.fit_intercept is True, which is consistent to sklearn. include_n_obs: If to include the number of observations in the return. @@ -303,43 +446,112 @@ def get_sufficient_stats_elements( norms: Nonempty only when normalize. A pd.DataFrame which holds the l2-norm values of all centered-x columns. """ - fit_intercept = m.fit_intercept if fit_intercept is None else fit_intercept if normalize is None: normalize = m.normalize and m.fit_intercept + xs_cols, sufficient_stats_elements, avg_x, norms = ( + get_sufficient_stats_elements_sql( + m, + table, + split_by, + execute, + normalize, + include_n_obs, + ) + ) + sufficient_stats_elements = execute(str(sufficient_stats_elements)) + if normalize: + col_names = list(sufficient_stats_elements.columns) + avg_x_names = [f'x{i}' for i in range(len(xs_cols))] + sufficient_stats_elements[avg_x_names] = 0 + sufficient_stats_elements = sufficient_stats_elements[ + col_names[: len(split_by)] + avg_x_names + col_names[len(split_by) :] + ] + return xs_cols, sufficient_stats_elements, avg_x, norms + + +def get_sufficient_stats_elements_sql( + m, + table, + split_by, + execute, + normalize=None, + include_n_obs=False, + global_filter=None, + indexes=None, + local_filter=None, + with_data=None, +): + """Generates the SQL columns for the elements of X'X and X'y. + + Args: + m: A Model instance. + table: The table we want to query from. + split_by: The columns that we use to split the data. + execute: A function that can executes a SQL query and returns a DataFrame. + normalize: If to normalize the X. Note that only has effect when + m.fit_intercept is True, which is consistent to sklearn. + include_n_obs: If to include the number of observations in the return. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: A global variable that contains all the WITH clauses we need. + + Returns: + xs: A list of the column names of x1, x2, ... + sufficient_stats_elements: A SQL query that has all unique elements of + sufficient stats. Each row corresponds to one slice in split_by. The + columns are + split_by, + avg(x0), avg(x1), ..., # if fit_intercept + avg(x0 * x0), avg(x0 * x1), avg(x0 * x2), avg(x1 * x2), ..., + avg(y), # if fit_intercept + avg(x0 * y), avg(x1 * y), ..., + n_observation # if include_n_obs. + The column are named as + split_by, x0, x1,..., x0x0, x0x1,..., y, x0y, x1y,..., n_obs. + avg_x: Nonempty only when normalize. A pd.DataFrame which holds the + avg(x0), avg(x1), ... of the UNNORMALIZED x. + Don't confuse it with the ones in the sufficient_stats_elements, which are + the average of normalized x, which are just 0s. + norms: Nonempty only when normalize. A pd.DataFrame which holds the l2-norm + values of all centered-x columns. + """ table, with_data, xs_cols, y, avg_x, norms = get_data( - m, table, split_by, execute, normalize + m, + table, + split_by, + execute, + normalize, + global_filter, + indexes, + local_filter, + with_data, ) xs = xs_cols.aliases - x_t_x = [] - x_t_y = [] - if m.fit_intercept: - if not normalize: - x_t_x = [sql.Column(f'AVG({x})', alias=f'x{i}') for i, x in enumerate(xs)] - x_t_y = [sql.Column(f'AVG({y})', alias='y')] - for i, x1 in enumerate(xs): - for j, x2 in enumerate(xs[i:]): - x_t_x.append(sql.Column(f'AVG({x1} * {x2})', alias=f'x{i}x{i + j}')) - x_t_y += [ - sql.Column(f'AVG({x} * {y})', alias=f'x{i}y') for i, x in enumerate(xs) - ] + x_t_x, x_t_y = utils.get_x_t_x_and_x_t_y_cols( + xs, y, '', m.fit_intercept, normalize + ) cols = sql.Columns(x_t_x + x_t_y) if include_n_obs: cols.add(sql.Column('COUNT(*)', alias='n_obs')) sufficient_stats_elements = sql.Sql( cols, table, groupby=sql.Columns(split_by).aliases, with_data=with_data ) - sufficient_stats_elements = execute(str(sufficient_stats_elements)) - if normalize: - col_names = list(sufficient_stats_elements.columns) - avg_x_names = [f'x{i}' for i in range(len(xs))] - sufficient_stats_elements[avg_x_names] = 0 - sufficient_stats_elements = sufficient_stats_elements[ - col_names[: len(split_by)] + avg_x_names + col_names[len(split_by) :] - ] return xs_cols, sufficient_stats_elements, avg_x, norms -def get_data(m, table, split_by, execute, normalize=False): +def get_data( + m, + table, + split_by, + execute, + normalize=False, + global_filter=None, + indexes=None, + local_filter=None, + with_data=None, +): """Retrieves the data that the model will be fit on. We compute a Model by first computing its children, and then fitting @@ -356,6 +568,12 @@ def get_data(m, table, split_by, execute, normalize=False): split_by: The columns that we use to split the data. execute: A function that can executes a SQL query and returns a DataFrame. normalize: If the Model normalizes x. + global_filter: The sql.Filters that can be applied to the whole Metric + tree. + indexes: The columns that we shouldn't apply any arithmetic operation. + local_filter: The sql.Filters that have been accumulated so far. + with_data: The WITH clause that holds all necessary subqueries so we can + query the `table`. Returns: table: A string representing the table name which we can query from. The @@ -370,9 +588,20 @@ def get_data(m, table, split_by, execute, normalize=False): norms: Nonempty only when normalize is True. A pd.DataFrame which holds the l2-norm values of all centered-x columns. """ - data = m.children[0].to_sql(table, split_by + m.group_by) - with_data = data.with_data - data.with_data = None + # All filters are global when getting data to fit. + global_filter = sql.Filters(global_filter).add(local_filter).add(m.where_) + if indexes is None: + indexes = sql.Columns(split_by) + data, with_data = m.children[0].get_sql_and_with_clause( + table, + sql.Columns(split_by).add(m.extra_split_by), + global_filter, + sql.Columns(indexes) + .add(m.extra_split_by) + .add(m.children[0].get_extra_idx()), + sql.Filters(), + sql.Datasources(with_data), + ) table = with_data.merge(sql.Datasource(data, 'DataToFit')) y = data.columns[-m.k - 1].alias xs_cols = sql.Columns(data.columns[-m.k :]) @@ -488,7 +717,7 @@ def compute_ridge_coefs(sufficient_stats, xs, m): if fit_intercept and m.normalize: return compute_coef_for_normalize_ridge(sufficient_stats, xs, m) x_t_x, x_t_y = construct_matrix_from_elements(sufficient_stats, fit_intercept) - if isinstance(m, Ridge): + if isinstance(m, Ridge) and m.alpha: n_obs = sufficient_stats['n_obs'] penalty = np.identity(len(x_t_y)) if fit_intercept: diff --git a/models_test.py b/models_test.py index b5e4798..5fbe915 100644 --- a/models_test.py +++ b/models_test.py @@ -56,6 +56,40 @@ def test_model(self, model, sklearn_model, name): ]) pd.testing.assert_frame_equal(output, expected) + def test_model_on_operations(self, model, sklearn_model, name): + del name # unused + s = metrics.Ratio('X1', 'Y') + s2 = metrics.Sum('Y') + pct = operations.PercentChange('grp1', 'A', s, include_base=True) + ab = operations.AbsoluteChange('grp1', 'A', s, include_base=True) + mh = operations.MH('grp1', 'A', 'grp2', s, include_base=True) + prepost = operations.PrePostChange( + 'grp1', 'A', s, s2, 'grp2', include_base=True + ) + cuped = operations.CUPED('grp1', 'A', s, s2, 'grp2', include_base=True) + all_changes = metrics.MetricList((pct, ab, mh, prepost, cuped)) + m1 = model(pct, [ab, mh, prepost, cuped], name='foo') + m2 = model(name='foo')(all_changes) + + output1 = m1.compute_on(DF) + output2 = m2.compute_on(DF) + + data_to_fit = all_changes.compute_on(DF) + model = sklearn_model().fit(data_to_fit.iloc[:, 1:], data_to_fit.iloc[:, 0]) + expected = pd.DataFrame([[model.intercept_] + list(model.coef_)]) + expected.columns = ['foo Coefficient: intercept'] + [ + f'foo Coefficient: sum(X1) / sum(Y) {c}' + for c in ( + 'Absolute Change', + 'MH Ratio', + 'PrePost Percent Change', + 'CUPED Change', + ) + ] + + pd.testing.assert_frame_equal(output1, expected) + pd.testing.assert_frame_equal(output2, expected) + def test_melted(self, model, sklearn_model, name): del sklearn_model, name # unused m = model(metrics.Sum('Y'), metrics.Sum('X1'), 'grp1') diff --git a/operations.py b/operations.py index f98710c..42b35a1 100644 --- a/operations.py +++ b/operations.py @@ -374,7 +374,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, child_table = sql.Datasource(dist_sql, 'CumulativeDistributionRaw') child_table_alias = with_data.merge(child_table) columns = sql.Columns(indexes.aliases) - order = list(utils.get_extra_idx(self)) + order = list(self.get_extra_idx()) order = [ sql.Column(self.get_ordered_col(sql.Column(o).alias), auto_alias=False) for o in order @@ -487,8 +487,6 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, The global with_data which holds all datasources we need in the WITH clause. """ - if not isinstance(self, (PercentChange, AbsoluteChange)): - raise ValueError('Not a PercentChange nor AbsoluteChange!') cond_cols = sql.Columns(self.extra_index) raw_table_sql, with_data = self.get_change_raw_sql( table, split_by, global_filter, indexes, local_filter, with_data @@ -509,33 +507,29 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes, base_table_alias = with_data.merge(base_table) cond = None if self.include_base else sql.Filters([f'NOT ({base_cond})']) - if isinstance(self, AbsoluteChange): - col_tmp = f'{raw_table_alias}.%(r)s - {base_table_alias}.%(b)s' - else: - col_tmp = ( - sql.SAFE_DIVIDE.format( - numer=f'{raw_table_alias}.%(r)s', - denom=f'{base_table_alias}.%(b)s', - ) - + ' * 100 - 100' - ) - columns = sql.Columns() - val_col_len = len(raw_table_sql.all_columns) - len(indexes) + col_tmp = self.get_col_tmp(raw_table_alias, base_table_alias) + columns = [] for r, b in zip( - raw_table_sql.all_columns[-val_col_len:], - base_value.columns[-val_col_len:], + raw_table_sql.all_columns[::-1], + base_value.columns[::-1], ): + if r.alias in sql.Columns(utils.get_extra_split_by(self)).aliases: + break col = sql.Column( col_tmp % {'r': r.alias, 'b': b.alias}, alias=self.name_tmpl.format(r.alias_raw), ) - columns.add(col) + columns = [col] + columns using = indexes.difference(cond_cols) join = '' if using else 'CROSS' - return sql.Sql( - sql.Columns(indexes.aliases).add(columns), - sql.Join(raw_table_alias, base_table_alias, join=join, using=using), - cond), with_data + return ( + sql.Sql( + sql.Columns(indexes.aliases).add(columns), + sql.Join(raw_table_alias, base_table_alias, join=join, using=using), + cond, + ), + with_data, + ) def get_change_raw_sql( self, table, split_by, global_filter, indexes, local_filter, with_data @@ -550,6 +544,23 @@ def get_change_raw_sql( ) return raw_table_sql, with_data + def get_col_tmp(self, raw_table_alias, base_table_alias): + """Gets a string template to compute the comparison between columns. + + The template needs to use "%(r)s" to represent the column from + raw_table_alias and "%(b)s" to represent that from base_table_alias. + For example, AbsoluteChange returns + f'{raw_table_alias}.%(r)s - {base_table_alias}.%(b)s'. + + Args: + raw_table_alias: The alias of the raw table for comparison. + base_table_alias: The alias of the base table for comparison. + + Returns: + A string template to compute the comparison between two columns. + """ + raise NotImplementedError + class PercentChange(Comparison): """Percent change estimator on a Metric. @@ -592,6 +603,15 @@ def compute_on_children(self, children, split_by): res = res[~idx_to_match.isin([self.baseline_key])] return res + def get_col_tmp(self, raw_table_alias, base_table_alias): + return ( + sql.SAFE_DIVIDE.format( + numer=f'{raw_table_alias}.%(r)s', + denom=f'{base_table_alias}.%(b)s', + ) + + ' * 100 - 100' + ) + class AbsoluteChange(Comparison): """Absolute change estimator on a Metric. @@ -635,6 +655,9 @@ def compute_on_children(self, children, split_by): res = res[~idx_to_match.isin([self.baseline_key])] return res + def get_col_tmp(self, raw_table_alias, base_table_alias): + return f'{raw_table_alias}.%(r)s - {base_table_alias}.%(b)s' + class PrePostChange(PercentChange): """PrePost Percent change estimator on a Metric. @@ -1155,12 +1178,13 @@ def __init__(self, include_base: bool = False, name_tmpl: Text = '{} MH Ratio', **kwargs): - self.stratified_by = stratified_by if isinstance(stratified_by, - list) else [stratified_by] + stratified_by = ( + stratified_by if isinstance(stratified_by, list) else [stratified_by] + ) condition_column = [condition_column] if isinstance( condition_column, str) else condition_column super(MH, self).__init__( - condition_column + self.stratified_by, + condition_column + stratified_by, baseline_key, metric, include_base, @@ -1168,6 +1192,17 @@ def __init__(self, extra_index=condition_column, **kwargs) + @property + def stratified_by(self): + return self.extra_split_by[len(self.extra_index):] + + @stratified_by.setter + def stratified_by(self, stratified_by): + stratified_by = ( + stratified_by if isinstance(stratified_by, list) else [stratified_by] + ) + self.extra_split_by[len(self.extra_index):] = stratified_by + def check_is_ratio(self, metric, allow_metriclist=True): if isinstance(metric, metrics.MetricList) and allow_metriclist: for m in metric: @@ -3504,8 +3539,6 @@ def modify_descendants_for_jackknife_fast( if isinstance(metric, Operation): metric.extra_index = sql.Columns(metric.extra_index).aliases metric.extra_split_by = sql.Columns(metric.extra_split_by).aliases - if isinstance(metric, MH): - metric.stratified_by = sql.Column(metric.stratified_by).alias new_children = [] for m in metric.children: diff --git a/requirements.txt b/requirements.txt index 1e941b5..672bc79 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,6 @@ six numpy>=1.25 scipy>=1.9.3 sklearn>=1.0.2 +sympy>=1.12 pandas>=2.0.3 pydot>=1.4.2 \ No newline at end of file diff --git a/sql.py b/sql.py index 0240477..762f7ff 100644 --- a/sql.py +++ b/sql.py @@ -412,6 +412,10 @@ def __init__(self, columns=None, distinct=None): # pylint: disable=super-init-n def aliases(self): return [c.alias for c in self] + @property + def raw_aliases(self): + return [c.alias_raw for c in self] + @property def original_columns(self): # Returns the original Column instances added. diff --git a/utils.py b/utils.py index fcd23ae..3e54246 100644 --- a/utils.py +++ b/utils.py @@ -24,6 +24,7 @@ from typing import Iterable, List, Optional, Text, Union from meterstick import sql +import numpy as np import pandas as pd @@ -145,32 +146,6 @@ def apply_name_tmpl(name_tmpl, res, melted=False): return res -def get_extra_idx(metric, return_superset=False): - """Collects the extra indexes added by Operations for the metric tree. - - Args: - metric: A Metric instance. - return_superset: If to return the superset of extra indexes if metric has - incompatible indexes. - - Returns: - A tuple of column names which are just the index of metric.compute_on(df). - """ - extra_idx = metric.extra_index[:] - children_idx = [ - get_extra_idx(c, return_superset) for c in metric.children if is_metric(c) - ] - if len(set(children_idx)) > 1: - if not return_superset: - raise ValueError('Incompatible indexes!') - children_idx_superset = set() - children_idx_superset.update(*children_idx) - children_idx = [list(children_idx_superset)] - if children_idx: - extra_idx += list(children_idx[0]) - return tuple(extra_idx) - - def get_extra_split_by(metric, return_superset=False): """Collects the extra split_by added by Operations for the metric tree. @@ -701,3 +676,86 @@ def pcollection_to_df_via_file_io( if cleanup: os.remove(f) return pd.concat(res, ignore_index=True) + + +def get_x_t_x_and_x_t_y_cols( + xs: List[str], y: str, prefix='', fit_intercept=True, normalize=False +): + """Computes the x_t_x and x_t_y elements. + + When solving LinearRegression or Ridge using sufficient stats, we need to + constuct SQL columns for X'X and X'Y. This function takes the SQL columns of + X and y, and output SQL columns for the elements of X'X and X'Y. + + Args: + xs: A list of column names of the features. + y: The column name of y. + prefix: A prefix to be added to the alias of the generated SQL columns. + fit_intercept: If the model in question fits intercept. + normalize: If the model in question normalizes x. + + Returns: + The SQL columns for the elements of X'X and X'Y. The elements of X'X are + avg(x0), avg(x1), ..., # if fit_intercept + avg(x0 * x0), avg(x0 * x1), avg(x0 * x2), avg(x1 * x2), .... + The elements of X'Y are + avg(y), # if fit_intercept + avg(x0 * y), avg(x1 * y), ..., + Note that when fit_intercept, the return cannot be directly fed to the + get_ridge_coefficients() below. You need to prepend a '1' to x_t_x. + """ + x_t_x = [] + x_t_y = [] + if fit_intercept: + if not normalize: + x_t_x = [ + sql.Column(f'AVG({x})', alias=f'{prefix}x{i}') + for i, x in enumerate(xs) + ] + x_t_y = [sql.Column(f'AVG({y})', alias=f'{prefix}y')] + for i, x1 in enumerate(xs): + for j, x2 in enumerate(xs[i:]): + x_t_x.append( + sql.Column(f'AVG({x1} * {x2})', alias=f'{prefix}x{i}x{i + j}') + ) + x_t_y += [ + sql.Column(f'AVG({x} * {y})', alias=f'{prefix}x{i}y') + for i, x in enumerate(xs) + ] + return x_t_x, x_t_y + + +def get_ridge_coefficients( + x_t_x_elements, x_t_y_elements, fit_intercept=True, penalty=0 +): + """Computes coefficients of Ridge regression. + + Args: + x_t_x_elements: The SQL column names of the elements of X'X. If not + fit_intercept, it's the 1st return of get_x_t_x_and_x_t_y_cols. If + fit_intercept, it's the 1st return of get_x_t_x_and_x_t_y_cols with '1' + prepended. + x_t_y_elements: The SQL column names of the elements of X'Y. It's the 2nd + return of get_x_t_x_and_x_t_y_cols. + fit_intercept: If the model in question fits intercept. + penalty: The penalty of Ridge regression. + + Returns: + (X'X)^(-1)(X'Y) as a Sympy matrix. + """ + import sympy # pylint: disable=g-import-not-at-top + n = len(x_t_y_elements) + x_t_x = np.empty([n, n], dtype=object) + x_t_x[np.triu_indices(n)] = x_t_x_elements + x_t_x[np.tril_indices(n)] = x_t_x.T[np.tril_indices(n)] + x_t_x = sympy.Matrix(x_t_x) + if penalty: + iden = np.identity(n) + if fit_intercept: + iden[0, 0] = 0 + x_t_x += penalty * sympy.Matrix(iden) + # Do not use x_t_x.inv(). It's very slow + # https://stackoverflow.com/questions/75553096/why-is-sympy-matrix-inv-slow. + x_t_x_inv = x_t_x.adjugate() / x_t_x.det() + x_t_y = sympy.Matrix(x_t_y_elements) + return x_t_x_inv * x_t_y diff --git a/utils_test.py b/utils_test.py index 2fb55f7..b2d5082 100644 --- a/utils_test.py +++ b/utils_test.py @@ -20,6 +20,7 @@ from absl.testing import absltest from meterstick import metrics from meterstick import operations +from meterstick import sql from meterstick import utils import numpy as np import pandas as pd @@ -32,36 +33,37 @@ def test_adjust_slices_for_loo_no_splitby_no_operation_unit_filled(self): df = pd.DataFrame({'unit': list('abc'), 'x': range(1, 4)}) bucket_res = df[df.unit != 'a'].groupby('unit').sum() output = utils.adjust_slices_for_loo(bucket_res, [], df) - expected = pd.DataFrame({'x': [0, 2, 3]}, - index=pd.Index(list('abc'), name='unit')) + expected = pd.DataFrame( + {'x': [0, 2, 3]}, index=pd.Index(list('abc'), name='unit') + ) testing.assert_frame_equal(output, expected) def test_adjust_slices_for_loo_no_splitby_operation(self): - df = pd.DataFrame({ - 'unit': list('abb'), - 'grp': list('bbc'), - 'x': range(1, 4) - }) + df = pd.DataFrame( + {'unit': list('abb'), 'grp': list('bbc'), 'x': range(1, 4)} + ) bucket_res = df[df.unit != 'a'].groupby(['unit', 'grp']).sum() output = utils.adjust_slices_for_loo(bucket_res, [], df) - expected = pd.DataFrame({'x': [0, 0]}, - index=pd.MultiIndex.from_tuples( - (('a', 'b'), ('a', 'c')), - names=('unit', 'grp'))) + expected = pd.DataFrame( + {'x': [0, 0]}, + index=pd.MultiIndex.from_tuples( + (('a', 'b'), ('a', 'c')), names=('unit', 'grp') + ), + ) testing.assert_frame_equal(output, expected) def test_adjust_slices_for_loo_splitby_no_operation(self): - df = pd.DataFrame({ - 'unit': list('abc'), - 'grp': list('abb'), - 'x': range(1, 4) - }) + df = pd.DataFrame( + {'unit': list('abc'), 'grp': list('abb'), 'x': range(1, 4)} + ) bucket_res = df[df.grp != 'b'].groupby(['grp', 'unit']).sum() output = utils.adjust_slices_for_loo(bucket_res, ['grp'], df) - expected = pd.DataFrame({'x': [1, 0, 0]}, - index=pd.MultiIndex.from_tuples( - (('a', 'a'), ('b', 'b'), ('b', 'c')), - names=('grp', 'unit'))) + expected = pd.DataFrame( + {'x': [1, 0, 0]}, + index=pd.MultiIndex.from_tuples( + (('a', 'a'), ('b', 'b'), ('b', 'c')), names=('grp', 'unit') + ), + ) testing.assert_frame_equal(output, expected) def test_adjust_slices_for_loo_splitby_operation(self): @@ -69,22 +71,25 @@ def test_adjust_slices_for_loo_splitby_operation(self): 'grp': list('aaabbb'), 'op': ['x'] * 2 + ['y'] * 2 + ['z'] * 2, 'unit': [1, 2, 3, 2, 3, 2], - 'x': range(1, 7) + 'x': range(1, 7), }) bucket_res = df[df.unit != 1].groupby(['grp', 'unit', 'op']).sum() output = utils.adjust_slices_for_loo(bucket_res, ['grp'], df) - expected = pd.DataFrame({'x': [0, 0, 0, 0, 6, 0, 5]}, - index=pd.MultiIndex.from_tuples( - ( - ('a', 1, 'x'), - ('a', 1, 'y'), - ('a', 2, 'y'), - ('a', 3, 'x'), - ('b', 2, 'z'), - ('b', 3, 'y'), - ('b', 3, 'z'), - ), - names=('grp', 'unit', 'op'))) + expected = pd.DataFrame( + {'x': [0, 0, 0, 0, 6, 0, 5]}, + index=pd.MultiIndex.from_tuples( + ( + ('a', 1, 'x'), + ('a', 1, 'y'), + ('a', 2, 'y'), + ('a', 3, 'x'), + ('b', 2, 'z'), + ('b', 3, 'y'), + ('b', 3, 'z'), + ), + names=('grp', 'unit', 'op'), + ), + ) testing.assert_frame_equal(output, expected) def test_one_level_column_and_no_splitby_melt(self): @@ -103,173 +108,183 @@ def test_one_level_value_column_and_no_splitby_unmelt(self): def test_one_level_not_value_column_and_no_splitby_unmelt(self): melted = pd.DataFrame({'Baz': [1, 2]}, index=['foo', 'bar']) melted.index.name = 'Metric' - expected = pd.DataFrame([[1, 2]], - columns=pd.MultiIndex.from_product( - [['foo', 'bar'], ['Baz']], - names=['Metric', None])) + expected = pd.DataFrame( + [[1, 2]], + columns=pd.MultiIndex.from_product( + [['foo', 'bar'], ['Baz']], names=['Metric', None] + ), + ) testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_one_level_column_and_single_splitby_melt(self): unmelted = pd.DataFrame( - data={ - 'foo': [0, 1], - 'bar': [2, 3] - }, + data={'foo': [0, 1], 'bar': [2, 3]}, columns=['foo', 'bar'], - index=['B', 'A']) + index=['B', 'A'], + ) unmelted.index.name = 'grp' - expected = pd.DataFrame({'Value': range(4)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp'])) + expected = pd.DataFrame( + {'Value': range(4)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + ) expected.index.name = 'Metric' testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_one_level_column_and_single_splitby_unmelt(self): expected = pd.DataFrame( - data={ - 'foo': [0, 1], - 'bar': [2, 3] - }, + data={'foo': [0, 1], 'bar': [2, 3]}, columns=['foo', 'bar'], - index=['B', 'A']) + index=['B', 'A'], + ) expected.index.name = 'grp' expected.columns.name = 'Metric' - melted = pd.DataFrame({'Value': range(4)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp'])) + melted = pd.DataFrame( + {'Value': range(4)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + ) melted.index.name = 'Metric' testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_one_level_column_and_multiple_splitby_melt(self): unmelted = pd.DataFrame( - data={ - 'foo': range(4), - 'bar': range(4, 8) - }, + data={'foo': range(4), 'bar': range(4, 8)}, columns=['foo', 'bar'], - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) - expected = pd.DataFrame({'Value': range(8)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) + expected = pd.DataFrame( + {'Value': range(8)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), + names=['Metric', 'grp', 'country'], + ), + ) expected.index.name = 'Metric' testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_one_level_column_and_multiple_splitby_unmelt(self): - melted = pd.DataFrame({'Value': range(8)}, - index=pd.MultiIndex.from_product( - (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country'])) + melted = pd.DataFrame( + {'Value': range(8)}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), + names=['Metric', 'grp', 'country'], + ), + ) expected = pd.DataFrame( - data={ - 'foo': range(4), - 'bar': range(4, 8) - }, + data={'foo': range(4), 'bar': range(4, 8)}, columns=['foo', 'bar'], - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) expected.columns.name = 'Metric' testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_multiple_index_columns_and_no_splitby_melt(self): - unmelted = pd.DataFrame([[1, 2, 3, 4]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE']))) + unmelted = pd.DataFrame( + [[1, 2, 3, 4]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + ) expected = pd.DataFrame( - data={ - 'Value': [1, 3], - 'SE': [2, 4] - }, + data={'Value': [1, 3], 'SE': [2, 4]}, index=['foo', 'bar'], - columns=['Value', 'SE']) + columns=['Value', 'SE'], + ) expected.index.name = 'Metric' testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_multiple_index_columns_and_no_splitby_unmelt(self): melted = pd.DataFrame( - data={ - 'Value': [1, 3], - 'SE': [2, 4] - }, + data={'Value': [1, 3], 'SE': [2, 4]}, index=['foo', 'bar'], - columns=['Value', 'SE']) + columns=['Value', 'SE'], + ) melted.index.name = 'Metric' - expected = pd.DataFrame([[1, 2, 3, 4]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE']))) + expected = pd.DataFrame( + [[1, 2, 3, 4]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + ) expected.columns.names = ['Metric', None] testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_multiple_index_column_and_single_splitby_melt(self): - unmelted = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE'])), - index=['B', 'A']) + unmelted = pd.DataFrame( + [[1, 2, 3, 4], [5, 6, 7, 8]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + index=['B', 'A'], + ) unmelted.index.name = 'grp' expected = pd.DataFrame( - data={ - 'Value': [1, 5, 3, 7], - 'SE': [2, 6, 4, 8] - }, - index=pd.MultiIndex.from_product((['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp']), - columns=['Value', 'SE']) + data={'Value': [1, 5, 3, 7], 'SE': [2, 6, 4, 8]}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + columns=['Value', 'SE'], + ) testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_multiple_index_column_and_single_splitby_unmelt(self): melted = pd.DataFrame( - data={ - 'Value': [1, 5, 3, 7], - 'SE': [2, 6, 4, 8] - }, - index=pd.MultiIndex.from_product((['foo', 'bar'], ['B', 'A']), - names=['Metric', 'grp']), - columns=['Value', 'SE']) - expected = pd.DataFrame([[1, 2, 3, 4], [5, 6, 7, 8]], - columns=pd.MultiIndex.from_product( - (['foo', 'bar'], ['Value', 'SE'])), - index=['B', 'A']) + data={'Value': [1, 5, 3, 7], 'SE': [2, 6, 4, 8]}, + index=pd.MultiIndex.from_product( + (['foo', 'bar'], ['B', 'A']), names=['Metric', 'grp'] + ), + columns=['Value', 'SE'], + ) + expected = pd.DataFrame( + [[1, 2, 3, 4], [5, 6, 7, 8]], + columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), + index=['B', 'A'], + ) expected.index.name = 'grp' expected.columns.names = ['Metric', None] testing.assert_frame_equal(expected, utils.unmelt(melted)) def test_multiple_index_column_and_multiple_splitby_melt(self): unmelted = pd.DataFrame( - [range(4), range(4, 8), - range(8, 12), range(12, 16)], + [range(4), range(4, 8), range(8, 12), range(12, 16)], columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) expected = pd.DataFrame( data={ 'Value': [0, 4, 8, 12, 2, 6, 10, 14], - 'SE': [1, 5, 9, 13, 3, 7, 11, 15] + 'SE': [1, 5, 9, 13, 3, 7, 11, 15], }, index=pd.MultiIndex.from_product( (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country']), - columns=['Value', 'SE']) + names=['Metric', 'grp', 'country'], + ), + columns=['Value', 'SE'], + ) testing.assert_frame_equal(expected, utils.melt(unmelted)) def test_multiple_index_column_and_multiple_splitby_unmelt(self): melted = pd.DataFrame( data={ 'Value': [0, 4, 8, 12, 2, 6, 10, 14], - 'SE': [1, 5, 9, 13, 3, 7, 11, 15] + 'SE': [1, 5, 9, 13, 3, 7, 11, 15], }, index=pd.MultiIndex.from_product( (['foo', 'bar'], ['B', 'A'], ['US', 'non-US']), - names=['Metric', 'grp', 'country']), - columns=['Value', 'SE']) + names=['Metric', 'grp', 'country'], + ), + columns=['Value', 'SE'], + ) expected = pd.DataFrame( - [range(4), range(4, 8), - range(8, 12), range(12, 16)], + [range(4), range(4, 8), range(8, 12), range(12, 16)], columns=pd.MultiIndex.from_product((['foo', 'bar'], ['Value', 'SE'])), - index=pd.MultiIndex.from_product((['B', 'A'], ['US', 'non-US']), - names=['grp', 'country'])) + index=pd.MultiIndex.from_product( + (['B', 'A'], ['US', 'non-US']), names=['grp', 'country'] + ), + ) expected.columns.names = ['Metric', None] testing.assert_frame_equal(expected, utils.unmelt(melted)) @@ -288,29 +303,6 @@ def test_remove_empty_level(self): actual = utils.remove_empty_level(df) testing.assert_frame_equal(expected, actual) - def test_get_extra_idx(self): - mh = operations.MH('foo', 'f', 'bar', metrics.Ratio('a', 'b')) - ab = operations.AbsoluteChange('foo', 'f', metrics.Sum('c')) - m = operations.Jackknife('unit', metrics.MetricList((mh, ab))) - self.assertEqual(utils.get_extra_idx(m), ('foo',)) - - def test_get_extra_idx_return_superset(self): - s = metrics.Sum('x') - m = metrics.MetricList(( - operations.AbsoluteChange('g', 0, s), - operations.AbsoluteChange('g2', 1, s), - )) - actual = utils.get_extra_idx(m, True) - self.assertEqual(set(actual), set(('g', 'g2'))) - - def test_get_extra_idx_raises(self): - mh = operations.MH('foo', 'f', 'bar', metrics.Ratio('a', 'b')) - ab = operations.AbsoluteChange('baz', 'f', metrics.Sum('c')) - m = operations.Jackknife('unit', metrics.MetricList((mh, ab))) - with self.assertRaises(ValueError) as cm: - utils.get_extra_idx(m) - self.assertEqual(str(cm.exception), 'Incompatible indexes!') - def test_get_extra_split_by(self): mh = operations.MH('foo', 'f', 'bar', metrics.Ratio('a', 'b')) m = operations.AbsoluteChange('unit', 'a', mh) @@ -352,11 +344,9 @@ def test_get_equivalent_metric_with_df(self): expected = metrics.Sum('meterstick_tmp:(x * y)') expected.where = 'a' expected.name = 'foo' - expected_df = pd.DataFrame({ - 'x': [1, 2], - 'y': [2, 3], - 'meterstick_tmp:(x * y)': [2, 6] - }) + expected_df = pd.DataFrame( + {'x': [1, 2], 'y': [2, 3], 'meterstick_tmp:(x * y)': [2, 6]} + ) self.assertEqual(output, expected) testing.assert_frame_equal(df, expected_df) @@ -424,6 +414,60 @@ def test_get_leaf_metrics_include_constants(self): expected = [metrics.Sum('x'), metrics.Sum('y'), metrics.Sum('c'), 1] self.assertEqual(output, expected) + def test_get_x_t_x_and_x_t_y_cols_one_x(self): + xs = ['a'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y, 'foo_') + x_t_x = sql.Columns([ + sql.Column('AVG(a)', alias='foo_x0'), + sql.Column('AVG(a * a)', alias='foo_x0x0'), + ]) + x_t_y = sql.Columns([ + sql.Column('AVG(y)', alias='foo_y'), + sql.Column('AVG(a * y)', alias='foo_x0y'), + ]) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + + def test_get_x_t_x_and_x_t_y_cols_multiple_xs(self): + xs = ['a', 'b'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y) + x_t_x = sql.Columns([ + sql.Column('AVG(a)', alias='x0'), + sql.Column('AVG(b)', alias='x1'), + sql.Column('AVG(a * a)', alias='x0x0'), + sql.Column('AVG(a * b)', alias='x0x1'), + sql.Column('AVG(b * b)', alias='x1x1'), + ]) + x_t_y = sql.Columns([ + sql.Column('AVG(y)', alias='y'), + sql.Column('AVG(a * y)', alias='x0y'), + sql.Column('AVG(b * y)', alias='x1y'), + ]) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + + def test_get_x_t_x_and_x_t_y_cols_no_intercept(self): + xs = ['a'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y, fit_intercept=False) + x_t_x = sql.Columns([sql.Column('AVG(a * a)', alias='x0x0')]) + x_t_y = sql.Columns([sql.Column('AVG(a * y)', alias='x0y')]) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + + def test_get_x_t_x_and_x_t_y_cols_normalize(self): + xs = ['a'] + y = 'y' + actual = utils.get_x_t_x_and_x_t_y_cols(xs, y, normalize=True) + x_t_x = sql.Columns([sql.Column('AVG(a * a)', alias='x0x0')]) + x_t_y = sql.Columns( + [sql.Column('AVG(y)', alias='y'), sql.Column('AVG(a * y)', alias='x0y')] + ) + self.assertEqual(sql.Columns(actual[0]), x_t_x) + self.assertEqual(sql.Columns(actual[1]), x_t_y) + if __name__ == '__main__': absltest.main()