Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(eap-api): support conditional aggregations in SELECT #6870

Merged
merged 15 commits into from
Feb 18, 2025
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python-rapidjson==1.8
redis==4.5.4
sentry-arroyo==2.19.12
sentry-kafka-schemas==1.0.4
sentry-protos==0.1.58
sentry-protos==0.1.61
sentry-redis-tools==0.3.0
sentry-relay==0.9.5
sentry-sdk==2.18.0
Expand Down
2 changes: 2 additions & 0 deletions snuba/web/db_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def execute_query(
# Apply clickhouse query setting overrides
clickhouse_query_settings.update(query_settings.get_clickhouse_settings())

print("formatted_queryyyy", formatted_query)

result = reader.execute(
formatted_query,
clickhouse_query_settings,
Expand Down
11 changes: 10 additions & 1 deletion snuba/web/rpc/v1/endpoint_trace_item_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ def _apply_label_to_column(column: Column) -> None:
if column.HasField("key"):
column.label = column.key.name

elif column.HasField("conditional_aggregation"):
column.label = column.conditional_aggregation.label

elif column.HasField("aggregation"):
column.label = column.aggregation.label

Expand All @@ -43,7 +46,13 @@ def _validate_select_and_groupby(in_msg: TraceItemTableRequest) -> None:
[c.key.name for c in in_msg.columns if c.HasField("key")]
)
grouped_by_columns = set([c.name for c in in_msg.group_by])
aggregation_present = any([c for c in in_msg.columns if c.HasField("aggregation")])
aggregation_present = any(
[
c
for c in in_msg.columns
if (c.HasField("aggregation") or c.HasField("conditional_aggregation"))
]
)
if non_aggregted_columns != grouped_by_columns and aggregation_present:
raise BadSnubaRPCRequestException(
f"Non aggregated columns should be in group_by. non_aggregated_columns: {non_aggregted_columns}, grouped_by_columns: {grouped_by_columns}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def _get_reliability_context_columns(column: Column) -> list[SelectedExpression]
extrapolated aggregates need to request extra columns to calculate the reliability of the result.
this function returns the list of columns that need to be requested.
"""
if not column.HasField("aggregation"):
if not (
column.HasField("aggregation") or column.HasField("conditional_aggregation")
):
return []

if (
Expand Down Expand Up @@ -191,6 +193,11 @@ def _column_to_expression(column: Column) -> Expression:
"""
if column.HasField("key"):
return attribute_key_to_expression(column.key)
elif column.HasField("conditional_aggregation"):
function_expr = aggregation_to_expression(column.conditional_aggregation)
# aggregation label may not be set and the column label takes priority anyways.
function_expr = replace(function_expr, alias=column.label)
return function_expr
elif column.HasField("aggregation"):
function_expr = aggregation_to_expression(column.aggregation)
# aggregation label may not be set and the column label takes priority anyways.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,18 @@ def _convert_order_by(
expression=attribute_key_to_expression(x.column.key),
)
)
elif x.column.HasField("conditional_aggregation"):
res.append(
OrderBy(
direction=direction,
expression=aggregation_to_expression(
x.column.conditional_aggregation,
attribute_key_to_expression(
x.column.conditional_aggregation.key
),
),
)
)
elif x.column.HasField("aggregation"):
res.append(
OrderBy(
Expand Down Expand Up @@ -145,6 +157,16 @@ def _build_query(request: TraceItemTableRequest) -> Query:
selected_columns.append(
SelectedExpression(name=column.label, expression=key_col)
)
elif column.HasField("conditional_aggregation"):
function_expr = aggregation_to_expression(
column.conditional_aggregation,
attribute_key_to_expression(column.conditional_aggregation.key),
)
# aggregation label may not be set and the column label takes priority anyways.
function_expr = replace(function_expr, alias=column.label)
selected_columns.append(
SelectedExpression(name=column.label, expression=function_expr)
)
elif column.HasField("aggregation"):
function_expr = aggregation_to_expression(
column.aggregation,
Expand Down
92 changes: 62 additions & 30 deletions snuba/web/rpc/v1/resolvers/common/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from functools import cached_property
from typing import Any, Dict, List, Optional

from sentry_protos.snuba.v1.attribute_conditional_aggregation_pb2 import (
AttributeConditionalAggregation,
)
from sentry_protos.snuba.v1.trace_item_attribute_pb2 import (
AttributeAggregation,
ExtrapolationMode,
Expand All @@ -16,9 +19,12 @@

from snuba.query.dsl import CurriedFunctions as cf
from snuba.query.dsl import Functions as f
from snuba.query.dsl import column
from snuba.query.dsl import and_cond, column, literal
from snuba.query.expressions import CurriedFunctionCall, Expression, FunctionCall
from snuba.web.rpc.common.common import get_field_existence_expression
from snuba.web.rpc.common.common import (
get_field_existence_expression,
trace_item_filters_to_expression,
)
from snuba.web.rpc.common.exceptions import BadSnubaRPCRequestException
from snuba.web.rpc.v1.resolvers.R_eap_spans.common.common import (
attribute_key_to_expression,
Expand Down Expand Up @@ -356,86 +362,103 @@ def _get_possible_percentiles_expression(


def get_extrapolated_function(
aggregation: AttributeAggregation,
aggregation: AttributeAggregation | AttributeConditionalAggregation,
field: Expression,
) -> CurriedFunctionCall | FunctionCall | None:
sampling_weight_column = column("sampling_weight")
alias = aggregation.label if aggregation.label else None
alias_dict = {"alias": alias} if alias else {}
condition_in_aggregation: Expression = literal(True)
if isinstance(aggregation, AttributeConditionalAggregation):
condition_in_aggregation = trace_item_filters_to_expression(
aggregation.filter, attribute_key_to_expression
)
function_map_sample_weighted: dict[
Function.ValueType, CurriedFunctionCall | FunctionCall
] = {
Function.FUNCTION_SUM: f.sumIfOrNull(
f.multiply(field, sampling_weight_column),
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_AVERAGE: f.divide(
f.sumIfOrNull(
f.multiply(field, sampling_weight_column),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field), condition_in_aggregation
),
),
f.sumIfOrNull(
sampling_weight_column,
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field), condition_in_aggregation
),
),
**alias_dict,
),
Function.FUNCTION_AVG: f.divide(
f.sumIfOrNull(
f.multiply(field, sampling_weight_column),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field), condition_in_aggregation
),
),
f.sumIfOrNull(
sampling_weight_column,
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field), condition_in_aggregation
),
),
**alias_dict,
),
Function.FUNCTION_COUNT: f.sumIfOrNull(
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_P50: cf.quantileTDigestWeightedIfOrNull(0.5)(
field,
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_P75: cf.quantileTDigestWeightedIfOrNull(0.75)(
field,
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_P90: cf.quantileTDigestWeightedIfOrNull(0.9)(
field,
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_P95: cf.quantileTDigestWeightedIfOrNull(0.95)(
field,
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_P99: cf.quantileTDigestWeightedIfOrNull(0.99)(
field,
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_MAX: f.maxIfOrNull(
field, get_field_existence_expression(field), **alias_dict
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_MIN: f.minIfOrNull(
field, get_field_existence_expression(field), **alias_dict
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
Function.FUNCTION_UNIQ: f.uniqIfOrNull(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
**alias_dict,
),
}
Expand Down Expand Up @@ -633,56 +656,65 @@ def calculate_reliability(


def aggregation_to_expression(
aggregation: AttributeAggregation, field: Expression | None = None
aggregation: AttributeAggregation | AttributeConditionalAggregation,
field: Expression | None = None,
) -> Expression:
field = field or attribute_key_to_expression(aggregation.key)
alias = aggregation.label if aggregation.label else None
alias_dict = {"alias": alias} if alias else {}
condition_in_aggregation: Expression = literal(True)
if isinstance(aggregation, AttributeConditionalAggregation):
condition_in_aggregation = trace_item_filters_to_expression(
aggregation.filter, attribute_key_to_expression
)
function_map: dict[Function.ValueType, CurriedFunctionCall | FunctionCall] = {
Function.FUNCTION_SUM: f.sumIfOrNull(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_AVERAGE: f.avgIfOrNull(
field, get_field_existence_expression(field)
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_COUNT: f.countIfOrNull(
field, get_field_existence_expression(field)
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_P50: cf.quantileIfOrNull(0.5)(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_P75: cf.quantileIfOrNull(0.75)(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_P90: cf.quantileIfOrNull(0.9)(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_P95: cf.quantileIfOrNull(0.95)(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_P99: cf.quantileIfOrNull(0.99)(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_AVG: f.avgIfOrNull(
field, get_field_existence_expression(field)
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_MAX: f.maxIfOrNull(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_MIN: f.minIfOrNull(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
Function.FUNCTION_UNIQ: f.uniqIfOrNull(
field,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
}

Expand Down
6 changes: 4 additions & 2 deletions snuba/web/rpc/v1/resolvers/common/trace_item_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ def convert_results(
converters[column.label] = lambda x: AttributeValue(val_float=float(x))
elif column.key.type == AttributeKey.TYPE_DOUBLE:
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
elif column.HasField("aggregation"):
elif column.HasField("aggregation") or column.HasField(
"conditional_aggregation"
):
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
elif column.HasField("formula"):
converters[column.label] = lambda x: AttributeValue(val_double=float(x))
else:
raise BadSnubaRPCRequestException(
"column is not one of: attribute, aggregation, or formula"
"column is not one of: attribute, (conditional) aggregation, or formula"
)

res: defaultdict[str, TraceItemColumnValues] = defaultdict(TraceItemColumnValues)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SparseAggregateAttributeTransformer:
bird | 64
chicken | 12

why? because the columns for chicken and bird don't have the attribute "wing.count"
why? because the columns for dog and cat don't have the attribute "wing.count"
but by default it sets it to 0 when the attribute is not present.

This class tranforms TraceItemTableRequest adding a filter `hasAttribute("wing.count")`
Expand All @@ -43,6 +43,8 @@ def transform(self) -> TraceItemTableRequest:
for column in self.req.columns:
if column.WhichOneof("column") == "aggregation":
agg_keys.append(column.aggregation.key)
if column.WhichOneof("column") == "conditional_aggregation":
agg_keys.append(column.conditional_aggregation.key)

if len(agg_keys) == 0:
return self.req
Expand Down
Loading