Skip to content

Commit

Permalink
feat: add MAX and MIN to AggregateColumns (#94)
Browse files Browse the repository at this point in the history
# Rationale for this change
We need to add `max` and `min` to `AggregateColumns` for postprocessing
even though we don't make them provable just yet.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked Jira ticket then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->

# What changes are included in this PR?
- add max and min to `AggregateColumns`
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->

# Are these changes tested?
Yes.
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
  • Loading branch information
iajoiner authored Aug 9, 2024
1 parent 268eba4 commit 810949b
Show file tree
Hide file tree
Showing 3 changed files with 504 additions and 19 deletions.
106 changes: 93 additions & 13 deletions crates/proof-of-sql/src/base/database/group_by_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@ pub struct AggregatedColumns<'a, S: Scalar> {
pub group_by_columns: Vec<Column<'a, S>>,
/// Resulting sums of the groups for the columns in `sum_columns_in`.
pub sum_columns: Vec<&'a [S]>,
/// Resulting maxima of the groups for the columns in `max_columns_in`. Note that for empty groups
/// the result will be `None`.
pub max_columns: Vec<&'a [Option<S>]>,
/// Resulting minima of the groups for the columns in `min_columns_in`. Note that for empty groups
/// the result will be `None`.
pub min_columns: Vec<&'a [Option<S>]>,
/// The number of rows in each group.
pub count_column: &'a [i64],
}
Expand All @@ -28,7 +34,8 @@ pub enum AggregateColumnsError {

/// This is a function that gives the result of a group by query similar to the following:
/// ```sql
/// SELECT <group_by[0]>, <group_by[1]>, ..., SUM(<sum_columns[0]>), SUM(<sum_columns[1]>), ..., COUNT(*)
/// SELECT <group_by[0]>, <group_by[1]>, ..., SUM(<sum_columns[0]>), SUM(<sum_columns[1]>), ...,
/// MAX(<max_columns[0]>), ..., MIN(<min_columns[0]>), ..., COUNT(*)
/// WHERE selection GROUP BY <group_by[0]>, <group_by[1]>, ...
/// ```
///
Expand All @@ -38,17 +45,20 @@ pub fn aggregate_columns<'a, S: Scalar>(
alloc: &'a Bump,
group_by_columns_in: &[Column<'a, S>],
sum_columns_in: &[Column<S>],
max_columns_in: &[Column<S>],
min_columns_in: &[Column<S>],
selection_column_in: &[bool],
) -> Result<AggregatedColumns<'a, S>, AggregateColumnsError> {
for col in group_by_columns_in {
if col.len() != selection_column_in.len() {
return Err(AggregateColumnsError::ColumnLengthMismatch);
}
}
for col in sum_columns_in {
if col.len() != selection_column_in.len() {
return Err(AggregateColumnsError::ColumnLengthMismatch);
}
// Check that all the columns have the same length
let len = selection_column_in.len();
if group_by_columns_in
.iter()
.chain(sum_columns_in.iter())
.chain(max_columns_in.iter())
.chain(min_columns_in.iter())
.any(|col| col.len() != len)
{
return Err(AggregateColumnsError::ColumnLengthMismatch);
}

// `filtered_indexes` is a vector of indexes of the rows that are selected. We sort this vector
Expand Down Expand Up @@ -84,12 +94,22 @@ pub fn aggregate_columns<'a, S: Scalar>(
sum_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes)
}));

let max_columns_out = Vec::from_iter(max_columns_in.iter().map(|column| {
max_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes)
}));

let min_columns_out = Vec::from_iter(min_columns_in.iter().map(|column| {
min_aggregate_column_by_index_counts(alloc, column, &counts, &filtered_indexes)
}));

// Cast the counts to something compatible with BigInt.
let count_column_out = alloc.alloc_slice_fill_iter(counts.into_iter().map(|c| c as i64));

Ok(AggregatedColumns {
group_by_columns: group_by_columns_out,
sum_columns: sum_columns_out,
max_columns: max_columns_out,
min_columns: min_columns_out,
count_column: count_column_out,
})
}
Expand Down Expand Up @@ -121,6 +141,68 @@ pub(crate) fn sum_aggregate_column_by_index_counts<'a, S: Scalar>(
}
}

/// Returns a slice with the lifetime of `alloc` that contains the grouped maxima of `column`.
/// The `counts` slice contains the number of elements in each group and the `indexes` slice
/// contains the indexes of the elements in `column`.
///
/// See [`max_aggregate_slice_by_index_counts`] for an example. This is a helper wrapper around that function.
pub(crate) fn max_aggregate_column_by_index_counts<'a, S: Scalar>(
alloc: &'a Bump,
column: &Column<S>,
counts: &[usize],
indexes: &[usize],
) -> &'a [Option<S>] {
match column {
Column::Boolean(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::SmallInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::BigInt(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int128(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Decimal75(_, _, col) => {
max_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::TimestampTZ(_, _, col) => {
max_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::Scalar(col) => max_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
// The following should never be reached because the `MAX` function can't be applied to varchar.
Column::VarChar(_) => {
unreachable!("MAX can not be applied to varchar")
}
}
}

/// Returns a slice with the lifetime of `alloc` that contains the grouped minima of `column`.
/// The `counts` slice contains the number of elements in each group and the `indexes` slice
/// contains the indexes of the elements in `column`.
///
/// See [`min_aggregate_slice_by_index_counts`] for an example. This is a helper wrapper around that function.
pub(crate) fn min_aggregate_column_by_index_counts<'a, S: Scalar>(
alloc: &'a Bump,
column: &Column<S>,
counts: &[usize],
indexes: &[usize],
) -> &'a [Option<S>] {
match column {
Column::Boolean(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::SmallInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::BigInt(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Int128(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
Column::Decimal75(_, _, col) => {
min_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::TimestampTZ(_, _, col) => {
min_aggregate_slice_by_index_counts(alloc, col, counts, indexes)
}
Column::Scalar(col) => min_aggregate_slice_by_index_counts(alloc, col, counts, indexes),
// The following should never be reached because the `MIN` function can't be applied to varchar.
Column::VarChar(_) => {
unreachable!("MIN can not be applied to varchar")
}
}
}

/// Returns a slice with the lifetime of `alloc` that contains the grouped sums of `slice`.
/// The `counts` slice contains the number of elements in each group and the `indexes` slice
/// contains the indexes of the elements in `slice`.
Expand Down Expand Up @@ -182,7 +264,6 @@ where
/// let result = max_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes);
/// assert_eq!(result, expected);
/// ```
#[allow(dead_code)]
pub(crate) fn max_aggregate_slice_by_index_counts<'a, S, T>(
alloc: &'a Bump,
slice: &[T],
Expand Down Expand Up @@ -226,7 +307,6 @@ where
/// let result = min_aggregate_slice_by_index_counts(&alloc, slice_a, counts, indexes);
/// assert_eq!(result, expected);
/// ```
#[allow(dead_code)]
pub(crate) fn min_aggregate_slice_by_index_counts<'a, S, T>(
alloc: &'a Bump,
slice: &[T],
Expand All @@ -243,7 +323,7 @@ where
indexes[start..index]
.iter()
.map(|i| S::from(&slice[*i]))
.max_by(|x, y| x.signed_cmp(y))
.min_by(|x, y| x.signed_cmp(y))
}))
}

Expand Down
Loading

0 comments on commit 810949b

Please sign in to comment.