diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index de14d3a01037..6237d3e9bcf4 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -2849,6 +2849,50 @@ impl ScalarValue { ScalarValue::from(value).cast_to(target_type) } + /// Returns the Some(`&str`) representation of `ScalarValue` of logical string type + /// + /// Returns `None` if this `ScalarValue` is not a logical string type or the + /// `ScalarValue` represents the `NULL` value. + /// + /// Note you can use [`Option::flatten`] to check for non null logical + /// strings. + /// + /// For example, [`ScalarValue::Utf8`], [`ScalarValue::LargeUtf8`], and + /// [`ScalarValue::Dictionary`] with a logical string value and store + /// strings and can be accessed as `&str` using this method. + /// + /// # Example: logical strings + /// ``` + /// # use datafusion_common::ScalarValue; + /// /// non strings return None + /// let scalar = ScalarValue::from(42); + /// assert_eq!(scalar.try_as_str(), None); + /// // Non null logical string returns Some(Some(&str)) + /// let scalar = ScalarValue::from("hello"); + /// assert_eq!(scalar.try_as_str(), Some(Some("hello"))); + /// // Null logical string returns Some(None) + /// let scalar = ScalarValue::Utf8(None); + /// assert_eq!(scalar.try_as_str(), Some(None)); + /// ``` + /// + /// # Example: use [`Option::flatten`] to check for non-null logical strings + /// ``` + /// # use datafusion_common::ScalarValue; + /// // Non null logical string returns Some(Some(&str)) + /// let scalar = ScalarValue::from("hello"); + /// assert_eq!(scalar.try_as_str().flatten(), Some("hello")); + /// ``` + pub fn try_as_str(&self) -> Option> { + let v = match self { + ScalarValue::Utf8(v) => v, + ScalarValue::LargeUtf8(v) => v, + ScalarValue::Utf8View(v) => v, + ScalarValue::Dictionary(_, v) => return v.try_as_str(), + _ => return None, + }; + Some(v.as_ref().map(|v| v.as_str())) + } + /// Try to cast this value to a ScalarValue of type `data_type` pub fn cast_to(&self, target_type: &DataType) -> Result { self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS) diff --git a/datafusion/core/tests/sql/path_partition.rs b/datafusion/core/tests/sql/path_partition.rs index 441af1639d9b..c4fa4c509aa8 100644 --- a/datafusion/core/tests/sql/path_partition.rs +++ b/datafusion/core/tests/sql/path_partition.rs @@ -218,10 +218,11 @@ async fn parquet_distinct_partition_col() -> Result<()> { assert_eq!(min_limit, resulting_limit); let s = ScalarValue::try_from_array(results[0].column(1), 0)?; - let month = match extract_as_utf(&s) { - Some(month) => month, - s => panic!("Expected month as Dict(_, Utf8) found {s:?}"), - }; + assert!( + matches!(s.data_type(), DataType::Dictionary(_, v) if v.as_ref() == &DataType::Utf8), + "Expected month as Dict(_, Utf8) found {s:?}" + ); + let month = s.try_as_str().flatten().unwrap(); let sql_on_partition_boundary = format!( "SELECT month from t where month = '{}' LIMIT {}", @@ -241,15 +242,6 @@ async fn parquet_distinct_partition_col() -> Result<()> { Ok(()) } -fn extract_as_utf(v: &ScalarValue) -> Option { - if let ScalarValue::Dictionary(_, v) = v { - if let ScalarValue::Utf8(v) = v.as_ref() { - return v.clone(); - } - } - None -} - #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { let ctx = SessionContext::new_with_config( diff --git a/datafusion/functions-aggregate/src/string_agg.rs b/datafusion/functions-aggregate/src/string_agg.rs index 7643b44e11d5..0cd403cff428 100644 --- a/datafusion/functions-aggregate/src/string_agg.rs +++ b/datafusion/functions-aggregate/src/string_agg.rs @@ -108,15 +108,14 @@ impl AggregateUDFImpl for StringAgg { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::() { - return match lit.value() { - ScalarValue::Utf8(Some(delimiter)) - | ScalarValue::LargeUtf8(Some(delimiter)) => { - Ok(Box::new(StringAggAccumulator::new(delimiter.as_str()))) + return match lit.value().try_as_str() { + Some(Some(delimiter)) => { + Ok(Box::new(StringAggAccumulator::new(delimiter))) + } + Some(None) => Ok(Box::new(StringAggAccumulator::new(""))), + None => { + not_impl_err!("StringAgg not supported for delimiter {}", lit.value()) } - ScalarValue::Utf8(None) - | ScalarValue::LargeUtf8(None) - | ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))), - e => not_impl_err!("StringAgg not supported for delimiter {}", e), }; } diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index 74dc5d517c2b..860c68bc93f4 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -121,11 +121,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result { ); } let digest_algorithm = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8View(Some(method)) - | ScalarValue::Utf8(Some(method)) - | ScalarValue::LargeUtf8(Some(method)) => method.parse::(), - other => exec_err!("Unsupported data type {other:?} for function digest"), + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(method)) => method.parse::(), + _ => exec_err!("Unsupported data type {scalar:?} for function digest"), }, ColumnarValue::Array(_) => { internal_err!("Digest using dynamically decided method is not yet supported") diff --git a/datafusion/functions/src/datetime/common.rs b/datafusion/functions/src/datetime/common.rs index c674ae09ecb3..fd9f37d8052c 100644 --- a/datafusion/functions/src/datetime/common.rs +++ b/datafusion/functions/src/datetime/common.rs @@ -211,14 +211,12 @@ where ))), other => exec_err!("Unsupported data type {other:?} for function {name}"), }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8View(a) - | ScalarValue::LargeUtf8(a) - | ScalarValue::Utf8(a) => { + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(a) => { let result = a.as_ref().map(|x| op(x)).transpose()?; Ok(ColumnarValue::Scalar(S::scalar(result))) } - other => exec_err!("Unsupported data type {other:?} for function {name}"), + _ => exec_err!("Unsupported data type {scalar:?} for function {name}"), }, } } @@ -270,10 +268,8 @@ where } }, // if the first argument is a scalar utf8 all arguments are expected to be scalar utf8 - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8View(a) - | ScalarValue::LargeUtf8(a) - | ScalarValue::Utf8(a) => { + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(a) => { let a = a.as_ref(); // ASK: Why do we trust `a` to be non-null at this point? let a = unwrap_or_internal_err!(a); @@ -291,7 +287,7 @@ where }; if let Some(s) = x { - match op(a.as_str(), s.as_str()) { + match op(a, s.as_str()) { Ok(r) => { ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some( op2(r), @@ -408,19 +404,10 @@ where DataType::Utf8 => Ok(a.as_string::().value(pos)), other => exec_err!("Unexpected type encountered '{other}'"), }, - ColumnarValue::Scalar(s) => match s { - ScalarValue::Utf8View(a) - | ScalarValue::LargeUtf8(a) - | ScalarValue::Utf8(a) => { - if let Some(v) = a { - Ok(v.as_str()) - } else { - continue; - } - } - other => { - exec_err!("Unexpected scalar type encountered '{other}'") - } + ColumnarValue::Scalar(s) => match s.try_as_str() { + Some(Some(v)) => Ok(v), + Some(None) => continue, // null string + None => exec_err!("Unexpected scalar type encountered '{s}'"), }, }?; diff --git a/datafusion/functions/src/encoding/inner.rs b/datafusion/functions/src/encoding/inner.rs index 31a2ce0f83fd..a5338ff76592 100644 --- a/datafusion/functions/src/encoding/inner.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -546,12 +546,10 @@ fn encode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(method)) => method.parse::(), _ => not_impl_err!( - "Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported" + "Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}" ), }, ColumnarValue::Array(_) => not_impl_err!( @@ -572,12 +570,10 @@ fn decode(args: &[ColumnarValue]) -> Result { ); } let encoding = match &args[1] { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => { - method.parse::() - } + ColumnarValue::Scalar(scalar) => match scalar.try_as_str() { + Some(Some(method))=> method.parse::(), _ => not_impl_err!( - "Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported" + "Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}" ), }, ColumnarValue::Array(_) => not_impl_err!( diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 87ac979bc057..9ce732efa0c7 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -134,18 +134,16 @@ impl ScalarUDFImpl for ConcatFunc { if array_len.is_none() { let mut result = String::new(); for arg in args { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => { - result.push_str(v); - } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} - other => plan_err!( + let ColumnarValue::Scalar(scalar) = arg else { + return internal_err!("concat expected scalar value, got {arg:?}"); + }; + + match scalar.try_as_str() { + Some(Some(v)) => result.push_str(v), + Some(None) => {} // null literal + None => plan_err!( "Concat function does not support scalar type {:?}", - other + scalar )?, } } diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 821bb48a8ce2..12556cfec0bc 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -124,48 +124,54 @@ impl ScalarUDFImpl for ConcatWsFunc { // Scalar if array_len.is_none() { - let sep = match &args[0] { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s, - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => { + let ColumnarValue::Scalar(scalar) = &args[0] else { + // loop above checks for all args being scalar + unreachable!() + }; + let sep = match scalar.try_as_str() { + Some(Some(s)) => s, + Some(None) => { + // null literal string return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None))); } - _ => unreachable!(), + None => return internal_err!("Expected string literal, got {scalar:?}"), }; let mut result = String::new(); - let iter = &mut args[1..].iter(); - - for arg in iter.by_ref() { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { + // iterator over Option + let iter = &mut args[1..].iter().map(|arg| { + let ColumnarValue::Scalar(scalar) = arg else { + // loop above checks for all args being scalar + unreachable!() + }; + scalar.try_as_str() + }); + + // append first non null arg + for scalar in iter.by_ref() { + match scalar { + Some(Some(s)) => { result.push_str(s); break; } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} - _ => unreachable!(), + Some(None) => {} // null literal string + None => { + return internal_err!("Expected string literal, got {scalar:?}") + } } } - for arg in iter.by_ref() { - match arg { - ColumnarValue::Scalar(ScalarValue::Utf8(Some(s))) - | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s))) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => { + // handle subsequent non null args + for scalar in iter.by_ref() { + match scalar { + Some(Some(s)) => { result.push_str(sep); result.push_str(s); } - ColumnarValue::Scalar(ScalarValue::Utf8(None)) - | ColumnarValue::Scalar(ScalarValue::Utf8View(None)) - | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {} - _ => unreachable!(), + Some(None) => {} // null literal string + None => { + return internal_err!("Expected string literal, got {scalar:?}") + } } } diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 8cba2c88e244..892d450ba85b 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -475,12 +475,7 @@ fn try_cast_string_literal( lit_value: &ScalarValue, target_type: &DataType, ) -> Option { - let string_value = match lit_value { - ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => { - s.clone() - } - _ => return None, - }; + let string_value = lit_value.try_as_str()?.map(|s| s.to_string()); let scalar_value = match target_type { DataType::Utf8 => ScalarValue::Utf8(string_value), DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value), diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 2ab53b214d7f..8aa45063c84f 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -251,22 +251,13 @@ macro_rules! compute_utf8_flag_op_scalar { .downcast_ref::<$ARRAYTYPE>() .expect("compute_utf8_flag_op_scalar failed to downcast array"); - let string_value = match $RIGHT { - ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, - ScalarValue::Dictionary(_, value) => { - match *value { - ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value, - other => return internal_err!( - "compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'", - other, stringify!($OP) - ) - } - }, + let string_value = match $RIGHT.try_as_str() { + Some(Some(string_value)) => string_value, + // null literal or non string _ => return internal_err!( - "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", - $RIGHT, stringify!($OP) - ) - + "compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'", + $RIGHT, stringify!($OP) + ) }; let flag = $FLAG.then_some("i"); diff --git a/datafusion/physical-optimizer/src/pruning.rs b/datafusion/physical-optimizer/src/pruning.rs index c16ed306efdf..30c6e7fb4b32 100644 --- a/datafusion/physical-optimizer/src/pruning.rs +++ b/datafusion/physical-optimizer/src/pruning.rs @@ -1639,17 +1639,12 @@ fn build_like_match( // column LIKE '%foo%' => min <= '' && '' <= max => true // column LIKE 'foo' => min <= 'foo' && 'foo' <= max - fn unpack_string(s: &ScalarValue) -> Option<&String> { - match s { - ScalarValue::Utf8(Some(s)) => Some(s), - ScalarValue::LargeUtf8(Some(s)) => Some(s), - ScalarValue::Utf8View(Some(s)) => Some(s), - ScalarValue::Dictionary(_, value) => unpack_string(value), - _ => None, - } + /// returns the string literal of the scalar value if it is a string + fn unpack_string(s: &ScalarValue) -> Option<&str> { + s.try_as_str().flatten() } - fn extract_string_literal(expr: &Arc) -> Option<&String> { + fn extract_string_literal(expr: &Arc) -> Option<&str> { if let Some(lit) = expr.as_any().downcast_ref::() { let s = unpack_string(lit.value())?; return Some(s); @@ -1681,7 +1676,9 @@ fn build_like_match( (lower_bound_lit, upper_bound_lit) } else { // the like expression is a literal and can be converted into a comparison - let bound = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some(s.clone())))); + let bound = Arc::new(phys_expr::Literal::new(ScalarValue::Utf8(Some( + s.to_string(), + )))); (Arc::clone(&bound), bound) }; let lower_bound_expr = Arc::new(phys_expr::BinaryExpr::new(