From 723a595528e945c0ebc59a62ece2e24e90627764 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 18 Jul 2024 16:32:09 -0400 Subject: [PATCH] Minor: avoid a clone in type coercion (#11530) * Minor: avoid a clone in type coercion * Fix test --- .../optimizer/src/analyzer/type_coercion.rs | 18 ++++++++---------- datafusion/sqllogictest/test_files/misc.slt | 4 ++++ 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 337492d1a55b..50fb1b8193ce 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -84,7 +84,7 @@ impl AnalyzerRule for TypeCoercion { /// Assumes that children have already been optimized fn analyze_internal( external_schema: &DFSchema, - mut plan: LogicalPlan, + plan: LogicalPlan, ) -> Result> { // get schema representing all available input fields. This is used for data type // resolution only, so order does not matter here @@ -103,15 +103,13 @@ fn analyze_internal( // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where t2.c2=t1.c3) schema.merge(external_schema); - if let LogicalPlan::Filter(filter) = &mut plan { - if let Ok(new_predicate) = filter - .predicate - .clone() - .cast_to(&DataType::Boolean, filter.input.schema()) - { - filter.predicate = new_predicate; - } - } + // Coerce filter predicates to boolean (handles `WHERE NULL`) + let plan = if let LogicalPlan::Filter(mut filter) = plan { + filter.predicate = filter.predicate.cast_to(&DataType::Boolean, &schema)?; + LogicalPlan::Filter(filter) + } else { + plan + }; let mut expr_rewrite = TypeCoercionRewriter::new(&schema); diff --git a/datafusion/sqllogictest/test_files/misc.slt b/datafusion/sqllogictest/test_files/misc.slt index 9f4710eb9bcc..9bd3023b56f7 100644 --- a/datafusion/sqllogictest/test_files/misc.slt +++ b/datafusion/sqllogictest/test_files/misc.slt @@ -30,6 +30,10 @@ query I select 1 where NULL ---- +# Where clause does not accept non boolean and has nice error message +query error Cannot create filter with non\-boolean predicate 'Utf8\("foo"\)' returning Utf8 +select 1 where 'foo' + query I select 1 where NULL and 1 = 1 ----