Skip to content

Commit

Permalink
refactor!: split and simplify owned_column_operation.rs && remove u…
Browse files Browse the repository at this point in the history
…nused functions from `slice_operation.rs` (#359)

Please be sure to look over the pull request guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md#submit-pr.

# Please go through the following checklist
- [x] The PR title and commit messages adhere to guidelines here:
https://github.com/spaceandtimelabs/sxt-proof-of-sql/blob/main/CONTRIBUTING.md.
In particular `!` is used if and only if at least one breaking change
has been introduced.
- [x] I have run the ci check script with `source
scripts/run_ci_checks.sh`.

# Rationale for this change
`owned_column_operation.rs` is extremely tedious hence we need to
simplify it.
<!--
Why are you proposing this change? If this is already explained clearly
in the linked issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.

 Example:
 Add `NestedLoopJoinExec`.
 Closes #345.

Since we added `HashJoinExec` in #323 it has been possible to do
provable inner joins. However performance is not satisfactory in some
cases. Hence we need to fix the problem by implement
`NestedLoopJoinExec` and speed up the code
 for `HashJoinExec`.
-->

# What changes are included in this PR?
- split out arithmetic operations into `column_arithmetic_operation.rs`
and unify them.
- split out comparison operations into `column_comparison_operation.rs`
and unify them.
- remove unused functions.
<!--
There is no need to duplicate the description in the ticket here but it
is sometimes worth providing a summary of the individual changes in this
PR.

Example:
- Add `NestedLoopJoinExec`.
- Speed up `HashJoinExec`.
- Route joins to `NestedLoopJoinExec` if the outer input is sufficiently
small.
-->

# Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code

If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?

Example:
Yes.
-->
Yes.
  • Loading branch information
iajoiner authored Nov 12, 2024
2 parents a0ffccb + 11596ef commit c60c143
Show file tree
Hide file tree
Showing 6 changed files with 787 additions and 1,630 deletions.
290 changes: 290 additions & 0 deletions crates/proof-of-sql/src/base/database/column_arithmetic_operation.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
use super::{ColumnOperationError, ColumnOperationResult};
use crate::base::{
database::{
slice_decimal_operation::{
try_add_decimal_columns, try_divide_decimal_columns, try_multiply_decimal_columns,
try_subtract_decimal_columns,
},
slice_operation::{
try_add, try_div, try_mul, try_slice_binary_op, try_slice_binary_op_left_upcast,
try_slice_binary_op_right_upcast, try_sub,
},
ColumnType, OwnedColumn,
},
math::decimal::Precision,
scalar::Scalar,
};
use alloc::{string::ToString, vec::Vec};
use core::fmt::Debug;
use num_bigint::BigInt;
use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub};

pub trait ArithmeticOp {
fn op<T>(l: &T, r: &T) -> ColumnOperationResult<T>
where
T: Debug + CheckedDiv + CheckedMul + CheckedAdd + CheckedSub;
fn decimal_op<S, T0, T1>(
lhs: &[T0],
rhs: &[T1],
left_column_type: ColumnType,
right_column_type: ColumnType,
) -> ColumnOperationResult<(Precision, i8, Vec<S>)>
where
S: Scalar + From<T0> + From<T1>,
T0: Copy + Debug + Into<BigInt>,
T1: Copy + Debug + Into<BigInt>;

#[allow(clippy::too_many_lines)]
fn owned_column_element_wise_arithmetic<S: Scalar>(
lhs: &OwnedColumn<S>,
rhs: &OwnedColumn<S>,
) -> ColumnOperationResult<OwnedColumn<S>> {
if lhs.len() != rhs.len() {
return Err(ColumnOperationError::DifferentColumnLength {
len_a: lhs.len(),
len_b: rhs.len(),
});
}
match (&lhs, &rhs) {
(OwnedColumn::TinyInt(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::TinyInt(
try_slice_binary_op(lhs, rhs, Self::op)?,
)),
(OwnedColumn::TinyInt(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::SmallInt(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::TinyInt(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::Int(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::TinyInt(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::TinyInt(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::TinyInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}

(OwnedColumn::SmallInt(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::SmallInt(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::SmallInt(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::SmallInt(
try_slice_binary_op(lhs, rhs, Self::op)?,
)),
(OwnedColumn::SmallInt(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::Int(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::SmallInt(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::SmallInt(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::SmallInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}

(OwnedColumn::Int(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::Int(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::Int(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int(lhs), OwnedColumn::Int(rhs)) => {
Ok(OwnedColumn::Int(try_slice_binary_op(lhs, rhs, Self::op)?))
}
(OwnedColumn::Int(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}

(OwnedColumn::BigInt(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::BigInt(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::BigInt(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::BigInt(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::BigInt(
try_slice_binary_op(lhs, rhs, Self::op)?,
)),
(OwnedColumn::BigInt(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_left_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::BigInt(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}

(OwnedColumn::Int128(lhs), OwnedColumn::TinyInt(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int128(lhs), OwnedColumn::SmallInt(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int128(lhs), OwnedColumn::Int(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int128(lhs), OwnedColumn::BigInt(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op_right_upcast(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int128(lhs), OwnedColumn::Int128(rhs)) => Ok(OwnedColumn::Int128(
try_slice_binary_op(lhs, rhs, Self::op)?,
)),
(OwnedColumn::Int128(lhs_values), OwnedColumn::Decimal75(_, _, rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}

(OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::TinyInt(rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}
(OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::SmallInt(rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}
(OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::Int(rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}
(OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::BigInt(rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}
(OwnedColumn::Decimal75(_, _, lhs_values), OwnedColumn::Int128(rhs_values)) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}
(
OwnedColumn::Decimal75(_, _, lhs_values),
OwnedColumn::Decimal75(_, _, rhs_values),
) => {
let (new_precision, new_scale, new_values) =
Self::decimal_op(lhs_values, rhs_values, lhs.column_type(), rhs.column_type())?;
Ok(OwnedColumn::Decimal75(new_precision, new_scale, new_values))
}
_ => Err(ColumnOperationError::BinaryOperationInvalidColumnType {
operator: "ArithmeticOp".to_string(),
left_type: lhs.column_type(),
right_type: rhs.column_type(),
}),
}
}
}

pub struct AddOp {}
impl ArithmeticOp for AddOp {
fn op<T>(l: &T, r: &T) -> ColumnOperationResult<T>
where
T: CheckedAdd + Debug,
{
try_add(l, r)
}

fn decimal_op<S, T0, T1>(
lhs: &[T0],
rhs: &[T1],
left_column_type: ColumnType,
right_column_type: ColumnType,
) -> ColumnOperationResult<(Precision, i8, Vec<S>)>
where
S: Scalar + From<T0> + From<T1>,
T0: Copy,
T1: Copy,
{
try_add_decimal_columns(lhs, rhs, left_column_type, right_column_type)
}
}

pub struct SubOp {}
impl ArithmeticOp for SubOp {
fn op<T>(l: &T, r: &T) -> ColumnOperationResult<T>
where
T: CheckedSub + Debug,
{
try_sub(l, r)
}

fn decimal_op<S, T0, T1>(
lhs: &[T0],
rhs: &[T1],
left_column_type: ColumnType,
right_column_type: ColumnType,
) -> ColumnOperationResult<(Precision, i8, Vec<S>)>
where
S: Scalar + From<T0> + From<T1>,
T0: Copy,
T1: Copy,
{
try_subtract_decimal_columns(lhs, rhs, left_column_type, right_column_type)
}
}

pub struct MulOp {}
impl ArithmeticOp for MulOp {
fn op<T>(l: &T, r: &T) -> ColumnOperationResult<T>
where
T: CheckedMul + Debug,
{
try_mul(l, r)
}

fn decimal_op<S, T0, T1>(
lhs: &[T0],
rhs: &[T1],
left_column_type: ColumnType,
right_column_type: ColumnType,
) -> ColumnOperationResult<(Precision, i8, Vec<S>)>
where
S: Scalar + From<T0> + From<T1>,
T0: Copy,
T1: Copy,
{
try_multiply_decimal_columns(lhs, rhs, left_column_type, right_column_type)
}
}

pub struct DivOp {}
impl ArithmeticOp for DivOp {
fn op<T>(l: &T, r: &T) -> ColumnOperationResult<T>
where
T: CheckedDiv + Debug,
{
try_div(l, r)
}

fn decimal_op<S, T0, T1>(
lhs: &[T0],
rhs: &[T1],
left_column_type: ColumnType,
right_column_type: ColumnType,
) -> ColumnOperationResult<(Precision, i8, Vec<S>)>
where
S: Scalar + From<T0> + From<T1>,
T0: Copy + Debug + Into<BigInt>,
T1: Copy + Debug + Into<BigInt>,
{
try_divide_decimal_columns(lhs, rhs, left_column_type, right_column_type)
}
}
Loading

0 comments on commit c60c143

Please sign in to comment.