Skip to content

Commit

Permalink
Add ScalarValue::try_as_str to get str value from logical strings
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb committed Jan 17, 2025
1 parent 9403448 commit 3b5f623
Show file tree
Hide file tree
Showing 11 changed files with 132 additions and 129 deletions.
44 changes: 44 additions & 0 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2849,6 +2849,50 @@ impl ScalarValue {
ScalarValue::from(value).cast_to(target_type)
}

/// Returns the Some(`&str`) representation of `ScalarValue` of logical string type
///
/// Returns `None` if this `ScalarValue` is not a logical string type or the
/// `ScalarValue` represents the `NULL` value.
///
/// Note you can use [`Option::flatten`] to check for non null logical
/// strings.
///
/// For example, [`ScalarValue::Utf8`], [`ScalarValue::LargeUtf8`], and
/// [`ScalarValue::Dictionary`] with a logical string value and store
/// strings and can be accessed as `&str` using this method.
///
/// # Example: logical strings
/// ```
/// # use datafusion_common::ScalarValue;
/// /// non strings return None
/// let scalar = ScalarValue::from(42);
/// assert_eq!(scalar.try_as_str(), None);
/// // Non null logical string returns Some(Some(&str))
/// let scalar = ScalarValue::from("hello");
/// assert_eq!(scalar.try_as_str(), Some(Some("hello")));
/// // Null logical string returns Some(None)
/// let scalar = ScalarValue::Utf8(None);
/// assert_eq!(scalar.try_as_str(), Some(None));
/// ```
///
/// # Example: use [`Option::flatten`] to check for non-null logical strings
/// ```
/// # use datafusion_common::ScalarValue;
/// // Non null logical string returns Some(Some(&str))
/// let scalar = ScalarValue::from("hello");
/// assert_eq!(scalar.try_as_str().flatten(), Some("hello"));
/// ```
pub fn try_as_str(&self) -> Option<Option<&str>> {
let v = match self {
ScalarValue::Utf8(v) => v,
ScalarValue::LargeUtf8(v) => v,
ScalarValue::Utf8View(v) => v,
ScalarValue::Dictionary(_, v) => return v.try_as_str(),
_ => return None,
};
Some(v.as_ref().map(|v| v.as_str()))
}

/// Try to cast this value to a ScalarValue of type `data_type`
pub fn cast_to(&self, target_type: &DataType) -> Result<Self> {
self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS)
Expand Down
18 changes: 5 additions & 13 deletions datafusion/core/tests/sql/path_partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,10 +218,11 @@ async fn parquet_distinct_partition_col() -> Result<()> {
assert_eq!(min_limit, resulting_limit);

let s = ScalarValue::try_from_array(results[0].column(1), 0)?;
let month = match extract_as_utf(&s) {
Some(month) => month,
s => panic!("Expected month as Dict(_, Utf8) found {s:?}"),
};
assert!(
matches!(s.data_type(), DataType::Dictionary(_, v) if v.as_ref() == &DataType::Utf8),
"Expected month as Dict(_, Utf8) found {s:?}"
);
let month = s.try_as_str().flatten().unwrap();

let sql_on_partition_boundary = format!(
"SELECT month from t where month = '{}' LIMIT {}",
Expand All @@ -241,15 +242,6 @@ async fn parquet_distinct_partition_col() -> Result<()> {
Ok(())
}

fn extract_as_utf(v: &ScalarValue) -> Option<String> {
if let ScalarValue::Dictionary(_, v) = v {
if let ScalarValue::Utf8(v) = v.as_ref() {
return v.clone();
}
}
None
}

#[tokio::test]
async fn csv_filter_with_file_col() -> Result<()> {
let ctx = SessionContext::new_with_config(
Expand Down
15 changes: 7 additions & 8 deletions datafusion/functions-aggregate/src/string_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,15 +108,14 @@ impl AggregateUDFImpl for StringAgg {

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() {
return match lit.value() {
ScalarValue::Utf8(Some(delimiter))
| ScalarValue::LargeUtf8(Some(delimiter)) => {
Ok(Box::new(StringAggAccumulator::new(delimiter.as_str())))
return match lit.value().try_as_str() {
Some(Some(delimiter)) => {
Ok(Box::new(StringAggAccumulator::new(delimiter)))
}
Some(None) => Ok(Box::new(StringAggAccumulator::new(""))),
None => {
not_impl_err!("StringAgg not supported for delimiter {}", lit.value())
}
ScalarValue::Utf8(None)
| ScalarValue::LargeUtf8(None)
| ScalarValue::Null => Ok(Box::new(StringAggAccumulator::new(""))),
e => not_impl_err!("StringAgg not supported for delimiter {}", e),
};
}

Expand Down
8 changes: 3 additions & 5 deletions datafusion/functions/src/crypto/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,9 @@ pub fn digest(args: &[ColumnarValue]) -> Result<ColumnarValue> {
);
}
let digest_algorithm = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8View(Some(method))
| ScalarValue::Utf8(Some(method))
| ScalarValue::LargeUtf8(Some(method)) => method.parse::<DigestAlgorithm>(),
other => exec_err!("Unsupported data type {other:?} for function digest"),
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method)) => method.parse::<DigestAlgorithm>(),
_ => exec_err!("Unsupported data type {scalar:?} for function digest"),
},
ColumnarValue::Array(_) => {
internal_err!("Digest using dynamically decided method is not yet supported")
Expand Down
33 changes: 10 additions & 23 deletions datafusion/functions/src/datetime/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,12 @@ where
))),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8View(a)
| ScalarValue::LargeUtf8(a)
| ScalarValue::Utf8(a) => {
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(a) => {
let result = a.as_ref().map(|x| op(x)).transpose()?;
Ok(ColumnarValue::Scalar(S::scalar(result)))
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
_ => exec_err!("Unsupported data type {scalar:?} for function {name}"),
},
}
}
Expand Down Expand Up @@ -270,10 +268,8 @@ where
}
},
// if the first argument is a scalar utf8 all arguments are expected to be scalar utf8
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8View(a)
| ScalarValue::LargeUtf8(a)
| ScalarValue::Utf8(a) => {
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(a) => {
let a = a.as_ref();
// ASK: Why do we trust `a` to be non-null at this point?
let a = unwrap_or_internal_err!(a);
Expand All @@ -291,7 +287,7 @@ where
};

if let Some(s) = x {
match op(a.as_str(), s.as_str()) {
match op(a, s.as_str()) {
Ok(r) => {
ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some(
op2(r),
Expand Down Expand Up @@ -408,19 +404,10 @@ where
DataType::Utf8 => Ok(a.as_string::<i32>().value(pos)),
other => exec_err!("Unexpected type encountered '{other}'"),
},
ColumnarValue::Scalar(s) => match s {
ScalarValue::Utf8View(a)
| ScalarValue::LargeUtf8(a)
| ScalarValue::Utf8(a) => {
if let Some(v) = a {
Ok(v.as_str())
} else {
continue;
}
}
other => {
exec_err!("Unexpected scalar type encountered '{other}'")
}
ColumnarValue::Scalar(s) => match s.try_as_str() {
Some(Some(v)) => Ok(v),
Some(None) => continue, // null string
None => exec_err!("Unexpected scalar type encountered '{s}'"),
},
}?;

Expand Down
16 changes: 6 additions & 10 deletions datafusion/functions/src/encoding/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -546,12 +546,10 @@ fn encode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
);
}
let encoding = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
method.parse::<Encoding>()
}
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method)) => method.parse::<Encoding>(),
_ => not_impl_err!(
"Second argument to encode must be a constant: Encode using dynamically decided method is not yet supported"
"Second argument to encode must be non null constant string: Encode using dynamically decided method is not yet supported. Got {scalar:?}"
),
},
ColumnarValue::Array(_) => not_impl_err!(
Expand All @@ -572,12 +570,10 @@ fn decode(args: &[ColumnarValue]) -> Result<ColumnarValue> {
);
}
let encoding = match &args[1] {
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(Some(method)) | ScalarValue::Utf8View(Some(method)) | ScalarValue::LargeUtf8(Some(method)) => {
method.parse::<Encoding>()
}
ColumnarValue::Scalar(scalar) => match scalar.try_as_str() {
Some(Some(method))=> method.parse::<Encoding>(),
_ => not_impl_err!(
"Second argument to decode must be a utf8 constant: Decode using dynamically decided method is not yet supported"
"Second argument to decode must be a non null constant string: Decode using dynamically decided method is not yet supported. Got {scalar:?}"
),
},
ColumnarValue::Array(_) => not_impl_err!(
Expand Down
20 changes: 9 additions & 11 deletions datafusion/functions/src/string/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,18 +134,16 @@ impl ScalarUDFImpl for ConcatFunc {
if array_len.is_none() {
let mut result = String::new();
for arg in args {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) => {
result.push_str(v);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
other => plan_err!(
let ColumnarValue::Scalar(scalar) = arg else {
return internal_err!("concat expected scalar value, got {arg:?}");
};

match scalar.try_as_str() {
Some(Some(v)) => result.push_str(v),
Some(None) => {} // null literal
None => plan_err!(
"Concat function does not support scalar type {:?}",
other
scalar
)?,
}
}
Expand Down
62 changes: 34 additions & 28 deletions datafusion/functions/src/string/concat_ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,48 +124,54 @@ impl ScalarUDFImpl for ConcatWsFunc {

// Scalar
if array_len.is_none() {
let sep = match &args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => s,
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {
let ColumnarValue::Scalar(scalar) = &args[0] else {
// loop above checks for all args being scalar
unreachable!()
};
let sep = match scalar.try_as_str() {
Some(Some(s)) => s,
Some(None) => {
// null literal string
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
}
_ => unreachable!(),
None => return internal_err!("Expected string literal, got {scalar:?}"),
};

let mut result = String::new();
let iter = &mut args[1..].iter();

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
// iterator over Option<str>
let iter = &mut args[1..].iter().map(|arg| {
let ColumnarValue::Scalar(scalar) = arg else {
// loop above checks for all args being scalar
unreachable!()
};
scalar.try_as_str()
});

// append first non null arg
for scalar in iter.by_ref() {
match scalar {
Some(Some(s)) => {
result.push_str(s);
break;
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
Some(None) => {} // null literal string
None => {
return internal_err!("Expected string literal, got {scalar:?}")
}
}
}

for arg in iter.by_ref() {
match arg {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(s)))
| ColumnarValue::Scalar(ScalarValue::Utf8View(Some(s)))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(s))) => {
// handle subsequent non null args
for scalar in iter.by_ref() {
match scalar {
Some(Some(s)) => {
result.push_str(sep);
result.push_str(s);
}
ColumnarValue::Scalar(ScalarValue::Utf8(None))
| ColumnarValue::Scalar(ScalarValue::Utf8View(None))
| ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
_ => unreachable!(),
Some(None) => {} // null literal string
None => {
return internal_err!("Expected string literal, got {scalar:?}")
}
}
}

Expand Down
7 changes: 1 addition & 6 deletions datafusion/optimizer/src/unwrap_cast_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -475,12 +475,7 @@ fn try_cast_string_literal(
lit_value: &ScalarValue,
target_type: &DataType,
) -> Option<ScalarValue> {
let string_value = match lit_value {
ScalarValue::Utf8(s) | ScalarValue::LargeUtf8(s) | ScalarValue::Utf8View(s) => {
s.clone()
}
_ => return None,
};
let string_value = lit_value.try_as_str()?.map(|s| s.to_string());
let scalar_value = match target_type {
DataType::Utf8 => ScalarValue::Utf8(string_value),
DataType::LargeUtf8 => ScalarValue::LargeUtf8(string_value),
Expand Down
21 changes: 6 additions & 15 deletions datafusion/physical-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -251,22 +251,13 @@ macro_rules! compute_utf8_flag_op_scalar {
.downcast_ref::<$ARRAYTYPE>()
.expect("compute_utf8_flag_op_scalar failed to downcast array");

let string_value = match $RIGHT {
ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value,
ScalarValue::Dictionary(_, value) => {
match *value {
ScalarValue::Utf8(Some(string_value)) | ScalarValue::LargeUtf8(Some(string_value)) => string_value,
other => return internal_err!(
"compute_utf8_flag_op_scalar failed to cast dictionary value {} for operation '{}'",
other, stringify!($OP)
)
}
},
let string_value = match $RIGHT.try_as_str() {
Some(Some(string_value)) => string_value,
// null literal or non string
_ => return internal_err!(
"compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
$RIGHT, stringify!($OP)
)

"compute_utf8_flag_op_scalar failed to cast literal value {} for operation '{}'",
$RIGHT, stringify!($OP)
)
};

let flag = $FLAG.then_some("i");
Expand Down
Loading

0 comments on commit 3b5f623

Please sign in to comment.