From c20229fe3427d12bf9e5dcc7259bd6d694f4b685 Mon Sep 17 00:00:00 2001 From: Kould <2435992353@qq.com> Date: Sun, 3 Mar 2024 20:45:46 +0800 Subject: [PATCH] fix: range detection error when nesting `and` & `or` (#150) --- Cargo.toml | 4 +- src/expression/simplify.rs | 268 ++++++++++++++++-- .../rule/normalization/simplification.rs | 201 +++++-------- tests/slt/dummy.slt | 5 + tests/slt/where_by_index.slt | 11 +- 5 files changed, 333 insertions(+), 156 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1173be7c..74230488 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ [package] name = "fnck_sql" -version = "0.0.1-alpha.11.fix2" +version = "0.0.1-alpha.11.fix3" edition = "2021" authors = ["Kould ", "Xwg "] description = "Fast Insert OLTP SQL DBMS" @@ -52,7 +52,7 @@ ahash = "0.8.3" lazy_static = "1.4.0" comfy-table = "7.0.1" bytes = "1.5.0" -kip_db = "0.1.2-alpha.25" +kip_db = "0.1.2-alpha.25.fix1" rust_decimal = "1" csv = "1" regex = "1.10.2" diff --git a/src/expression/simplify.rs b/src/expression/simplify.rs index dc684af9..24779e1d 100644 --- a/src/expression/simplify.rs +++ b/src/expression/simplify.rs @@ -21,9 +21,7 @@ pub enum ConstantBinary { Eq(ValueRef), NotEq(ValueRef), - // ConstantBinary in And can only be Scope\Eq\NotEq And(Vec), - // ConstantBinary in Or can only be Scope\Eq\NotEq\And Or(Vec), } @@ -147,26 +145,32 @@ impl ConstantBinary { } pub fn scope_aggregation(&mut self) -> Result<(), DatabaseError> { + // Tips: Only single-level `And` and `Or` match self { - // `Or` is allowed to contain And, `Scope`, `Eq/NotEq` - // Tips: Only single-level `And` + // `Or` is allowed to contain `And`, `Scope`, `Eq/NotEq` ConstantBinary::Or(binaries) => { - for binary in mem::take(binaries) { + for mut binary in mem::take(binaries) { + binary.scope_aggregation()?; match binary { - ConstantBinary::And(mut and_binaries) => { - Self::and_scope_aggregation(&mut and_binaries)?; - binaries.append(&mut and_binaries); - } - ConstantBinary::Or(_) => { - unreachable!("`Or` does not allow nested `Or`") - } + ConstantBinary::And(mut and_binaries) => binaries.append(&mut and_binaries), + ConstantBinary::Or(_) => unreachable!("`Or` does not allow nested `Or`"), binary => binaries.push(binary), } } Self::or_scope_aggregation(binaries); } - // `And` is allowed to contain Scope, `Eq/NotEq` + // `And` is allowed to contain `Or`, Scope, `Eq/NotEq` ConstantBinary::And(binaries) => { + for mut binary in mem::take(binaries) { + binary.scope_aggregation()?; + match binary { + ConstantBinary::And(_) => unreachable!("`And` does not allow nested `And`"), + ConstantBinary::Or(or_binaries) => { + binaries.append(&mut ConstantBinary::Or(or_binaries).rearrange()?); + } + binary => binaries.push(binary), + } + } Self::and_scope_aggregation(binaries)?; } _ => (), @@ -258,7 +262,7 @@ impl ConstantBinary { } } ConstantBinary::Or(_) | ConstantBinary::And(_) => { - return Err(DatabaseError::InvalidType) + unreachable!() } } } @@ -303,12 +307,12 @@ impl ConstantBinary { let mut scope_margin = None; - let sort_op = |binary: &&mut ConstantBinary| match binary { + let sort_op = |binary: &&ConstantBinary| match binary { ConstantBinary::NotEq(_) => 2, ConstantBinary::Eq(_) => 1, _ => 3, }; - for binary in binaries.iter_mut().sorted_by_key(sort_op) { + for binary in binaries.iter().sorted_by_key(sort_op) { if matches!(scope_margin, Some((Bound::Unbounded, Bound::Unbounded))) { break; } @@ -413,10 +417,7 @@ impl ConstantBinary { } *binaries = merge_scopes .into_iter() - .map(|(min, max)| ConstantBinary::Scope { - min: min.clone(), - max: max.clone(), - }) + .map(|(min, max)| ConstantBinary::Scope { min, max }) .chain(eqs.into_iter().map(ConstantBinary::Eq)) .collect_vec(); } @@ -1027,9 +1028,16 @@ impl ScalarExpression { } (ConstantBinary::Or(mut binaries), binary) | (binary, ConstantBinary::Or(mut binaries)) => { - binaries.push(binary); + if op == &BinaryOperator::And { + Ok(Some(ConstantBinary::And(vec![ + binary, + ConstantBinary::Or(binaries), + ]))) + } else { + binaries.push(binary); - Ok(Some(ConstantBinary::Or(binaries))) + Ok(Some(ConstantBinary::Or(binaries))) + } } (left, right) => match op { BinaryOperator::And => Ok(Some(ConstantBinary::And(vec![left, right]))), @@ -1355,7 +1363,7 @@ mod test { } #[test] - fn test_scope_aggregation_eq_noteq_cover() -> Result<(), DatabaseError> { + fn test_scope_aggregation_and_eq_noteq_cover() -> Result<(), DatabaseError> { let val_0 = Arc::new(DataValue::Int32(Some(0))); let val_1 = Arc::new(DataValue::Int32(Some(1))); let val_2 = Arc::new(DataValue::Int32(Some(2))); @@ -1423,7 +1431,7 @@ mod test { } #[test] - fn test_scope_aggregation_mixed() -> Result<(), DatabaseError> { + fn test_scope_aggregation_and_mixed() -> Result<(), DatabaseError> { let val_0 = Arc::new(DataValue::Int32(Some(0))); let val_1 = Arc::new(DataValue::Int32(Some(1))); let val_2 = Arc::new(DataValue::Int32(Some(2))); @@ -1540,6 +1548,218 @@ mod test { Ok(()) } + #[test] + fn test_scope_aggregation_or_eq_noteq_cover() -> Result<(), DatabaseError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Eq(val_0.clone()), + ConstantBinary::NotEq(val_1.clone()), + ConstantBinary::Eq(val_2.clone()), + ConstantBinary::NotEq(val_3.clone()), + ConstantBinary::NotEq(val_0.clone()), + ConstantBinary::NotEq(val_1.clone()), + ConstantBinary::NotEq(val_2.clone()), + ConstantBinary::NotEq(val_3.clone()), + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded + }]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_and_mixed() -> Result<(), DatabaseError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Included(val_2.clone()), + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Excluded(val_2.clone()), + }, + ]), + ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Included(val_3.clone()), + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Excluded(val_3.clone()), + }, + ]), + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Excluded(val_3.clone()), + },]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_and_converse() -> Result<(), DatabaseError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Included(val_3.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_0.clone()), + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!(binary, ConstantBinary::And(vec![])); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_converse() -> Result<(), DatabaseError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Included(val_3.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Included(val_0.clone()), + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Unbounded + }]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_scopes() -> Result<(), DatabaseError> { + let val_0 = Arc::new(DataValue::Int32(Some(0))); + let val_1 = Arc::new(DataValue::Int32(Some(1))); + let val_2 = Arc::new(DataValue::Int32(Some(2))); + let val_3 = Arc::new(DataValue::Int32(Some(3))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_0.clone()), + max: Bound::Included(val_2.clone()), + }, + ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Excluded(val_2.clone()), + }, + ConstantBinary::Scope { + min: Bound::Excluded(val_1.clone()), + max: Bound::Included(val_3.clone()), + }, + ConstantBinary::Scope { + min: Bound::Included(val_1.clone()), + max: Bound::Excluded(val_3.clone()), + }, + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ConstantBinary::Scope { + min: Bound::Included(val_0.clone()), + max: Bound::Included(val_3.clone()), + },]) + ); + + Ok(()) + } + + #[test] + fn test_scope_aggregation_or_and_mixed_1() -> Result<(), DatabaseError> { + let val_5 = Arc::new(DataValue::Int32(Some(5))); + let val_6 = Arc::new(DataValue::Int32(Some(6))); + let val_8 = Arc::new(DataValue::Int32(Some(8))); + let val_12 = Arc::new(DataValue::Int32(Some(12))); + + let mut binary = ConstantBinary::Or(vec![ + ConstantBinary::Eq(val_5.clone()), + ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_5.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_6.clone()), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_8.clone()), + }, + ]), + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(val_12.clone()), + }, + ]), + ]); + + binary.scope_aggregation()?; + + assert_eq!( + binary, + ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(val_5.clone()), + max: Bound::Excluded(val_12.clone()), + }, + ConstantBinary::Eq(val_5) + ]) + ); + + Ok(()) + } + #[test] fn test_scope_aggregation_or_lower_unbounded() -> Result<(), DatabaseError> { let val_0 = Arc::new(DataValue::Int32(Some(2))); diff --git a/src/optimizer/rule/normalization/simplification.rs b/src/optimizer/rule/normalization/simplification.rs index 08e4dc69..6162da14 100644 --- a/src/optimizer/rule/normalization/simplification.rs +++ b/src/optimizer/rule/normalization/simplification.rs @@ -306,6 +306,23 @@ mod test { Ok(()) } + fn plan_filter(plan: LogicalPlan, expr: &str) -> Result, DatabaseError> { + let best_plan = HepOptimizer::new(plan.clone()) + .batch( + "test_simplify_filter".to_string(), + HepBatchStrategy::once_topdown(), + vec![NormalizationRuleImpl::SimplifyFilter], + ) + .find_best::(None)?; + if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { + println!("{expr}: {:#?}", filter_op); + + Ok(Some(filter_op)) + } else { + Ok(None) + } + } + #[tokio::test] async fn test_simplify_filter_multiple_column() -> Result<(), DatabaseError> { // c1 + 1 < -1 => c1 < -2 @@ -319,27 +336,10 @@ mod test { // c1 > 0 let plan_4 = select_sql_run("select * from t1 where c1 + 1 > 1 and -c2 > 1").await?; - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); - - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; - - let op_1 = op(plan_1, "-(c1 + 1) > 1 and -(1 - c2) > 1")?.unwrap(); - let op_2 = op(plan_2, "-(1 - c1) > 1 and -(c2 + 1) > 1")?.unwrap(); - let op_3 = op(plan_3, "-c1 > 1 and c2 + 1 > 1")?.unwrap(); - let op_4 = op(plan_4, "c1 + 1 > 1 and -c2 > 1")?.unwrap(); + let op_1 = plan_filter(plan_1, "-(c1 + 1) > 1 and -(1 - c2) > 1")?.unwrap(); + let op_2 = plan_filter(plan_2, "-(1 - c1) > 1 and -(c2 + 1) > 1")?.unwrap(); + let op_3 = plan_filter(plan_3, "-c1 > 1 and c2 + 1 > 1")?.unwrap(); + let op_4 = plan_filter(plan_4, "c1 + 1 > 1 and -c2 > 1")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); @@ -429,24 +429,7 @@ mod test { // c1 + 1 < -1 => c1 < -2 let plan_1 = select_sql_run("select * from t1 where c1 > c2 or c1 > 1").await?; - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); - - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; - - let op_1 = op(plan_1, "c1 > c2 or c1 > 1")?.unwrap(); + let op_1 = plan_filter(plan_1, "c1 > c2 or c1 > 1")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); @@ -460,24 +443,7 @@ mod test { { let plan_1 = select_sql_run("select * from t1 where c1 = 4 and c1 > c2 or c1 > 1").await?; - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); - - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; - - let op_1 = op(plan_1, "c1 = 4 and c2 > c1 or c1 > 1")?.unwrap(); + let op_1 = plan_filter(plan_1, "c1 = 4 and c2 > c1 or c1 > 1")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); @@ -496,27 +462,55 @@ mod test { } #[tokio::test] - async fn test_simplify_filter_column_is_null() -> Result<(), DatabaseError> { - let plan_1 = select_sql_run("select * from t1 where c1 is null").await?; + async fn test_simplify_filter_and_or_mixed() -> Result<(), DatabaseError> { + let plan_1 = select_sql_run( + "select * from t1 where c1 = 5 or (c1 > 5 and (c1 > 6 or c1 < 8) and c1 < 12)", + ) + .await?; + + let op_1 = plan_filter( + plan_1, + "c1 = 5 or (c1 > 5 and (c1 > 6 or c1 < 8) and c1 < 12)", + )? + .unwrap(); - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); + let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); + println!("op_1 => c1: {:#?}", cb_1_c1); + assert_eq!( + cb_1_c1, + Some(ConstantBinary::Or(vec![ + ConstantBinary::Eq(Arc::new(DataValue::Int32(Some(5)))), + ConstantBinary::And(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(5)))), + max: Bound::Unbounded, + }, + ConstantBinary::Or(vec![ + ConstantBinary::Scope { + min: Bound::Excluded(Arc::new(DataValue::Int32(Some(6)))), + max: Bound::Unbounded, + }, + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(8)))), + } + ]), + ConstantBinary::Scope { + min: Bound::Unbounded, + max: Bound::Excluded(Arc::new(DataValue::Int32(Some(12)))), + } + ]) + ])) + ); - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; + Ok(()) + } - let op_1 = op(plan_1, "c1 is null")?.unwrap(); + #[tokio::test] + async fn test_simplify_filter_column_is_null() -> Result<(), DatabaseError> { + let plan_1 = select_sql_run("select * from t1 where c1 is null").await?; + + let op_1 = plan_filter(plan_1, "c1 is null")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); @@ -529,24 +523,7 @@ mod test { async fn test_simplify_filter_column_is_not_null() -> Result<(), DatabaseError> { let plan_1 = select_sql_run("select * from t1 where c1 is not null").await?; - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); - - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; - - let op_1 = op(plan_1, "c1 is not null")?.unwrap(); + let op_1 = plan_filter(plan_1, "c1 is not null")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); @@ -562,24 +539,7 @@ mod test { async fn test_simplify_filter_column_in() -> Result<(), DatabaseError> { let plan_1 = select_sql_run("select * from t1 where c1 in (1, 2, 3)").await?; - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); - - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; - - let op_1 = op(plan_1, "c1 in (1, 2, 3)")?.unwrap(); + let op_1 = plan_filter(plan_1, "c1 in (1, 2, 3)")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); @@ -599,24 +559,7 @@ mod test { async fn test_simplify_filter_column_not_in() -> Result<(), DatabaseError> { let plan_1 = select_sql_run("select * from t1 where c1 not in (1, 2, 3)").await?; - let op = |plan: LogicalPlan, expr: &str| -> Result, DatabaseError> { - let best_plan = HepOptimizer::new(plan.clone()) - .batch( - "test_simplify_filter".to_string(), - HepBatchStrategy::once_topdown(), - vec![NormalizationRuleImpl::SimplifyFilter], - ) - .find_best::(None)?; - if let Operator::Filter(filter_op) = best_plan.childrens[0].clone().operator { - println!("{expr}: {:#?}", filter_op); - - Ok(Some(filter_op)) - } else { - Ok(None) - } - }; - - let op_1 = op(plan_1, "c1 not in (1, 2, 3)")?.unwrap(); + let op_1 = plan_filter(plan_1, "c1 not in (1, 2, 3)")?.unwrap(); let cb_1_c1 = op_1.predicate.convert_binary("t1", &0).unwrap(); println!("op_1 => c1: {:#?}", cb_1_c1); diff --git a/tests/slt/dummy.slt b/tests/slt/dummy.slt index 7367c309..1ff5e0c0 100644 --- a/tests/slt/dummy.slt +++ b/tests/slt/dummy.slt @@ -11,6 +11,11 @@ SELECT 'a' ---- a +query B +SELECT 1.01=1.01 +---- +true + query B SELECT NULL=NULL ---- diff --git a/tests/slt/where_by_index.slt b/tests/slt/where_by_index.slt index 22d82f85..3afa8160 100644 --- a/tests/slt/where_by_index.slt +++ b/tests/slt/where_by_index.slt @@ -120,4 +120,13 @@ select * from t1 where (id >= 0 or id <= 3) and (id >= 9 or id <= 12) limit 10; 18 19 20 21 22 23 24 25 26 -27 28 29 \ No newline at end of file +27 28 29 + +query IIT +select * from t1 where id = 5 or (id > 5 and (id > 6 or id < 8) and id < 12); +---- +6 7 8 +9 10 11 + +statement ok +drop table t1; \ No newline at end of file