diff --git a/crates/proof-of-sql/src/base/database/group_by_util.rs b/crates/proof-of-sql/src/base/database/group_by_util.rs index 94aec1f7a..50713c644 100644 --- a/crates/proof-of-sql/src/base/database/group_by_util.rs +++ b/crates/proof-of-sql/src/base/database/group_by_util.rs @@ -17,6 +17,12 @@ pub struct AggregatedColumns<'a, S: Scalar> { pub group_by_columns: Vec>, /// Resulting sums of the groups for the columns in `sum_columns_in`. pub sum_columns: Vec<&'a [S]>, + /// Resulting maxima of the groups for the columns in `max_columns_in`. Note that for empty groups + /// the result will be `None`. + pub max_columns: Vec<&'a [Option]>, + /// Resulting minima of the groups for the columns in `min_columns_in`. Note that for empty groups + /// the result will be `None`. + pub min_columns: Vec<&'a [Option]>, /// The number of rows in each group. pub count_column: &'a [i64], } @@ -28,7 +34,8 @@ pub enum AggregateColumnsError { /// This is a function that gives the result of a group by query similar to the following: /// ```sql -/// SELECT , , ..., SUM(), SUM(), ..., COUNT(*) +/// SELECT , , ..., SUM(), SUM(), ..., +/// MAX(), ..., MIN(), ..., COUNT(*) /// WHERE selection GROUP BY , , ... /// ``` /// @@ -38,17 +45,20 @@ pub fn aggregate_columns<'a, S: Scalar>( alloc: &'a Bump, group_by_columns_in: &[Column<'a, S>], sum_columns_in: &[Column], + max_columns_in: &[Column], + min_columns_in: &[Column], selection_column_in: &[bool], ) -> Result, AggregateColumnsError> { - for col in group_by_columns_in { - if col.len() != selection_column_in.len() { - return Err(AggregateColumnsError::ColumnLengthMismatch); - } - } - for col in sum_columns_in { - if col.len() != selection_column_in.len() { - return Err(AggregateColumnsError::ColumnLengthMismatch); - } + // Check that all the columns have the same length + let len = selection_column_in.len(); + if group_by_columns_in + .iter() + .chain(sum_columns_in.iter()) + .chain(max_columns_in.iter()) + .chain(min_columns_in.iter()) + .any(|col| col.len() != len) + { + return Err(AggregateColumnsError::ColumnLengthMismatch); } // `filtered_indexes` is a vector of indexes of the rows that are selected. We sort this vector @@ -84,12 +94,22 @@ pub fn aggregate_columns<'a, S: Scalar>( sum_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes) })); + let max_columns_out = Vec::from_iter(max_columns_in.iter().map(|column| { + max_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes) + })); + + let min_columns_out = Vec::from_iter(min_columns_in.iter().map(|column| { + min_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes) + })); + // Cast the counts to something compatible with BigInt. let count_column_out = alloc.alloc_slice_fill_iter(counts.into_iter().map(|c| c as i64)); Ok(AggregatedColumns { group_by_columns: group_by_columns_out, sum_columns: sum_columns_out, + max_columns: max_columns_out, + min_columns: min_columns_out, count_column: count_column_out, }) } @@ -121,6 +141,68 @@ pub(crate) fn sum_aggregate_column_by_index_counts<'a, S: Scalar>( } } +/// Returns a slice with the lifetime of `alloc` that contains the grouped maxima of `column`. +/// The `counts` slice contains the number of elements in each group and the `indexes` slice +/// contains the indexes of the elements in `column`. +/// +/// See [`max_aggregate_slice_by_index_counts`] for an example. This is a helper wrapper around that function. +pub(crate) fn max_aggregate_column_by_index_counts<'a, S: Scalar>( + alloc: &'a Bump, + column: &Column, + counts: &[usize], + indexes: &[usize], +) -> &'a [Option] { + match column { + Column::Boolean(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::SmallInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::Int(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::BigInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::Int128(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::Decimal75(_, _, col) => { + max_aggregate_slice_by_index_counts(alloc, col, counts, indexes) + } + Column::TimestampTZ(_, _, col) => { + max_aggregate_slice_by_index_counts(alloc, col, counts, indexes) + } + Column::Scalar(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + // The following should never be reached because the `MAX` function can't be applied to varchar. + Column::VarChar(_) => { + unreachable!("MAX can not be applied to varchar") + } + } +} + +/// Returns a slice with the lifetime of `alloc` that contains the grouped minima of `column`. +/// The `counts` slice contains the number of elements in each group and the `indexes` slice +/// contains the indexes of the elements in `column`. +/// +/// See [`min_aggregate_slice_by_index_counts`] for an example. This is a helper wrapper around that function. +pub(crate) fn min_aggregate_column_by_index_counts<'a, S: Scalar>( + alloc: &'a Bump, + column: &Column, + counts: &[usize], + indexes: &[usize], +) -> &'a [Option] { + match column { + Column::Boolean(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::SmallInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::Int(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::BigInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::Int128(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + Column::Decimal75(_, _, col) => { + min_aggregate_slice_by_index_counts(alloc, col, counts, indexes) + } + Column::TimestampTZ(_, _, col) => { + min_aggregate_slice_by_index_counts(alloc, col, counts, indexes) + } + Column::Scalar(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes), + // The following should never be reached because the `MIN` function can't be applied to varchar. + Column::VarChar(_) => { + unreachable!("MIN can not be applied to varchar") + } + } +} + /// Returns a slice with the lifetime of `alloc` that contains the grouped sums of `slice`. /// The `counts` slice contains the number of elements in each group and the `indexes` slice /// contains the indexes of the elements in `slice`. @@ -182,7 +264,6 @@ where /// let result = max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); /// assert_eq!(result, expected); /// ``` -#[allow(dead_code)] pub(crate) fn max_aggregate_slice_by_index_counts<'a, S, T>( alloc: &'a Bump, slice: &[T], @@ -226,7 +307,6 @@ where /// let result = min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); /// assert_eq!(result, expected); /// ``` -#[allow(dead_code)] pub(crate) fn min_aggregate_slice_by_index_counts<'a, S, T>( alloc: &'a Bump, slice: &[T], @@ -243,7 +323,7 @@ where indexes[start..index] .iter() .map(|i| S::from(&slice[*i])) - .max_by(|x, y| x.signed_cmp(y)) + .min_by(|x, y| x.signed_cmp(y)) })) } diff --git a/crates/proof-of-sql/src/base/database/group_by_util_test.rs b/crates/proof-of-sql/src/base/database/group_by_util_test.rs index 843bd42e7..1bbd656fb 100644 --- a/crates/proof-of-sql/src/base/database/group_by_util_test.rs +++ b/crates/proof-of-sql/src/base/database/group_by_util_test.rs @@ -18,7 +18,7 @@ fn we_can_aggregate_empty_columns() { let sum_columns = &[column_c.clone(), column_d.clone()]; let selection = &[]; let alloc = Bump::new(); - let aggregate_result = aggregate_columns(&alloc, group_by, sum_columns, selection) + let aggregate_result = aggregate_columns(&alloc, group_by, sum_columns, &[], &[], selection) .expect("Aggregation should succeed"); assert_eq!( aggregate_result.group_by_columns, @@ -28,6 +28,89 @@ fn we_can_aggregate_empty_columns() { assert_eq!(aggregate_result.count_column, &[0i64; 0]); } +#[test] +fn we_can_aggregate_columns_with_empty_group_by_and_no_rows_selected() { + let slice_c = &[100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111]; + let slice_d = &[200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211]; + let selection = &[false; 12]; + let scals_d: Vec = slice_d.iter().map(|s| s.into()).collect(); + let column_c = Column::Int128(slice_c); + let column_d = Column::Scalar(&scals_d); + let group_by = &[]; + let sum_columns = &[column_c.clone(), column_d.clone()]; + let max_columns = &[column_c.clone(), column_d.clone()]; + let min_columns = &[column_c.clone(), column_d.clone()]; + let alloc = Bump::new(); + let aggregate_result = aggregate_columns( + &alloc, + group_by, + sum_columns, + min_columns, + max_columns, + selection, + ) + .expect("Aggregation should succeed"); + let expected_group_by_result = &[]; + let expected_sum_result = &[&[], &[]]; + let expected_max_result = &[&[], &[]]; + let expected_min_result = &[&[], &[]]; + let expected_count_result: &[i64] = &[]; + assert_eq!(aggregate_result.group_by_columns, expected_group_by_result); + assert_eq!(aggregate_result.sum_columns, expected_sum_result); + assert_eq!(aggregate_result.count_column, expected_count_result); + assert_eq!(aggregate_result.max_columns, expected_max_result); + assert_eq!(aggregate_result.min_columns, expected_min_result); +} + +#[test] +fn we_can_aggregate_columns_with_empty_group_by() { + let slice_c = &[100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111]; + let slice_d = &[200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211]; + let selection = &[ + false, true, true, true, true, true, true, true, true, true, true, true, + ]; + let scals_d: Vec = slice_d.iter().map(|s| s.into()).collect(); + let column_c = Column::Int128(slice_c); + let column_d = Column::Scalar(&scals_d); + let group_by = &[]; + let sum_columns = &[column_c.clone(), column_d.clone()]; + let max_columns = &[column_c.clone(), column_d.clone()]; + let min_columns = &[column_c.clone(), column_d.clone()]; + let alloc = Bump::new(); + let aggregate_result = aggregate_columns( + &alloc, + group_by, + sum_columns, + min_columns, + max_columns, + selection, + ) + .expect("Aggregation should succeed"); + let expected_group_by_result = &[]; + let expected_sum_result = &[ + &[Curve25519Scalar::from( + 101 + 102 + 103 + 104 + 105 + 106 + 107 + 108 + 109 + 110 + 111, + )], + &[Curve25519Scalar::from( + 201 + 202 + 203 + 204 + 205 + 206 + 207 + 208 + 209 + 210 + 211, + )], + ]; + let expected_max_result = &[ + &[Some(Curve25519Scalar::from(111))], + &[Some(Curve25519Scalar::from(211))], + ]; + let expected_min_result = &[ + &[Some(Curve25519Scalar::from(101))], + &[Some(Curve25519Scalar::from(201))], + ]; + let expected_count_result = &[11]; + assert_eq!(aggregate_result.group_by_columns, expected_group_by_result); + assert_eq!(aggregate_result.sum_columns, expected_sum_result); + assert_eq!(aggregate_result.count_column, expected_count_result); + assert_eq!(aggregate_result.max_columns, expected_max_result); + assert_eq!(aggregate_result.min_columns, expected_min_result); +} + #[test] fn we_can_aggregate_columns() { let slice_a = &[3, 3, 3, 2, 2, 1, 1, 2, 2, 3, 3, 3]; @@ -47,9 +130,18 @@ fn we_can_aggregate_columns() { let column_d = Column::Scalar(&scals_d); let group_by = &[column_a.clone(), column_b.clone()]; let sum_columns = &[column_c.clone(), column_d.clone()]; + let max_columns = &[column_c.clone(), column_d.clone()]; + let min_columns = &[column_c.clone(), column_d.clone()]; let alloc = Bump::new(); - let aggregate_result = aggregate_columns(&alloc, group_by, sum_columns, selection) - .expect("Aggregation should succeed"); + let aggregate_result = aggregate_columns( + &alloc, + group_by, + sum_columns, + min_columns, + max_columns, + selection, + ) + .expect("Aggregation should succeed"); let scals_res = [ Curve25519Scalar::from("Cat"), Curve25519Scalar::from("Dog"), @@ -80,10 +172,48 @@ fn we_can_aggregate_columns() { Curve25519Scalar::from(202 + 210), ], ]; + let expected_max_result = &[ + &[ + Some(Curve25519Scalar::from(105)), + Some(Curve25519Scalar::from(106)), + Some(Curve25519Scalar::from(107)), + Some(Curve25519Scalar::from(108)), + Some(Curve25519Scalar::from(111)), + Some(Curve25519Scalar::from(110)), + ], + &[ + Some(Curve25519Scalar::from(205)), + Some(Curve25519Scalar::from(206)), + Some(Curve25519Scalar::from(207)), + Some(Curve25519Scalar::from(208)), + Some(Curve25519Scalar::from(211)), + Some(Curve25519Scalar::from(210)), + ], + ]; + let expected_min_result = &[ + &[ + Some(Curve25519Scalar::from(105)), + Some(Curve25519Scalar::from(106)), + Some(Curve25519Scalar::from(103)), + Some(Curve25519Scalar::from(104)), + Some(Curve25519Scalar::from(101)), + Some(Curve25519Scalar::from(102)), + ], + &[ + Some(Curve25519Scalar::from(205)), + Some(Curve25519Scalar::from(206)), + Some(Curve25519Scalar::from(203)), + Some(Curve25519Scalar::from(204)), + Some(Curve25519Scalar::from(201)), + Some(Curve25519Scalar::from(202)), + ], + ]; let expected_count_result = &[1, 1, 2, 2, 3, 2]; assert_eq!(aggregate_result.group_by_columns, expected_group_by_result); assert_eq!(aggregate_result.sum_columns, expected_sum_result); assert_eq!(aggregate_result.count_column, expected_count_result); + assert_eq!(aggregate_result.max_columns, expected_max_result); + assert_eq!(aggregate_result.min_columns, expected_min_result); } #[test] @@ -264,6 +394,7 @@ fn we_can_compare_indexes_by_columns_for_scalar_columns() { assert_eq!(compare_indexes_by_columns(columns, 6, 9), Ordering::Equal); } +// SUM slices #[test] fn we_can_sum_aggregate_slice_by_counts_for_empty_slice() { let slice_a: &[i64; 0] = &[]; @@ -289,7 +420,39 @@ fn we_can_sum_aggregate_slice_by_counts_with_empty_result() { } #[test] -fn we_can_sum_aggregate_slice_by_counts() { +fn we_can_sum_aggregate_slice_by_counts_with_all_empty_groups() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[]; + let counts = &[0, 0, 0]; + let expected = &[Curve25519Scalar::from(0); 3]; + let alloc = Bump::new(); + let result: &[Curve25519Scalar] = + sum_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_sum_aggregate_slice_by_counts_with_some_empty_group() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[12, 11, 1, 10, 2, 3, 4]; + let counts = &[3, 4, 0]; + let expected = &[ + Curve25519Scalar::from(112 + 111 + 101), + Curve25519Scalar::from(110 + 102 + 103 + 104), + Curve25519Scalar::from(0), + ]; + let alloc = Bump::new(); + let result: &[Curve25519Scalar] = + sum_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_sum_aggregate_slice_by_counts_without_empty_groups() { let slice_a = &[ 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, ]; @@ -349,3 +512,243 @@ fn we_can_sum_aggregate_columns_by_counts() { let result = sum_aggregate_column_by_index_counts(&alloc, &columns_c, counts, indexes); assert_eq!(result, expected); } + +// MAX slices +#[test] +fn we_can_max_aggregate_slice_by_counts_for_empty_slice() { + let slice_a: &[i64; 0] = &[]; + let indexes = &[]; + let counts = &[]; + let expected: &[Option; 0] = &[]; + let alloc = Bump::new(); + let result: &[Option] = + max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_max_aggregate_slice_by_counts_with_empty_result() { + let slice_a = &[100, 101, 102, 103, 104, 105, 106, 107, 108, 109]; + let indexes = &[]; + let counts = &[]; + let expected: &[Option; 0] = &[]; + let alloc = Bump::new(); + let result: &[Option] = + max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_max_aggregate_slice_by_counts_with_all_empty_groups() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[]; + let counts = &[0, 0, 0]; + let expected = &[None; 3]; + let alloc = Bump::new(); + let result: &[Option] = + max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_max_aggregate_slice_by_counts_with_some_empty_group() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[12, 11, 1, 10, 2, 3, 4]; + let counts = &[3, 4, 0]; + let expected = &[ + Some(Curve25519Scalar::from(112)), + Some(Curve25519Scalar::from(110)), + None, + ]; + let alloc = Bump::new(); + let result: &[Option] = + max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_max_aggregate_slice_by_counts_without_empty_groups() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[12, 11, 1, 10, 2, 3, 6, 14, 13, 9]; + let counts = &[3, 3, 4]; + let expected = &[ + Some(Curve25519Scalar::from(112)), + Some(Curve25519Scalar::from(110)), + Some(Curve25519Scalar::from(114)), + ]; + let alloc = Bump::new(); + let result: &[Option] = + max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_max_aggregate_columns_by_counts_for_empty_column() { + let slice_a: &[i64; 0] = &[]; + let column_a = Column::BigInt::(slice_a); + let indexes = &[]; + let counts = &[]; + let expected: &[Option; 0] = &[]; + let alloc = Bump::new(); + let result: &[Option] = + max_aggregate_column_by_index_counts(&alloc, &column_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_max_aggregate_columns_by_counts() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let slice_b = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let slice_c = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let scals_c: Vec = slice_c.iter().map(|s| s.into()).collect(); + let column_a = Column::BigInt::(slice_a); + let columns_b = Column::Int128::(slice_b); + let columns_c = Column::Scalar(&scals_c); + let indexes = &[12, 11, 1, 10, 2, 3, 6, 14, 13, 9]; + let counts = &[3, 3, 4, 0]; + let expected = &[ + Some(Curve25519Scalar::from(112)), + Some(Curve25519Scalar::from(110)), + Some(Curve25519Scalar::from(114)), + None, + ]; + let alloc = Bump::new(); + let result = max_aggregate_column_by_index_counts(&alloc, &column_a, counts, indexes); + assert_eq!(result, expected); + let result = max_aggregate_column_by_index_counts(&alloc, &columns_b, counts, indexes); + assert_eq!(result, expected); + let result = max_aggregate_column_by_index_counts(&alloc, &columns_c, counts, indexes); + assert_eq!(result, expected); +} + +// MIN slices +#[test] +fn we_can_min_aggregate_slice_by_counts_for_empty_slice() { + let slice_a: &[i64; 0] = &[]; + let indexes = &[]; + let counts = &[]; + let expected: &[Option; 0] = &[]; + let alloc = Bump::new(); + let result: &[Option] = + min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_min_aggregate_slice_by_counts_with_empty_result() { + let slice_a = &[100, 101, 102, 103, 104, 105, 106, 107, 108, 109]; + let indexes = &[]; + let counts = &[]; + let expected: &[Option; 0] = &[]; + let alloc = Bump::new(); + let result: &[Option] = + min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_min_aggregate_slice_by_counts_with_all_empty_groups() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[]; + let counts = &[0, 0, 0]; + let expected = &[None; 3]; + let alloc = Bump::new(); + let result: &[Option] = + min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_min_aggregate_slice_by_counts_with_some_empty_group() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[12, 11, 1, 10, 2, 3, 4]; + let counts = &[3, 4, 0]; + let expected = &[ + Some(Curve25519Scalar::from(101)), + Some(Curve25519Scalar::from(102)), + None, + ]; + let alloc = Bump::new(); + let result: &[Option] = + min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_min_aggregate_slice_by_counts_without_empty_groups() { + let slice_a = &[ + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, + ]; + let indexes = &[12, 11, 1, 10, 2, 3, 6, 14, 13, 9]; + let counts = &[3, 3, 4]; + let expected = &[ + Some(Curve25519Scalar::from(101)), + Some(Curve25519Scalar::from(102)), + Some(Curve25519Scalar::from(106)), + ]; + let alloc = Bump::new(); + let result: &[Option] = + min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_min_aggregate_columns_by_counts_for_empty_column() { + let slice_a: &[i64; 0] = &[]; + let column_a = Column::BigInt::(slice_a); + let indexes = &[]; + let counts = &[]; + let expected: &[Option; 0] = &[]; + let alloc = Bump::new(); + let result: &[Option] = + min_aggregate_column_by_index_counts(&alloc, &column_a, counts, indexes); + assert_eq!(result, expected); +} + +#[test] +fn we_can_min_aggregate_columns_by_counts() { + let slice_a = &[ + 100, -101, 102, -103, 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + ]; + let slice_b = &[ + 100, -101, 102, -103, 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + ]; + let slice_c = &[ + 100, -101, 102, -103, 104, -105, 106, -107, 108, -109, 110, -111, 112, -113, 114, -115, + ]; + let scals_c: Vec = slice_c.iter().map(|s| s.into()).collect(); + let column_a = Column::BigInt::(slice_a); + let columns_b = Column::Int128::(slice_b); + let columns_c = Column::Scalar(&scals_c); + let indexes = &[12, 11, 1, 10, 2, 3, 6, 14, 13, 9]; + let counts = &[3, 3, 4, 0]; + let expected = &[ + Some(Curve25519Scalar::from(-111)), + Some(Curve25519Scalar::from(-103)), + Some(Curve25519Scalar::from(-113)), + None, + ]; + let alloc = Bump::new(); + let result = min_aggregate_column_by_index_counts(&alloc, &column_a, counts, indexes); + assert_eq!(result, expected); + let result = min_aggregate_column_by_index_counts(&alloc, &columns_b, counts, indexes); + assert_eq!(result, expected); + let result = min_aggregate_column_by_index_counts(&alloc, &columns_c, counts, indexes); + assert_eq!(result, expected); +} diff --git a/crates/proof-of-sql/src/sql/ast/group_by_expr.rs b/crates/proof-of-sql/src/sql/ast/group_by_expr.rs index 48019a115..8f3aa95c3 100644 --- a/crates/proof-of-sql/src/sql/ast/group_by_expr.rs +++ b/crates/proof-of-sql/src/sql/ast/group_by_expr.rs @@ -232,7 +232,8 @@ impl ProverEvaluate for GroupByExpr { group_by_columns: group_by_result_columns, sum_columns: sum_result_columns, count_column, - } = aggregate_columns(alloc, &group_by_columns, &sum_columns, selection) + .. + } = aggregate_columns(alloc, &group_by_columns, &sum_columns, &[], &[], selection) .expect("columns should be aggregatable"); // 3. set indexes builder.set_result_indexes(Indexes::Dense(0..(count_column.len() as u64))); @@ -278,7 +279,8 @@ impl ProverEvaluate for GroupByExpr { group_by_columns: group_by_result_columns, sum_columns: sum_result_columns, count_column, - } = aggregate_columns(alloc, &group_by_columns, &sum_columns, selection) + .. + } = aggregate_columns(alloc, &group_by_columns, &sum_columns, &[], &[], selection) .expect("columns should be aggregatable"); let alpha = builder.consume_post_result_challenge();