Skip to content

Commit

Permalink
extrapolation test
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachel Chen authored and Rachel Chen committed Feb 12, 2025
1 parent b03e3e8 commit abc85b2
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 35 deletions.
113 changes: 84 additions & 29 deletions snuba/web/rpc/v1/resolvers/common/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,17 @@
_FLOATING_POINT_PRECISION = 9


def _get_condition_in_aggregation(
aggregation: AttributeAggregation | AttributeConditionalAggregation,
) -> Expression:
condition_in_aggregation: Expression = literal(True)
if isinstance(aggregation, AttributeConditionalAggregation):
condition_in_aggregation = trace_item_filters_to_expression(
aggregation.filter, attribute_key_to_expression
)
return condition_in_aggregation


@dataclass(frozen=True)
class ExtrapolationContext(ABC):
value: float
Expand Down Expand Up @@ -253,7 +264,8 @@ def from_alias(alias: str) -> "CustomColumnInformation":


def get_attribute_confidence_interval_alias(
aggregation: AttributeAggregation, additional_metadata: dict[str, str] = {}
aggregation: AttributeAggregation | AttributeConditionalAggregation,
additional_metadata: dict[str, str] = {},
) -> str | None:
function_alias_map = {
Function.FUNCTION_COUNT: "count",
Expand All @@ -280,36 +292,49 @@ def get_attribute_confidence_interval_alias(
return None


def get_average_sample_rate_column(aggregation: AttributeAggregation) -> Expression:
def get_average_sample_rate_column(
aggregation: AttributeAggregation | AttributeConditionalAggregation,
) -> Expression:
alias = CustomColumnInformation(
custom_column_id="average_sample_rate",
referenced_column=aggregation.label,
metadata={},
).to_alias()
field = attribute_key_to_expression(aggregation.key)
condition_in_aggregation = _get_condition_in_aggregation(aggregation)
return f.divide(
f.countIf(field, get_field_existence_expression(field)),
f.countIf(
field,
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
f.sumIf(
sampling_weight_column,
get_field_existence_expression(field),
and_cond(get_field_existence_expression(field), condition_in_aggregation),
),
alias=alias,
)


def _get_count_column_alias(aggregation: AttributeAggregation) -> str:
def _get_count_column_alias(
aggregation: AttributeAggregation | AttributeConditionalAggregation,
) -> str:
return CustomColumnInformation(
custom_column_id="count",
referenced_column=aggregation.label,
metadata={},
).to_alias()


def get_count_column(aggregation: AttributeAggregation) -> Expression:
def get_count_column(
aggregation: AttributeAggregation | AttributeConditionalAggregation,
) -> Expression:
field = attribute_key_to_expression(aggregation.key)
return f.countIf(
field,
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
_get_condition_in_aggregation(aggregation),
),
alias=_get_count_column_alias(aggregation),
)

Expand All @@ -333,7 +358,7 @@ def _get_possible_percentiles(


def _get_possible_percentiles_expression(
aggregation: AttributeAggregation,
aggregation: AttributeAggregation | AttributeConditionalAggregation,
percentile: float,
granularity: float = 0.005,
width: float = 0.1,
Expand Down Expand Up @@ -368,11 +393,7 @@ def get_extrapolated_function(
sampling_weight_column = column("sampling_weight")
alias = aggregation.label if aggregation.label else None
alias_dict = {"alias": alias} if alias else {}
condition_in_aggregation: Expression = literal(True)
if isinstance(aggregation, AttributeConditionalAggregation):
condition_in_aggregation = trace_item_filters_to_expression(
aggregation.filter, attribute_key_to_expression
)
condition_in_aggregation = _get_condition_in_aggregation(aggregation)
function_map_sample_weighted: dict[
Function.ValueType, CurriedFunctionCall | FunctionCall
] = {
Expand Down Expand Up @@ -467,7 +488,7 @@ def get_extrapolated_function(


def get_confidence_interval_column(
aggregation: AttributeAggregation,
aggregation: AttributeAggregation | AttributeConditionalAggregation,
) -> Expression | None:
"""
Returns the expression for calculating the upper confidence limit for a given aggregation. If the aggregation cannot be extrapolated, returns None.
Expand All @@ -478,6 +499,8 @@ def get_confidence_interval_column(
alias = get_attribute_confidence_interval_alias(aggregation)
alias_dict = {"alias": alias} if alias else {}

condition_in_aggregation = _get_condition_in_aggregation(aggregation)

function_map_confidence_interval = {
# confidence interval = Z \cdot \sqrt{-log{(\frac{\sum_{i=1}^n \frac{1}{w_i}}{n})} \cdot \sum_{i=1}^n w_i^2 - w_i}
# ┌─────────────────────────┐
Expand All @@ -499,7 +522,10 @@ def get_confidence_interval_column(
f.multiply(sampling_weight_column, sampling_weight_column),
sampling_weight_column,
),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
)
),
Expand Down Expand Up @@ -527,7 +553,10 @@ def get_confidence_interval_column(
sampling_weight_column,
f.multiply(field, field),
),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
f.divide(
f.multiply(
Expand All @@ -537,7 +566,10 @@ def get_confidence_interval_column(
),
f.sumIf(
f.multiply(sampling_weight_column, field),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
),
column(f"{alias}_N"),
Expand All @@ -546,10 +578,19 @@ def get_confidence_interval_column(
f.multiply(
f.sumIf(
sampling_weight_column,
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
alias=f"{alias}_N",
),
f.countIf(field, get_field_existence_expression(field)),
f.countIf(
field,
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
),
)
),
Expand All @@ -576,17 +617,26 @@ def get_confidence_interval_column(
sampling_weight_column,
f.multiply(field, field),
),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
f.divide(
f.multiply(
f.sumIf(
f.multiply(sampling_weight_column, field),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
f.sumIf(
f.multiply(sampling_weight_column, field),
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
),
column(f"{alias}_N"),
Expand All @@ -595,10 +645,19 @@ def get_confidence_interval_column(
f.multiply(
f.sumIf(
sampling_weight_column,
get_field_existence_expression(field),
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
alias=f"{alias}_N",
),
f.countIf(field, get_field_existence_expression(field)),
f.countIf(
field,
and_cond(
get_field_existence_expression(field),
condition_in_aggregation,
),
),
),
)
),
Expand Down Expand Up @@ -662,11 +721,7 @@ def aggregation_to_expression(
field = field or attribute_key_to_expression(aggregation.key)
alias = aggregation.label if aggregation.label else None
alias_dict = {"alias": alias} if alias else {}
condition_in_aggregation: Expression = literal(True)
if isinstance(aggregation, AttributeConditionalAggregation):
condition_in_aggregation = trace_item_filters_to_expression(
aggregation.filter, attribute_key_to_expression
)
condition_in_aggregation = _get_condition_in_aggregation(aggregation)
function_map: dict[Function.ValueType, CurriedFunctionCall | FunctionCall] = {
Function.FUNCTION_SUM: f.sumIfOrNull(
field,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1821,8 +1821,6 @@ def test_conditional_aggregation_in_select(self, setup_teardown: Any) -> None:
],
)
response = EndpointTraceItemTable().execute(message)
print("responseee", response)
breakpoint()
assert response.column_values == [
TraceItemColumnValues(
attribute_name="kylestag",
Expand Down
Loading

0 comments on commit abc85b2

Please sign in to comment.