diff --git a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs index 4f68869d9..0c1cfd965 100644 --- a/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs +++ b/crates/proof-of-sql/src/sql/parse/query_expr_tests.rs @@ -54,25 +54,6 @@ pub fn schema_accessor_from_table_ref_with_schema( TestSchemaAccessor::new(indexmap! {table => schema}) } -fn get_test_accessor() -> (TableRef, TestSchemaAccessor) { - let table = "sxt.t".parse().unwrap(); - let accessor = schema_accessor_from_table_ref_with_schema( - table, - indexmap! { - "s".parse().unwrap() => ColumnType::VarChar, - "i".parse().unwrap() => ColumnType::BigInt, - "d".parse().unwrap() => ColumnType::Int128, - "s0".parse().unwrap() => ColumnType::VarChar, - "i0".parse().unwrap() => ColumnType::BigInt, - "d0".parse().unwrap() => ColumnType::Int128, - "s1".parse().unwrap() => ColumnType::VarChar, - "i1".parse().unwrap() => ColumnType::BigInt, - "d1".parse().unwrap() => ColumnType::Int128, - }, - ); - (table, accessor) -} - #[test] fn we_can_convert_an_ast_with_one_column() { let t = "sxt.sxt_tab".parse().unwrap(); @@ -1128,8 +1109,17 @@ fn we_can_group_by_without_using_aggregate_functions() { #[test] fn group_by_expressions_are_parsed_before_an_order_by_referencing_an_aggregate_alias_result() { let query_text = - "select max(i) max_sal, i0 d, count(i0) from sxt.t group by i0, i1 order by max_sal"; - let (t, accessor) = get_test_accessor(); + "select max(salary) max_sal, department_budget d, count(department_budget) from sxt.employees group by department_budget, tax order by max_sal"; + + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "department_budget".parse().unwrap() => ColumnType::BigInt, + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + }, + ); let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1138,20 +1128,20 @@ fn group_by_expressions_are_parsed_before_an_order_by_referencing_an_aggregate_a let expected_query = QueryExpr::new( filter( vec![ - col_expr_plan(t, "i", &accessor), - col_expr_plan(t, "i0", &accessor), - col_expr_plan(t, "i1", &accessor), + col_expr_plan(t, "department_budget", &accessor), + col_expr_plan(t, "salary", &accessor), + col_expr_plan(t, "tax", &accessor), ], tab(t), const_bool(true), ), vec![ group_by_postprocessing( - &["i0", "i1"], + &["department_budget", "tax"], &[ - aliased_expr(max(col("i")), "max_sal"), - aliased_expr(col("i0"), "d"), - aliased_expr(count(col("i0")), "__count__"), + aliased_expr(max(col("salary")), "max_sal"), + aliased_expr(col("department_budget"), "d"), + aliased_expr(count(col("department_budget")), "__count__"), ], ), orders(&["max_sal"], &[Asc]), @@ -1240,8 +1230,14 @@ fn group_by_column_cannot_be_a_column_result_alias() { #[test] fn we_can_have_aggregate_functions_without_a_group_by_clause() { - let query_text = "select count(s) from sxt.t"; - let (t, accessor) = get_test_accessor(); + let query_text = "select count(name) from sxt.employees"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1401,8 +1397,17 @@ fn we_can_use_the_same_result_columns_with_different_aliases_and_associate_it_wi #[test] fn we_can_use_multiple_group_by_clauses_with_multiple_agg_and_non_agg_exprs() { - let (t, accessor) = get_test_accessor(); - let query_text = "select i d1, max(i1), i d2, sum(i0) sum_bonus, count(s) count_s from sxt.t group by i, i0, i"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "bonus".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + }, + ); + let query_text = "select salary d1, max(tax), salary d2, sum(bonus) sum_bonus, count(name) count_s from sxt.employees group by salary, bonus, salary"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1410,18 +1415,18 @@ fn we_can_use_multiple_group_by_clauses_with_multiple_agg_and_non_agg_exprs() { let expected_ast = QueryExpr::new( filter( - cols_expr_plan(t, &["i", "i0", "i1", "s"], &accessor), + cols_expr_plan(t, &["bonus", "name", "salary", "tax"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["i", "i0", "i"], + &["salary", "bonus", "salary"], &[ - aliased_expr(col("i"), "d1"), - aliased_expr(max(col("i1")), "__max__"), - aliased_expr(col("i"), "d2"), - aliased_expr(sum(col("i0")), "sum_bonus"), - aliased_expr(count(col("s")), "count_s"), + aliased_expr(col("salary"), "d1"), + aliased_expr(max(col("tax")), "__max__"), + aliased_expr(col("salary"), "d2"), + aliased_expr(sum(col("bonus")), "sum_bonus"), + aliased_expr(count(col("name")), "count_s"), ], )], ); @@ -1567,12 +1572,19 @@ fn we_can_parse_arithmetic_expression_within_aggregations_in_the_result_expr() { #[test] fn we_cannot_use_non_grouped_columns_outside_agg() { - let (t, accessor) = get_test_accessor(); + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); let identifier_not_in_agg_queries = vec![ - "select i from sxt.t group by s", - "select sum(i), i from sxt.t group by s", - "select min(i) + i from sxt.t group by s", - "select 2 * i, min(i) from sxt.t group by s", + "select salary from sxt.employees group by name", + "select sum(salary), salary from sxt.employees group by name", + "select min(salary) + salary from sxt.employees group by name", + "select 2 * salary, min(salary) from sxt.employees group by name", ]; for query_text in &identifier_not_in_agg_queries { @@ -1589,9 +1601,9 @@ fn we_cannot_use_non_grouped_columns_outside_agg() { } let invalid_group_by_queries = vec![ - "select 2 * i, min(i) from sxt.t", - "select sum(i), i from sxt.t", - "select max(i) + 2 * i from sxt.t", + "select 2 * salary, min(salary) from sxt.employees", + "select sum(salary), salary from sxt.employees", + "select max(salary) + 2 * salary from sxt.employees", ]; for query_text in &invalid_group_by_queries { @@ -1608,11 +1620,23 @@ fn we_cannot_use_non_grouped_columns_outside_agg() { #[test] fn varchar_column_is_not_compatible_with_integer_column() { - let bigint_to_varchar_queries = vec!["select -123 * s from sxt.t", "select i - s from sxt.t"]; - let (t, accessor) = get_test_accessor(); + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + + let bigint_to_varchar_queries = vec![ + "select -123 * name from sxt.employees", + "select salary - name from sxt.employees", + ]; + let varchar_to_bigint_queries = vec![ - "select s from sxt.t where 'abc' = i", - "select s from sxt.t where 'abc' != i", + "select name from sxt.employees where 'abc' = salary", + "select name from sxt.employees where 'abc' != salary", ]; for query_text in &bigint_to_varchar_queries { @@ -1646,8 +1670,16 @@ fn varchar_column_is_not_compatible_with_integer_column() { #[test] fn arithmetic_operations_are_not_allowed_with_varchar_column() { - let (t, accessor) = get_test_accessor(); - let query_text = "select s - s1 from sxt.t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + "position".parse().unwrap() => ColumnType::VarChar, + }, + ); + + let query_text = "select name - position from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1662,8 +1694,14 @@ fn arithmetic_operations_are_not_allowed_with_varchar_column() { #[test] fn varchar_column_is_not_allowed_within_numeric_aggregations() { - let (t, accessor) = get_test_accessor(); - let sum_query = "select sum(s) from sxt.t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + let sum_query = "select sum(name) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(sum_query).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1673,7 +1711,7 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { if expression == "cannot use expression of type 'varchar' with numeric aggregation function 'sum'" )); - let max_query = "select max(s) from sxt.t"; + let max_query = "select max(name) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(max_query).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1683,7 +1721,7 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { if expression == "cannot use expression of type 'varchar' with numeric aggregation function 'max'" )); - let min_query = "select min(s) from sxt.t"; + let min_query = "select min(name) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(min_query).unwrap(); let result = QueryExpr::::try_new(intermediate_ast, t.schema_id(), &accessor); @@ -1696,8 +1734,14 @@ fn varchar_column_is_not_allowed_within_numeric_aggregations() { #[test] fn group_by_with_bigint_column_is_valid() { - let (t, accessor) = get_test_accessor(); - let query_text = "select i from sxt.t group by i"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + }, + ); + let query_text = "select salary from sxt.employees group by salary"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1705,13 +1749,13 @@ fn group_by_with_bigint_column_is_valid() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["i"], &accessor), + cols_expr_plan(t, &["salary"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["i"], - &[aliased_expr(col("i"), "i")], + &["salary"], + &[aliased_expr(col("salary"), "salary")], )], ); assert_eq!(query, expected_query); @@ -1719,8 +1763,14 @@ fn group_by_with_bigint_column_is_valid() { #[test] fn group_by_with_decimal_column_is_valid() { - let (t, accessor) = get_test_accessor(); - let query_text = "select d from sxt.t group by d"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = "select salary from sxt.employees group by salary"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1728,13 +1778,13 @@ fn group_by_with_decimal_column_is_valid() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["d"], &accessor), + cols_expr_plan(t, &["salary"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["d"], - &[aliased_expr(col("d"), "d")], + &["salary"], + &[aliased_expr(col("salary"), "salary")], )], ); assert_eq!(query, expected_query); @@ -1742,8 +1792,14 @@ fn group_by_with_decimal_column_is_valid() { #[test] fn group_by_with_varchar_column_is_valid() { - let (t, accessor) = get_test_accessor(); - let query_text = "select s from sxt.t group by s"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + let query_text = "select name from sxt.employees group by name"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1751,13 +1807,13 @@ fn group_by_with_varchar_column_is_valid() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["s"], &accessor), + cols_expr_plan(t, &["name"], &accessor), tab(t), const_bool(true), ), vec![group_by_postprocessing( - &["s"], - &[aliased_expr(col("s"), "s")], + &["name"], + &[aliased_expr(col("name"), "name")], )], ); assert_eq!(query, expected_query); @@ -1765,8 +1821,16 @@ fn group_by_with_varchar_column_is_valid() { #[test] fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { - let (t, accessor) = get_test_accessor(); - let query_text = "select 2 * i + sum(i) - i1 from sxt.t group by i, i1"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + }, + ); + let query_text = + "select 2 * salary + sum(salary) - tax from sxt.employees group by salary, tax"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let query = @@ -1774,20 +1838,26 @@ fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { let expected_query = QueryExpr::new( filter( - cols_expr_plan(t, &["i", "i1"], &accessor), + cols_expr_plan(t, &["salary", "tax"], &accessor), tab(t), const_bool(true), ), vec![ group_by_postprocessing( - &["i", "i1"], + &["salary", "tax"], &[aliased_expr( - psub(padd(pmul(lit(2), col("i")), sum(col("i"))), col("i1")), + psub( + padd(pmul(lit(2), col("salary")), sum(col("salary"))), + col("tax"), + ), "__expr__", )], ), select_expr(&[aliased_expr( - psub(padd(pmul(lit(2), col("i")), col("__col_agg_0")), col("i1")), + psub( + padd(pmul(lit(2), col("salary")), col("__col_agg_0")), + col("tax"), + ), "__expr__", )]), ], @@ -1797,8 +1867,15 @@ fn we_can_use_arithmetic_outside_agg_expressions_while_using_group_by() { #[test] fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { - let (t, accessor) = get_test_accessor(); - let query_text = "select 7 + max(i) as max_i, min(i + 777 * d) * -5 as min_d from t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "bonus".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = "select 7 + max(salary) as max_i, min(salary + 777 * bonus) * -5 as min_d from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1806,7 +1883,7 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { let expected_ast = QueryExpr::new( filter( - cols_expr_plan(t, &["d", "i"], &accessor), + cols_expr_plan(t, &["bonus", "salary"], &accessor), tab(t), const_bool(true), ), @@ -1814,9 +1891,12 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { group_by_postprocessing( &[], &[ - aliased_expr(padd(lit(7), max(col("i"))), "max_i"), + aliased_expr(padd(lit(7), max(col("salary"))), "max_i"), aliased_expr( - pmul(min(padd(col("i"), pmul(lit(777), col("d")))), lit(-5)), + pmul( + min(padd(col("salary"), pmul(lit(777), col("bonus")))), + lit(-5), + ), "min_d", ), ], @@ -1832,8 +1912,17 @@ fn we_can_use_arithmetic_outside_agg_expressions_without_using_group_by() { #[test] fn count_aggregation_always_have_integer_type() { - let (t, accessor) = get_test_accessor(); - let query_text = "select 7 + count(s) as cs, count(i) * -5 as ci, count(d) from t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "name".parse().unwrap() => ColumnType::VarChar, + "salary".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = + "select 7 + count(name) as cs, count(salary) * -5 as ci, count(tax) from sxt.employees"; let intermediate_ast = SelectStatementParser::new().parse(query_text).unwrap(); let ast = @@ -1841,7 +1930,7 @@ fn count_aggregation_always_have_integer_type() { let expected_ast = QueryExpr::new( filter( - cols_expr_plan(t, &["d", "i", "s"], &accessor), + cols_expr_plan(t, &["name", "salary", "tax"], &accessor), tab(t), const_bool(true), ), @@ -1849,9 +1938,9 @@ fn count_aggregation_always_have_integer_type() { group_by_postprocessing( &[], &[ - aliased_expr(padd(lit(7), count(col("s"))), "cs"), - aliased_expr(pmul(count(col("i")), lit(-5)), "ci"), - aliased_expr(count(col("d")), "__count__"), + aliased_expr(padd(lit(7), count(col("name"))), "cs"), + aliased_expr(pmul(count(col("salary")), lit(-5)), "ci"), + aliased_expr(count(col("tax")), "__count__"), ], ), select_expr(&[ @@ -1866,17 +1955,41 @@ fn count_aggregation_always_have_integer_type() { #[test] fn select_wildcard_is_valid_with_group_by_exprs() { - let columns = ["s", "i", "d", "s0", "i0", "d0", "s1", "i1", "d1"]; + let columns = [ + "employee_name", + "base_salary", + "annual_bonus", + "manager_name", + "manager_salary", + "manager_bonus", + "department_name", + "department_budget", + "department_headcount", + ]; let sorted_columns = columns.iter().sorted().collect::>(); let aliased_exprs = columns .iter() .map(|c| aliased_expr(col(c), c)) .collect::>(); - let (t, accessor) = get_test_accessor(); - let table_name = "sxt.t"; + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "employee_name".parse().unwrap() => ColumnType::VarChar, + "base_salary".parse().unwrap() => ColumnType::BigInt, + "annual_bonus".parse().unwrap() => ColumnType::Int128, + "manager_name".parse().unwrap() => ColumnType::VarChar, + "manager_salary".parse().unwrap() => ColumnType::BigInt, + "manager_bonus".parse().unwrap() => ColumnType::Int128, + "department_name".parse().unwrap() => ColumnType::VarChar, + "department_budget".parse().unwrap() => ColumnType::BigInt, + "department_headcount".parse().unwrap() => ColumnType::Int128, + }, + ); + let query_text = format!( "SELECT * FROM {} GROUP BY {}", - table_name, + "sxt.employees", columns.join(", ") ); @@ -1901,10 +2014,19 @@ fn select_wildcard_is_valid_with_group_by_exprs() { #[test] fn nested_aggregations_are_not_supported() { let supported_agg = ["max", "min", "sum", "count"]; - let (t, accessor) = get_test_accessor(); + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + }, + ); for perm_aggs in supported_agg.iter().permutations(2) { - let query_text = format!("SELECT {}({}(i)) FROM t", perm_aggs[0], perm_aggs[1]); + let query_text = format!( + "SELECT {}({}(salary)) FROM sxt.employees", + perm_aggs[0], perm_aggs[1] + ); let intermediate_ast = SelectStatementParser::new().parse(&query_text).unwrap(); let result = @@ -1922,8 +2044,17 @@ fn nested_aggregations_are_not_supported() { #[test] fn select_group_and_order_by_preserve_the_column_order_reference() { const N: usize = 4; - let (t, accessor) = get_test_accessor(); - let base_cols: [&str; N] = ["i", "i0", "i1", "s"]; // sorted because of `select: [cols = ... ]` + let t = "sxt.employees".parse().unwrap(); + let accessor = schema_accessor_from_table_ref_with_schema( + t, + indexmap! { + "salary".parse().unwrap() => ColumnType::BigInt, + "department".parse().unwrap() => ColumnType::BigInt, + "tax".parse().unwrap() => ColumnType::BigInt, + "name".parse().unwrap() => ColumnType::VarChar, + }, + ); + let base_cols: [&str; N] = ["salary", "department", "tax", "name"]; // sorted because of `select: [cols = ... ]` let base_ordering = [Asc, Desc, Asc, Desc]; for (idx, perm_cols) in base_cols .into_iter()