Skip to content

Enable optimizations in the 'mixed' mode of Jackknife/Bootstrap.compute_on_sql. In Jackknife/Bootstrap we rewrite leaf metrics to Sum and Count then preaggregate when possible. This already happens in the default mode of compute_on_sql. This change enables the trick for the mixed mode. #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions meterstick_demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18617,31 +18617,31 @@
"execution_count": null,
"metadata": {
"executionInfo": {
"elapsed": 213,
"elapsed": 59,
"status": "ok",
"timestamp": 1684897933930,
"timestamp": 1750186450282,
"user": {
"displayName": "",
"userId": ""
"displayName": "Xunmo Yang",
"userId": "12474546967758012552"
},
"user_tz": 420
},
"id": "nzewljtaTLOz",
"outputId": "8c5f37ea-ec98-419d-c42d-29ed1ca36ea5"
"id": "eoHY1kVlPbSL",
"outputId": "e4f42347-809c-492d-ec9a-a72103fd86ef"
},
"outputs": [
{
"data": {
"text/plain": [
"SELECT\n",
" grp,\n",
" SUM(IF(Y \u003e 0, X, NULL)) AS sum_X,\n",
" SUM(IF(Y \u003e 0, X, 0)) AS sum_X,\n",
" SUM(X) AS sum_X_1\n",
"FROM T\n",
"GROUP BY grp"
]
},
"execution_count": 172,
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
34 changes: 25 additions & 9 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,19 @@ def compute_through_sql(self, table, split_by, execute, mode):
pass
if self.where:
table = sql.Sql(None, table, self.where)
res = self.compute_on_sql_mixed_mode(table, split_by, execute, mode)
return self.to_series_or_number_if_not_operation(res)
try:
res = self.compute_on_sql_mixed_mode(table, split_by, execute, mode)
return self.to_series_or_number_if_not_operation(res)
except NotImplementedError:
raise
except Exception as e: # pylint: disable=broad-except
if mode:
raise ValueError(
'Please see the root cause of the failure above. You can try'
' `mode=None` to see if it helps.'
) from e
else:
raise

def to_series_or_number_if_not_operation(self, df):
return self.to_series_or_number(df) if not self.is_operation else df
Expand Down Expand Up @@ -814,8 +825,11 @@ def compute_on_sql_mixed_mode(self, table, split_by, execute, mode=None):
children = self.compute_children_sql(table, split_by, execute, mode)
return self.compute_on_children(children, split_by)

def compute_children_sql(self, table, split_by, execute, mode=None):
def compute_children_sql(
self, table, split_by, execute, mode, *args, **kwargs
):
"""The return should be similar to compute_children()."""
del args, kwargs # unused
children = []
for c in self.children:
if not isinstance(c, Metric):
Expand Down Expand Up @@ -2286,7 +2300,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
val,
SUM(weight) AS weight
FROM T
WHERE val IS NOT NULL AND weight IS NOT NULL
WHERE val IS NOT NULL AND weight IS NOT NULL AND weight != 0
GROUP BY split_by, val),
QuantileWeights AS (SELECT
split_by,
Expand All @@ -2296,7 +2310,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
- 0.5 * weight,
SUM(weight) OVER (PARTITION BY split_by)) AS weight
FROM AggregatedQuantileWeights
WHERE weight IS NOT NULL
WHERE weight IS NOT NULL AND weight != 0
ORDER BY split_by, val),
PairedQuantileWeights AS (SELECT
split_by,
Expand Down Expand Up @@ -2346,9 +2360,11 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
deduped_weight_sql = sql.Sql(
cols,
table,
sql.Filters(global_filter).add(
(f'{self.var} IS NOT NULL', f'{self.weight} IS NOT NULL')
),
sql.Filters(global_filter).add((
f'{self.var} IS NOT NULL',
f'{self.weight} IS NOT NULL',
f'{self.weight} != 0',
)),
split_by_and_value,
)
deduped_weight_alias = with_data.merge(
Expand All @@ -2372,7 +2388,7 @@ def get_sql_and_with_clause(self, table, split_by, global_filter, indexes,
normalized_weights_sql = sql.Sql(
cols,
deduped_weight_alias,
where=f'{w} IS NOT NULL',
where=(f'{w} IS NOT NULL', f'{w} != 0'),
orderby=split_by_and_value,
)
normalized_weights_alias = with_data.merge(
Expand Down
2 changes: 2 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,8 @@ def compute_on_sql_magic_mode(self, table, split_by, execute):
raise ValueError("Magic mode doesn't support class_weight!")
if self.intercept_scaling != 1:
raise ValueError('intercept_scaling is not supported in magic mode!')
if not self.y:
raise ValueError('y is not set!')

y = self.y.to_sql(table, self.group_by + split_by)
n_y = metrics.Count(y.columns[-1].alias, distinct=True)
Expand Down
Loading