From c32980a9c2e5a2b00f69b35bcd57f5ea1c1c2019 Mon Sep 17 00:00:00 2001 From: Evan Hicks Date: Mon, 15 Jan 2024 14:27:56 -0500 Subject: [PATCH] feat(mql): Support single type formula queries (#5376) Support single type formula queries using the -If suffix. This embeds all the filters for a formula parameter into the SELECT clause, by amending the aggregate with `-If` and collapsing all the filters into the new expression. --- requirements.txt | 2 +- snuba/query/mql/parser.py | 252 +++++- tests/query/parser/test_formula_mql_query.py | 757 +++++++++++++++++++ tests/test_metrics_mql_api.py | 124 ++- 4 files changed, 1092 insertions(+), 43 deletions(-) create mode 100644 tests/query/parser/test_formula_mql_query.py diff --git a/requirements.txt b/requirements.txt index 155cc0668c3..ac78a97ffdf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -28,7 +28,7 @@ sentry-redis-tools==0.1.7 sentry-relay==0.8.39 sentry-sdk==1.28.0 simplejson==3.17.6 -snuba-sdk==2.0.19 +snuba-sdk==2.0.20 structlog==22.3.0 structlog-sentry==2.0.0 sql-metadata==2.6.0 diff --git a/snuba/query/mql/parser.py b/snuba/query/mql/parser.py index b1a635875f8..ecf5eac7e00 100644 --- a/snuba/query/mql/parser.py +++ b/snuba/query/mql/parser.py @@ -1,8 +1,8 @@ from __future__ import annotations import logging -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from dataclasses import dataclass, replace +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union import sentry_sdk from parsimonious.nodes import Node, NodeVisitor @@ -49,18 +49,27 @@ # The parser returns a bunch of different types, so create a single aggregate type to # capture everything. -MQLSTUFF = Dict[str, Union[str, List[SelectedExpression], List[Expression]]] +MQLSTUFF = Dict[str, Union[str, list[SelectedExpression], list[Expression]]] logger = logging.getLogger("snuba.mql.parser") @dataclass class InitialParseResult: expression: SelectedExpression | None = None - groupby: List[SelectedExpression] | None = None - conditions: List[Expression] | None = None + formula: str | None = None + parameters: list[InitialParseResult] | None = None + groupby: list[SelectedExpression] | None = None + conditions: list[Expression] | None = None mri: str | None = None public_name: str | None = None - metric_id: int | None = None + + +ARITHMETIC_OPERATORS_MAPPING = { + "+": "plus", + "-": "minus", + "*": "multiply", + "/": "divide", +} class MQLVisitor(NodeVisitor): # type: ignore @@ -76,12 +85,106 @@ def visit_expression( Any, ], ) -> InitialParseResult: - # zero_or_more_others is used for formulas, which aren't supported yet - args, zero_or_more_others = children - return args + term, zero_or_more_others = children + if zero_or_more_others: + _, term_operator, _, coefficient, *_ = zero_or_more_others[0] + return self._visit_formula(term_operator, term, coefficient) + return term def visit_expr_op(self, node: Node, children: Sequence[Any]) -> Any: - raise InvalidQueryException("Arithmetic function not supported yet") + return ARITHMETIC_OPERATORS_MAPPING[node.text] + + def visit_term_op(self, node: Node, children: Sequence[Any]) -> Any: + return ARITHMETIC_OPERATORS_MAPPING[node.text] + + def _build_timeseries_formula_param( + self, param: InitialParseResult + ) -> SelectedExpression: + """ + Timeseries inside a formula need to have three things done to them: + 1. Add the -If suffix to their aggregate function + 2. Put all the filters for the timeseries inside the aggregate expression + 3. Add a metric_id condition to the conditions in the aggregate + + Given an input parse result from `sum(mri){x:y}`, this should output + an expression like `sumIf(value, x = y AND metric_id = mri)`. + """ + assert param.expression is not None + exp = param.expression.expression + assert isinstance(exp, (FunctionCall, CurriedFunctionCall)) + + conditions = param.conditions or [] + metric_id_condition = binary_condition( + ConditionFunctions.EQ, + Column(None, None, "metric_id"), + Literal(None, param.mri or param.public_name), + ) + conditions.append(metric_id_condition) + value_column = exp.parameters[0] + if isinstance(exp, FunctionCall): + return SelectedExpression( + None, + FunctionCall( + None, + f"{exp.function_name}If", + parameters=( + value_column, + combine_and_conditions(conditions), + ), + ), + ) + else: + return SelectedExpression( + None, + CurriedFunctionCall( + None, + FunctionCall( + exp.internal_function.alias, + f"{exp.internal_function.function_name}If", + exp.internal_function.parameters, + ), + ( + value_column, + combine_and_conditions(conditions), + ), + ), + ) + + def _visit_formula( + self, + term_operator: str, + term: InitialParseResult, + coefficient: InitialParseResult, + ) -> InitialParseResult: + # TODO: If the formula has filters/group by, where do those appear? + + # If the parameters of the query are timeseries, extract the expressions from the result + if isinstance(term, InitialParseResult) and term.expression is not None: + term = replace(term, expression=self._build_timeseries_formula_param(term)) + if ( + isinstance(coefficient, InitialParseResult) + and coefficient.expression is not None + ): + coefficient = replace( + coefficient, + expression=self._build_timeseries_formula_param(coefficient), + ) + + if ( + isinstance(term, InitialParseResult) + and isinstance(coefficient, InitialParseResult) + and term.groupby != coefficient.groupby + ): + raise InvalidQueryException( + "All terms in a formula must have the same groupby" + ) + + return InitialParseResult( + expression=None, + formula=term_operator, + parameters=[term, coefficient], + groupby=term.groupby, + ) def visit_term( self, @@ -90,11 +193,10 @@ def visit_term( ) -> InitialParseResult: term, zero_or_more_others = children if zero_or_more_others: - raise InvalidQueryException("Arithmetic function not supported yet") - return term + _, term_operator, _, coefficient, *_ = zero_or_more_others[0] + return self._visit_formula(term_operator, term, coefficient) - def visit_term_op(self, node: Node, children: Sequence[Any]) -> str: - raise InvalidQueryException("Arithmetic function not supported yet") + return term def visit_coefficient( self, @@ -120,10 +222,37 @@ def visit_filter( if packed_filters: assert isinstance(packed_filters, list) _, _, filter_expr, *_ = packed_filters[0] - if target.conditions is not None: - target.conditions = target.conditions + [filter_expr] + if target.formula is not None: + + def pushdown_filter(param: InitialParseResult) -> InitialParseResult: + if param.formula is not None: + parameters = param.parameters or [] + for p in parameters: + pushdown_filter(p) + elif param.expression is not None: + exp = param.expression.expression + assert isinstance(exp, (FunctionCall, CurriedFunctionCall)) + exp = replace( + exp, + parameters=( + exp.parameters[0], + binary_condition("and", filter_expr, exp.parameters[1]), + ), + ) + param.expression = replace(param.expression, expression=exp) + else: + raise InvalidQueryException("Could not parse formula") + + return param + + if target.parameters is not None: + for param in target.parameters: + pushdown_filter(param) else: - target.conditions = [filter_expr] + if target.conditions is not None: + target.conditions = target.conditions + [filter_expr] + else: + target.conditions = [filter_expr] if packed_groupbys: assert isinstance(packed_groupbys, list) @@ -158,7 +287,7 @@ def visit_filter_term(self, node: Node, children: Sequence[Any]) -> Any: def visit_filter_factor( self, node: Node, - children: Tuple[Sequence[Union[str, Sequence[str]]], Any], + children: Tuple[Sequence[Union[str, Sequence[str]]] | FunctionCall, Any], ) -> FunctionCall: factor, *_ = children if isinstance(factor, FunctionCall): @@ -484,33 +613,74 @@ def parse_mql_query_body( """ exp_tree = MQL_GRAMMAR.parse(body) parsed: InitialParseResult = MQLVisitor().visit(exp_tree) - if not parsed.expression: - raise ParsingException("No expression specified in MQL query") + if not parsed.expression and not parsed.formula: + raise ParsingException( + "No aggregate/expression or formula specified in MQL query" + ) - selected_columns = [parsed.expression] - if parsed.groupby: - selected_columns.extend(parsed.groupby) - groupby = [g.expression for g in parsed.groupby] if parsed.groupby else None + if parsed.formula: - id_value = parsed.metric_id or parsed.mri or parsed.public_name - metric_id_condition = binary_condition( - ConditionFunctions.EQ, - Column(None, None, "metric_id"), - Literal(None, id_value), - ) - if parsed.conditions: - conditions = combine_and_conditions( - [metric_id_condition, *parsed.conditions] + def extract_expression(param: InitialParseResult | Any) -> Expression: + if not isinstance(param, InitialParseResult): + return Literal(None, param) + elif param.expression is not None: + return param.expression.expression + elif param.formula: + parameters = param.parameters or [] + return FunctionCall( + None, + param.formula, + tuple(extract_expression(p) for p in parameters), + ) + else: + raise InvalidQueryException("Could not parse formula") + + parameters = parsed.parameters or [] + selected_columns = [ + SelectedExpression( + name=AGGREGATE_ALIAS, + expression=FunctionCall( + alias=AGGREGATE_ALIAS, + function_name=parsed.formula, + parameters=tuple(extract_expression(p) for p in parameters), + ), + ) + ] + if parsed.groupby: + selected_columns.extend(parsed.groupby) + groupby = [g.expression for g in parsed.groupby] if parsed.groupby else None + query = LogicalQuery( + from_clause=None, + selected_columns=selected_columns, + groupby=groupby, + ) + if parsed.expression: + selected_columns = [parsed.expression] + if parsed.groupby: + selected_columns.extend(parsed.groupby) + groupby = [g.expression for g in parsed.groupby] if parsed.groupby else None + + metric_value = parsed.mri or parsed.public_name + conditions: list[Expression] = [ + binary_condition( + ConditionFunctions.EQ, + Column(None, None, "metric_id"), + Literal(None, metric_value), + ) + ] + if parsed.conditions: + conditions.extend(parsed.conditions) + + final_conditions = ( + combine_and_conditions(conditions) if conditions else None ) - else: - conditions = metric_id_condition - query = LogicalQuery( - from_clause=None, - selected_columns=selected_columns, - condition=conditions, - groupby=groupby, - ) + query = LogicalQuery( + from_clause=None, + selected_columns=selected_columns, + condition=final_conditions, + groupby=groupby, + ) except Exception as e: raise e diff --git a/tests/query/parser/test_formula_mql_query.py b/tests/query/parser/test_formula_mql_query.py new file mode 100644 index 00000000000..cd677f6ec0f --- /dev/null +++ b/tests/query/parser/test_formula_mql_query.py @@ -0,0 +1,757 @@ +from __future__ import annotations + +import re +from datetime import datetime + +import pytest + +from snuba.datasets.entities.entity_key import EntityKey +from snuba.datasets.entities.factory import get_entity +from snuba.datasets.factory import get_dataset +from snuba.query import OrderBy, OrderByDirection, SelectedExpression +from snuba.query.conditions import binary_condition +from snuba.query.data_source.simple import Entity as QueryEntity +from snuba.query.dsl import divide, multiply, plus +from snuba.query.expressions import ( + Column, + CurriedFunctionCall, + FunctionCall, + Literal, + SubscriptableReference, +) +from snuba.query.logical import Query +from snuba.query.mql.parser import parse_mql_query + +# Commonly used expressions +from_distributions = QueryEntity( + EntityKey.GENERIC_METRICS_DISTRIBUTIONS, + get_entity(EntityKey.GENERIC_METRICS_DISTRIBUTIONS).get_data_model(), +) + +time_expression = FunctionCall( + "_snuba_time", + "toStartOfInterval", + ( + Column("_snuba_timestamp", None, "timestamp"), + FunctionCall(None, "toIntervalSecond", (Literal(None, 60),)), + Literal(None, "Universal"), + ), +) + +formula_condition = FunctionCall( + None, + "and", + ( + FunctionCall( + None, + "equals", + ( + Column( + "_snuba_granularity", + None, + "granularity", + ), + Literal(None, 60), + ), + ), + FunctionCall( + None, + "and", + ( + FunctionCall( + None, + "in", + ( + Column( + "_snuba_project_id", + None, + "project_id", + ), + FunctionCall( + None, + "tuple", + (Literal(None, 11),), + ), + ), + ), + FunctionCall( + None, + "and", + ( + FunctionCall( + None, + "in", + ( + Column( + "_snuba_org_id", + None, + "org_id", + ), + FunctionCall( + None, + "tuple", + (Literal(None, 1),), + ), + ), + ), + FunctionCall( + None, + "and", + ( + FunctionCall( + None, + "equals", + ( + Column( + "_snuba_use_case_id", + None, + "use_case_id", + ), + Literal(None, "transactions"), + ), + ), + FunctionCall( + None, + "and", + ( + FunctionCall( + None, + "greaterOrEquals", + ( + Column( + "_snuba_timestamp", + None, + "timestamp", + ), + Literal( + None, + datetime(2023, 11, 23, 18, 30), + ), + ), + ), + FunctionCall( + None, + "less", + ( + Column( + "_snuba_timestamp", + None, + "timestamp", + ), + Literal( + None, + datetime( + 2023, + 11, + 23, + 22, + 30, + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), + ), +) +mql_context = { + "entity": "generic_metrics_distributions", + "start": "2023-11-23T18:30:00", + "end": "2023-11-23T22:30:00", + "rollup": { + "granularity": 60, + "interval": 60, + "with_totals": "False", + "orderby": None, + }, + "scope": { + "org_ids": [1], + "project_ids": [11], + "use_case_id": "transactions", + }, + "indexer_mappings": { + "d:transactions/duration@millisecond": 123456, + "status_code": 222222, + "transaction": 333333, + }, + "limit": None, + "offset": None, +} + + +def timeseries( + agg: str, metric_id: int, condition: FunctionCall | None = None +) -> FunctionCall: + metric_condition = FunctionCall( + None, + "equals", + ( + Column( + "_snuba_metric_id", + None, + "metric_id", + ), + Literal(None, metric_id), + ), + ) + if condition: + metric_condition = FunctionCall( + None, + "and", + ( + condition, + metric_condition, + ), + ) + + return FunctionCall( + None, + agg, + ( + Column("_snuba_value", None, "value"), + metric_condition, + ), + ) + + +def tag_column(tag: str) -> SubscriptableReference: + tag_val = mql_context.get("indexer_mappings").get(tag) # type: ignore + return SubscriptableReference( + alias=f"_snuba_tags_raw[{tag_val}]", + column=Column( + alias="_snuba_tags_raw", + table_name=None, + column_name="tags_raw", + ), + key=Literal(alias=None, value=f"{tag_val}"), + ) + + +def test_simple_formula() -> None: + query_body = "sum(`d:transactions/duration@millisecond`){status_code:200} / sum(`d:transactions/duration@millisecond`)" + expected_selected = SelectedExpression( + "aggregate_value", + divide( + timeseries( + "sumIf", + 123456, + binary_condition( + "equals", tag_column("status_code"), Literal(None, "200") + ), + ), + timeseries("sumIf", 123456), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +def test_groupby() -> None: + query_body = "sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction" + expected_selected = SelectedExpression( + "aggregate_value", + divide( + timeseries( + "sumIf", + 123456, + binary_condition( + "equals", tag_column("status_code"), Literal(None, "200") + ), + ), + timeseries("sumIf", 123456), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression("transaction", tag_column("transaction")), + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[tag_column("transaction"), time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +def test_curried_aggregate() -> None: + query_body = "quantiles(0.5)(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by transaction" + expected_selected = SelectedExpression( + "aggregate_value", + divide( + CurriedFunctionCall( + alias=None, + internal_function=FunctionCall( + None, "quantilesIf", (Literal(None, 0.5),) + ), + parameters=( + Column("_snuba_value", None, "value"), + FunctionCall( + None, + "and", + ( + binary_condition( + "equals", + tag_column("status_code"), + Literal(None, "200"), + ), + FunctionCall( + None, + "equals", + ( + Column( + "_snuba_metric_id", + None, + "metric_id", + ), + Literal(None, 123456), + ), + ), + ), + ), + ), + ), + timeseries("sumIf", 123456), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression("transaction", tag_column("transaction")), + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[tag_column("transaction"), time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +def test_bracketing_rules() -> None: + query_body = "sum(`d:transactions/duration@millisecond`) / ((max(`d:transactions/duration@millisecond`) + avg(`d:transactions/duration@millisecond`)) * min(`d:transactions/duration@millisecond`))" + expected_selected = SelectedExpression( + "aggregate_value", + divide( + timeseries("sumIf", 123456), + multiply( + plus( + timeseries("maxIf", 123456), + timeseries("avgIf", 123456), + ), + timeseries("minIf", 123456), + ), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +def test_mismatch_groupby() -> None: + query_body = "sum(`d:transactions/duration@millisecond`){status_code:200} by transaction / sum(`d:transactions/duration@millisecond`) by status_code" + + generic_metrics = get_dataset( + "generic_metrics", + ) + with pytest.raises( + Exception, + match=re.escape("All terms in a formula must have the same groupby"), + ): + parse_mql_query(str(query_body), mql_context, generic_metrics) + + +def test_formula_filters() -> None: + query_body = "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200}" + expected_selected = SelectedExpression( + "aggregate_value", + divide( + timeseries( + "sumIf", + 123456, + binary_condition( + "equals", tag_column("status_code"), Literal(None, "200") + ), + ), + timeseries( + "maxIf", + 123456, + binary_condition( + "equals", tag_column("status_code"), Literal(None, "200") + ), + ), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +def test_formula_groupby() -> None: + query_body = "(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`)){status_code:200} by transaction" + expected_selected = SelectedExpression( + "aggregate_value", + divide( + timeseries( + "sumIf", + 123456, + binary_condition( + "equals", tag_column("status_code"), Literal(None, "200") + ), + ), + timeseries( + "maxIf", + 123456, + binary_condition( + "equals", tag_column("status_code"), Literal(None, "200") + ), + ), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + name="transaction", + expression=tag_column("transaction"), + ), + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[tag_column("transaction"), time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +def test_formula_scalar_value() -> None: + query_body = "(sum(`d:transactions/duration@millisecond`) / sum(`d:transactions/duration@millisecond`)) + 100" + expected_selected = SelectedExpression( + "aggregate_value", + plus( + divide( + timeseries("sumIf", 123456), + timeseries("sumIf", 123456), + ), + Literal(None, 100), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +@pytest.mark.xfail(reason="Not implemented yet") # type: ignore +def test_arbitrary_functions() -> None: + query_body = "apdex(sum(`d:transactions/duration@millisecond`), 123) / max(`d:transactions/duration@millisecond`)" + + # Note: This expected selected might not be correct, depending on exactly how we build this + expected_selected = SelectedExpression( + "aggregate_value", + divide( + FunctionCall( + None, + "apdex", + ( + Literal(None, "d:transactions/duration@millisecond"), + Literal(None, 123), + ), + ), + timeseries("maxIf", 123456), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +@pytest.mark.xfail(reason="Not implemented yet") # type: ignore +def test_arbitrary_functions_with_formula() -> None: + query_body = "apdex(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`), 123)" + + # Note: This expected selected might not be correct, depending on exactly how we build this + expected_selected = SelectedExpression( + "aggregate_value", + divide( + FunctionCall( + None, + "apdex", + ( + Literal(None, "d:transactions/duration@millisecond"), + Literal(None, 123), + ), + ), + timeseries("maxIf", 123456), + "_snuba_aggregate_value", + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=formula_condition, + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason + + +@pytest.mark.xfail(reason="Not implemented yet = needs snuba-sdk>2.0.20") # type: ignore +def test_arbitrary_functions_with_formula_and_filters() -> None: + query_body = 'apdex(sum(`d:transactions/duration@millisecond`) / max(`d:transactions/duration@millisecond`), 500){dist:["dist1", "dist2"]}' + + # Note: This expected selected might not be correct, depending on exactly how we build this + expected_selected = SelectedExpression( + "aggregate_value", + FunctionCall( + "_snuba_aggregate_value", + "apdex", + ( + divide( + timeseries("sumIf", 123456), + timeseries("maxIf", 123456), + ), + Literal(None, 500), + ), + ), + ) + expected = Query( + from_distributions, + selected_columns=[ + expected_selected, + SelectedExpression( + "time", + time_expression, + ), + ], + groupby=[time_expression], + condition=binary_condition( + "and", + formula_condition, + binary_condition( + "in", + tag_column("dist"), + FunctionCall( + None, "array", (Literal(None, "dist1"), Literal(None, "dist2")) + ), + ), + ), + order_by=[ + OrderBy( + direction=OrderByDirection.ASC, + expression=time_expression, + ) + ], + limit=1000, + offset=0, + ) + + generic_metrics = get_dataset( + "generic_metrics", + ) + query, _ = parse_mql_query(str(query_body), mql_context, generic_metrics) + eq, reason = query.equals(expected) + assert eq, reason diff --git a/tests/test_metrics_mql_api.py b/tests/test_metrics_mql_api.py index c669a44b767..509def36589 100644 --- a/tests/test_metrics_mql_api.py +++ b/tests/test_metrics_mql_api.py @@ -8,10 +8,12 @@ import pytest import simplejson as json from snuba_sdk import ( + ArithmeticOperator, Column, Condition, Direction, Flags, + Formula, Metric, MetricsQuery, MetricsScope, @@ -454,7 +456,9 @@ def test_dots_in_mri_names(self) -> None: end=self.end_time, rollup=Rollup(interval=60, totals=None, orderby=None, granularity=60), scope=MetricsScope( - org_ids=[1], project_ids=[1], use_case_id="transactions" + org_ids=[self.org_id], + project_ids=self.project_ids, + use_case_id="transactions", ), indexer_mappings={ "d:transactions/measurements.indexer_batch.payloads.len@none": DISTRIBUTIONS.metric_id, @@ -515,3 +519,121 @@ def test_crazy_characters(self) -> None: ).serialize_mql(), ) assert response.status_code == 200 + + def test_simple_formula(self) -> None: + query = MetricsQuery( + query=Formula( + ArithmeticOperator.PLUS.value, + [ + Timeseries( + metric=Metric( + "transaction.duration", + TRANSACTION_MRI, + DISTRIBUTIONS.metric_id, + DISTRIBUTIONS.entity, + ), + aggregate="avg", + ), + Timeseries( + metric=Metric( + "transaction.duration", + TRANSACTION_MRI, + DISTRIBUTIONS.metric_id, + DISTRIBUTIONS.entity, + ), + aggregate="avg", + ), + ], + ), + start=self.start_time, + end=self.end_time, + rollup=Rollup(interval=60, totals=None, orderby=None, granularity=60), + scope=MetricsScope( + org_ids=[self.org_id], + project_ids=self.project_ids, + use_case_id=USE_CASE_ID, + ), + indexer_mappings={ + TRANSACTION_MRI: DISTRIBUTIONS.metric_id, + "status_code": resolve_str("status_code"), + }, + ) + + response = self.app.post( + self.mql_route, + data=Request( + dataset=DATASET, + app_id="test", + query=query, + flags=Flags(debug=True), + tenant_ids={"referrer": "tests", "organization_id": self.org_id}, + ).serialize_mql(), + ) + assert response.status_code == 200, response.data + data = json.loads(response.data) + assert len(data["data"]) == 180, data + + @pytest.mark.xfail(reason="Needs snuba-sdk 2.0.21 or later") + def test_complex_formula(self) -> None: + query = MetricsQuery( + query=Formula( + ArithmeticOperator.DIVIDE.value, + [ + Timeseries( + metric=Metric( + "transaction.duration", + TRANSACTION_MRI, + DISTRIBUTIONS.metric_id, + DISTRIBUTIONS.entity, + ), + aggregate="quantiles", + aggregate_params=[0.5], + filters=[ + Condition( + Column("status_code"), + Op.IN, + ["200"], + ) + ], + groupby=[Column("transaction")], + ), + Timeseries( + metric=Metric( + "transaction.duration", + TRANSACTION_MRI, + DISTRIBUTIONS.metric_id, + DISTRIBUTIONS.entity, + ), + aggregate="avg", + groupby=[Column("transaction")], + ), + ], + ), + start=self.start_time, + end=self.end_time, + rollup=Rollup(interval=60, totals=None, orderby=None, granularity=60), + scope=MetricsScope( + org_ids=[self.org_id], + project_ids=self.project_ids, + use_case_id=USE_CASE_ID, + ), + indexer_mappings={ + TRANSACTION_MRI: DISTRIBUTIONS.metric_id, + "status_code": resolve_str("status_code"), + "transaction": resolve_str("transaction"), + }, + ) + + response = self.app.post( + self.mql_route, + data=Request( + dataset=DATASET, + app_id="test", + query=query, + flags=Flags(debug=True), + tenant_ids={"referrer": "tests", "organization_id": self.org_id}, + ).serialize_mql(), + ) + assert response.status_code == 200, response.data + data = json.loads(response.data) + assert len(data["data"]) == 180, data