diff --git a/meterstick_demo.ipynb b/meterstick_demo.ipynb index 0df0b92..979c8fc 100644 --- a/meterstick_demo.ipynb +++ b/meterstick_demo.ipynb @@ -18578,7 +18578,7 @@ "source": [ "#SQL\n", "\n", - "You can easily get SQL query for all built-in Metrics and Operations by calling `to_sql(sql_table, split_by)`.\n", + "You can easily get SQL query for all built-in Metrics and Operations by calling `to_sql(sql_table, split_by, create_tmp_table_for_volatile_fn=None)`.\n", "\n", "You can also directly execute the query by calling `compute_on_sql(sql_table, split_by, execute, melted)`,\n", "\n", @@ -18586,6 +18586,14 @@ "\n", "The dialect it uses is the [standard SQL](https://cloud.google.com/bigquery/docs/reference/standard-sql) in Google Cloud's BigQuery.\n", "\n", + "The choice of `create_tmp_table_for_volatile_fn` depends on your SQL engine. If query\n", + "```\n", + "WITH T AS (SELECT RAND() AS r)\n", + "SELECT t1.r - t2.r AS d\n", + "FROM T t1 CROSS JOIN T t2\n", + "```\n", + "does NOT always return 0 on your engine, set `create_tmp_table_for_volatile_fn` to `True`.\n", + "\n", "Additionally, `compute_on_sql` also takes a `mode` arg. It can be `None` (default and recommended), `'mixed'` or `'magic'`. The mode controls how we split the computation between SQL and Python. For example, for a Metric with descendants, we can compute everything in SQL (if applicable), or the children in SQL and the parent in Python, or grandchildren in SQL and the rest in Python. The default `None` mode maximizes the SQL usage, namely, everything can be computed in SQL is computed in SQL. The `mixed` mode does the opposite. It minimizes the SQL usage, namely, only leaf Metrics are computed in SQL. The advantage of the `sql` mode is that SQL is usually faster and can handle larger data than Python. On the other hand, as all the `Metric`s computed in Python will be cached, the `mixed` mode will cache all levels of `Metric`s in the `Metric` tree. As a result, if you have a complex `Metric` that has many duplicated leaf `Metric`s, the `mixed` mode could be faster.\n", "\n", "There is another `magic` mode that only applies to `Model`s. The mode computes sufficient statistics in SQL then use them to solve the coefficients in Python. It's faster then the regular modes when fitting `Model`s on large data.\n", diff --git a/metrics.py b/metrics.py index 430ec2e..5e803f7 100644 --- a/metrics.py +++ b/metrics.py @@ -93,8 +93,10 @@ def compute_on_beam( # pylint: enable=g-long-lambda -def to_sql(table, split_by=None): - return lambda metric: metric.to_sql(table, split_by) +def to_sql(table, split_by=None, create_tmp_table_for_volatile_fn=None): + return lambda metric: metric.to_sql( + table, split_by, create_tmp_table_for_volatile_fn + ) # Classes we built so caching across instances can be enabled with confidence. @@ -675,7 +677,25 @@ def to_series_or_number(self, df): def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): """Executes the query from to_sql() and process the result.""" - query = self.to_sql(table, split_by) + query = self.to_sql(table, split_by, False) + # We try to avoid using CREATE TEMP TABLE when possible. It's only used when + # - the query contains RAND(); + # - the execute doesn't evaluate RAND() only once in the WITH clause; + # - ALLOW_TEMP_TABLE is True. + if sql.ALLOW_TEMP_TABLE and 'RAND' in str(query): + query_with_tmp_table = self.to_sql(table, split_by, True) + if str(query) != str( + query_with_tmp_table + ) and not sql.rand_run_only_once_in_with_clause(execute): + try: + execute('CREATE OR REPLACE TEMP TABLE T AS (SELECT 42 AS ans);') + sql.TEMP_TABLE_SUPPORTED = True + query = self.to_sql(table, split_by, True) + except Exception: # pylint: disable=broad-except + sql.TEMP_TABLE_SUPPORTED = False + raise NotImplementedError # pylint: disable=raise-missing-from + finally: + sql.TEMP_TABLE_SUPPORTED = None res = execute(str(query)) extra_idx = list(utils.get_extra_idx(self, return_superset=True)) indexes = split_by + extra_idx if split_by else extra_idx @@ -688,8 +708,39 @@ def compute_on_sql_sql_mode(self, table, split_by=None, execute=None): res.sort_values(split_by, kind='mergesort', inplace=True) return res - def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None): - """Generates SQL query for the metric.""" + def to_sql( + self, + table, + split_by: Optional[Union[Text, List[Text]]] = None, + create_tmp_table_for_volatile_fn=None, + ): + """Generates SQL query for the metric. + + Args: + table: The table or subquery we want to query from. + split_by: The columns that we use to split the data. + create_tmp_table_for_volatile_fn: When generating the query, we assume + that volatile functions like RAND() in the WITH clause behave as if they + are evaluated only once. Unfortunately, not all engines behave like + that. In those cases, we need to CREATE TEMP TABLE to materialize the + subqueries that have volatile functions, so that the same result is used + in all places. An example is + WITH T AS (SELECT RAND() AS r) + SELECT t1.r - t2.r AS d + FROM T t1 CROSS JOIN T t2. + If it doesn't always evaluates to 0, then this arg should be True, and + we will put all subqueries that + 1) have volatile functions and + 2) are referenced in the same query multiple times, + into CREATE TEMP TABLE statements. + Note that this arg has no effect if sql.ALLOW_TEMP_TABLE is False. + When you use compute_on_sql or compute_on_beam, this arg is + automatically decided based on your `execute` function. + + Returns: + The SQL query for the metric as a SQL instance, which is similar to a str. + Calling str() on it will get the query in string. + """ global_filter = utils.get_global_filter(self) indexes = sql.Columns(split_by).add( utils.get_extra_idx(self, return_superset=True) @@ -706,6 +757,17 @@ def to_sql(self, table, split_by: Optional[Union[Text, List[Text]]] = None): global_filter, indexes, sql.Filters(), with_data) query.with_data = with_data + create_tmp_table = ( + sql.ALLOW_TEMP_TABLE + if create_tmp_table_for_volatile_fn is None + else create_tmp_table_for_volatile_fn + ) + if not create_tmp_table: + return query + # None means we don't know yet so we only check for False. + if sql.TEMP_TABLE_SUPPORTED is False: # pylint: disable=g-bool-id-comparison + raise NotImplementedError # to fall back to the mixed mode + with_data.temp_tables = sql.get_temp_tables(with_data) return query def get_sql_and_with_clause(self, table: sql.Datasource, diff --git a/operations.py b/operations.py index 7cd42d8..5e880fb 100644 --- a/operations.py +++ b/operations.py @@ -2052,7 +2052,7 @@ def compute_children_sql(self, """The return should be similar to compute_children().""" raise NotImplementedError - def to_sql(self, table, split_by=None): + def to_sql(self, table, split_by=None, create_tmp_table_for_volatile_fn=None): if not isinstance(self, (Jackknife, Bootstrap)): raise NotImplementedError split_by = [split_by] if isinstance(split_by, str) else list(split_by or []) @@ -2060,15 +2060,19 @@ def to_sql(self, table, split_by=None): self._is_root_node = True if self.has_been_preaggregated or not self.can_precompute(): if not self.where: - return super(MetricWithCI, self).to_sql(table, split_by) + return super(MetricWithCI, self).to_sql( + table, split_by, create_tmp_table_for_volatile_fn + ) table = sql.Sql(None, table, self.where) self_no_filter = copy.deepcopy(self) self_no_filter.where = None - return self_no_filter.to_sql(table, split_by) + return self_no_filter.to_sql( + table, split_by, create_tmp_table_for_volatile_fn + ) expanded, _ = utils.get_fully_expanded_equivalent_metric_tree(self) if self != expanded: - return expanded.to_sql(table, split_by) + return expanded.to_sql(table, split_by, create_tmp_table_for_volatile_fn) expanded.where = None # The filter has been taken care of in preaggregation expanded = utils.push_filters_to_leaf(expanded) @@ -2097,7 +2101,7 @@ def to_sql(self, table, split_by=None): equiv.unit = None else: equiv.has_local_filter = any([l.where for l in leaf]) - return equiv.to_sql(preagg, split_by) + return equiv.to_sql(preagg, split_by, create_tmp_table_for_volatile_fn) def get_sql_and_with_clause( self, table, split_by, global_filter, indexes, local_filter, with_data diff --git a/pyproject.toml b/pyproject.toml index a61a6f9..cc22eab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "meterstick" -version = "1.5.2" +version = "1.5.3" authors = [ { name="Xunmo Yang", email="xunmo@google.com" }, { name="Dennis Sun", email="dlsun@google.com" }, diff --git a/sql.py b/sql.py index 0240477..07e5dfb 100644 --- a/sql.py +++ b/sql.py @@ -25,6 +25,11 @@ SAFE_DIVIDE = 'IF(({denom}) = 0, NULL, ({numer}) / ({denom}))' +# If to use CREATE TEMP TABLE. Setting it to False disables CREATE TEMP TABLE +# even when it's needed. +ALLOW_TEMP_TABLE = True +# If the engine supports CREATE TEMP TABLE +TEMP_TABLE_SUPPORTED = None def is_compatible(sql0, sql1): @@ -67,6 +72,80 @@ def add_suffix(alias): return alias + '_1' +def rand_run_only_once_in_with_clause(execute): + """Check if the RAND() is only evaluated once in the WITH clause.""" + d = execute( + '''WITH T AS (SELECT RAND() AS r) + SELECT t1.r - t2.r AS d + FROM T t1 CROSS JOIN T t2''' + ) + return bool(d.iloc[0, 0] == 0) + + +def dep_on_rand_table(query, rand_tables): + """Returns if a SQL query depends on any stochastic table in rand_tables.""" + for rand_table in rand_tables: + if re.search(r'\b%s\b' % rand_table, str(query)): + return True + return False + + +def get_temp_tables(with_data: 'Datasources'): + """Gets all the subquery tables that need to be materialized. + + When generating the query, we assume that volatile functions like RAND() in + the WITH clause behave as if they are evaluated only once. Unfortunately, not + all engines behave like that. In those cases, we need to CREATE TEMP TABLE to + materialize the subqueries that have volatile functions, so that the same + result is used in all places. An example is + WITH T AS (SELECT RAND() AS r) + SELECT t1.r - t2.r AS d + FROM T t1 CROSS JOIN T t2. + If it doesn't always evaluates to 0, we need to create a temp table for T. + A subquery needs to be materialized if + 1. it depends on any stochastic table + (e.g. RAND()) and + 2. the random column is referenced in the same query multiple times. + #2 is hard to check so we check if the stochastic table is referenced in the + same query multiple times instead. + An exception is the BootstrapRandomChoices table, which refers to a stochastic + table twice but only one refers to the stochasic column, so we don't need to + materialize it. + This function finds all the subquery tables in the WITH clause that need to be + materialized by + 1. finding all the stochastic tables, + 2. finding all the tables that depend, even indirectly, on a stochastic table, + 3. finding all the tables in #2 that are referenced in the same query multiple + times. + + Args: + with_data: The with clause. + + Returns: + A set of table names that need to be materialized. + """ + tmp_tables = set() + for rand_table in with_data: + query = with_data[rand_table] + if 'RAND' not in str(query): + continue + dep_on_rand = set([rand_table]) + for alias in with_data: + if dep_on_rand_table(with_data[alias].from_data, dep_on_rand): + dep_on_rand.add(alias) + for t in dep_on_rand: + from_data = with_data[t].from_data + if isinstance(from_data, Join) and not t.startswith( + 'BootstrapRandomChoices' + ): + if dep_on_rand_table(from_data.ds1, dep_on_rand) and dep_on_rand_table( + from_data.ds2, dep_on_rand + ): + tmp_tables.add(rand_table) + break + return tmp_tables + + def get_alias(c): return getattr(c, 'alias_raw', c) @@ -571,6 +650,7 @@ class Datasources(SqlComponents): def __init__(self, datasources=None): super(Datasources, self).__init__() self.children = collections.OrderedDict() + self.temp_tables = set() self.add(datasources) @property @@ -663,7 +743,7 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]): return if not isinstance(children, Datasource): raise ValueError('Not a Datasource!') - alias, table = children.alias, children.table + alias, table = children.alias, children.table, if alias not in self.children: if table not in self.children.values(): self.children[alias] = table @@ -676,6 +756,23 @@ def add(self, children: Union[Datasource, Iterable[Datasource]]): children.alias = add_suffix(alias) return self.add(children) + def add_temp_table(self, table: Union[str, 'Sql', Join, Datasource]): + """Marks alias and all its data dependencies as temp tables.""" + if isinstance(table, str): + self.temp_tables.add(table) + if table in self.children: + self.add_temp_table(self.children[table]) + return + if isinstance(table, Join): + self.add_temp_table(table.ds1) + self.add_temp_table(table.ds2) + return + if isinstance(table, Datasource): + return self.add_temp_table(table.table) + if isinstance(table, Sql): + return self.add_temp_table(table.from_data) + return self + def extend(self, other: 'Datasources'): """Merge other to self. Adjust the query if a new alias is needed.""" datasources = list(other.datasources) @@ -691,7 +788,18 @@ def extend(self, other: 'Datasources'): return self def __str__(self): - return ',\n'.join((d.get_expression('WITH') for d in self.datasources if d)) + temp_tables = [] + with_tables = [] + for d in self.datasources: + expression = d.get_expression('WITH') + if d.alias in self.temp_tables: + temp_tables.append(f'CREATE OR REPLACE TEMP TABLE {expression};') + else: + with_tables.append(expression) + res = '\n'.join(temp_tables) + if with_tables: + res += '\nWITH\n' + ',\n'.join(with_tables) + return res.strip() class Sql(SqlComponent): @@ -766,7 +874,7 @@ def merge(self, other: 'Sql'): return True def __str__(self): - with_clause = 'WITH\n%s' % self.with_data if self.with_data else None + with_clause = str(self.with_data) if self.with_data else None all_columns = self.all_columns or '*' select_clause = f'SELECT\n{all_columns}' from_clause = ('FROM %s'