Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(eap-api): support conditional aggregations in SELECT #6870

Merged
merged 15 commits into from
Feb 18, 2025
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ python-rapidjson==1.8
redis==4.5.4
sentry-arroyo==2.19.12
sentry-kafka-schemas==1.0.4
sentry-protos==0.1.58
sentry-protos==0.1.61
sentry-redis-tools==0.3.0
sentry-relay==0.9.5
sentry-sdk==2.18.0
Expand Down
56 changes: 53 additions & 3 deletions snuba/web/rpc/v1/endpoint_trace_item_table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import uuid
from typing import Type

from sentry_protos.snuba.v1.attribute_conditional_aggregation_pb2 import (
AttributeConditionalAggregation,
)
from sentry_protos.snuba.v1.endpoint_trace_item_table_pb2 import (
AggregationComparisonFilter,
AggregationFilter,
Column,
TraceItemTableRequest,
TraceItemTableResponse,
Expand All @@ -26,8 +31,8 @@ def _apply_label_to_column(column: Column) -> None:
if column.HasField("key"):
column.label = column.key.name

elif column.HasField("aggregation"):
column.label = column.aggregation.label
elif column.HasField("conditional_aggregation"):
column.label = column.conditional_aggregation.label

for column in in_msg.columns:
_apply_label_to_column(column)
Expand All @@ -43,7 +48,9 @@ def _validate_select_and_groupby(in_msg: TraceItemTableRequest) -> None:
[c.key.name for c in in_msg.columns if c.HasField("key")]
)
grouped_by_columns = set([c.name for c in in_msg.group_by])
aggregation_present = any([c for c in in_msg.columns if c.HasField("aggregation")])
aggregation_present = any(
[c for c in in_msg.columns if c.HasField("conditional_aggregation")]
)
if non_aggregted_columns != grouped_by_columns and aggregation_present:
raise BadSnubaRPCRequestException(
f"Non aggregated columns should be in group_by. non_aggregated_columns: {non_aggregted_columns}, grouped_by_columns: {grouped_by_columns}"
Expand Down Expand Up @@ -80,6 +87,48 @@ def _transform_request(request: TraceItemTableRequest) -> TraceItemTableRequest:
return SparseAggregateAttributeTransformer(request).transform()


def convert_to_conditional_aggregation(in_msg: TraceItemTableRequest) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How different would this function be for the TimeSeries endpoint?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. leave a comment explaining why this is being done. It's not immediately clear unless you are the user
  2. Test this function independently. It has clearly defined inputs and outputs

def _add_conditional_aggregation(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see if you can replace in all the cases so you don't have the split behavior with Column and AggregationComparsionFilter

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

replace is done with input.ClearField("aggregation"), but I still need the separate isinstance(input, Column) and isinstance(input, AggregationFilter) because of mypy. Also I feel like it makes _convert more readable(?)

input: Column | AggregationComparisonFilter,
) -> None:
aggregation = input.aggregation
input.conditional_aggregation.CopyFrom(
AttributeConditionalAggregation(
aggregate=aggregation.aggregate,
key=aggregation.key,
label=aggregation.label,
extrapolation_mode=aggregation.extrapolation_mode,
)
)

def _convert(input: Column | AggregationFilter) -> None:
if isinstance(input, Column):
if input.HasField("aggregation"):
_add_conditional_aggregation(input)

if input.HasField("formula"):
_convert(input.formula.left)
_convert(input.formula.right)

if isinstance(input, AggregationFilter):
if input.HasField("and_filter"):
for aggregation_filter in input.and_filter.filters:
_convert(aggregation_filter)
if input.HasField("or_filter"):
for aggregation_filter in input.or_filter.filters:
_convert(aggregation_filter)
if input.HasField("comparison_filter"):
if input.comparison_filter.HasField("aggregation"):
_add_conditional_aggregation(input.comparison_filter)

for column in in_msg.columns:
_convert(column)
for ob in in_msg.order_by:
_convert(ob.column)
if in_msg.HasField("aggregation_filter"):
_convert(in_msg.aggregation_filter)


class EndpointTraceItemTable(
RPCEndpoint[TraceItemTableRequest, TraceItemTableResponse]
):
Expand All @@ -103,6 +152,7 @@ def response_class(cls) -> Type[TraceItemTableResponse]:
return TraceItemTableResponse

def _execute(self, in_msg: TraceItemTableRequest) -> TraceItemTableResponse:
convert_to_conditional_aggregation(in_msg)
in_msg = _apply_labels_to_columns(in_msg)
_validate_select_and_groupby(in_msg)
_validate_order_by(in_msg)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def aggregation_filter_to_expression(agg_filter: AggregationFilter) -> Expressio
f"Unsupported aggregation filter op: {AggregationComparisonFilter.Op.Name(agg_filter.comparison_filter.op)}"
)
return op_expr(
aggregation_to_expression(agg_filter.comparison_filter.aggregation),
aggregation_to_expression(
agg_filter.comparison_filter.conditional_aggregation
),
agg_filter.comparison_filter.val,
)
case "and_filter":
Expand Down Expand Up @@ -124,11 +126,13 @@ def _convert_order_by(
expression=attribute_key_to_expression(x.column.key),
)
)
elif x.column.HasField("aggregation"):
elif x.column.HasField("conditional_aggregation"):
res.append(
OrderBy(
direction=direction,
expression=aggregation_to_expression(x.column.aggregation),
expression=aggregation_to_expression(
x.column.conditional_aggregation
),
)
)
elif x.column.HasField("formula"):
Expand All @@ -146,15 +150,17 @@ def _get_reliability_context_columns(column: Column) -> list[SelectedExpression]
extrapolated aggregates need to request extra columns to calculate the reliability of the result.
this function returns the list of columns that need to be requested.
"""
if not column.HasField("aggregation"):
if not (column.HasField("conditional_aggregation")):
return []

if (
column.aggregation.extrapolation_mode
column.conditional_aggregation.extrapolation_mode
== ExtrapolationMode.EXTRAPOLATION_MODE_SAMPLE_WEIGHTED
):
context_columns = []
confidence_interval_column = get_confidence_interval_column(column.aggregation)
confidence_interval_column = get_confidence_interval_column(
column.conditional_aggregation
)
if confidence_interval_column is not None:
context_columns.append(
SelectedExpression(
Expand All @@ -163,8 +169,10 @@ def _get_reliability_context_columns(column: Column) -> list[SelectedExpression]
)
)

average_sample_rate_column = get_average_sample_rate_column(column.aggregation)
count_column = get_count_column(column.aggregation)
average_sample_rate_column = get_average_sample_rate_column(
column.conditional_aggregation
)
count_column = get_count_column(column.conditional_aggregation)
context_columns.append(
SelectedExpression(
name=average_sample_rate_column.alias,
Expand All @@ -191,8 +199,8 @@ def _column_to_expression(column: Column) -> Expression:
"""
if column.HasField("key"):
return attribute_key_to_expression(column.key)
elif column.HasField("aggregation"):
function_expr = aggregation_to_expression(column.aggregation)
elif column.HasField("conditional_aggregation"):
function_expr = aggregation_to_expression(column.conditional_aggregation)
# aggregation label may not be set and the column label takes priority anyways.
function_expr = replace(function_expr, alias=column.label)
return function_expr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,15 @@ def _convert_order_by(
expression=attribute_key_to_expression(x.column.key),
)
)
elif x.column.HasField("aggregation"):
elif x.column.HasField("conditional_aggregation"):
res.append(
OrderBy(
direction=direction,
expression=aggregation_to_expression(
x.column.aggregation,
attribute_key_to_expression(x.column.aggregation.key),
x.column.conditional_aggregation,
attribute_key_to_expression(
x.column.conditional_aggregation.key
),
),
)
)
Expand All @@ -145,10 +147,10 @@ def _build_query(request: TraceItemTableRequest) -> Query:
selected_columns.append(
SelectedExpression(name=column.label, expression=key_col)
)
elif column.HasField("aggregation"):
elif column.HasField("conditional_aggregation"):
function_expr = aggregation_to_expression(
column.aggregation,
attribute_key_to_expression(column.aggregation.key),
column.conditional_aggregation,
attribute_key_to_expression(column.conditional_aggregation.key),
)
# aggregation label may not be set and the column label takes priority anyways.
function_expr = replace(function_expr, alias=column.label)
Expand Down
Loading