Skip to content

Commit

Permalink
Support alternate format for Utf8 unparsing (CHAR) (apache#11494)
Browse files Browse the repository at this point in the history
* Add dialect param to use CHAR instead of TEXT for Utf8 unparsing for MySQL (apache#12)

* Configurable data type instead of flag for Utf8 unparsing

* Fix type in comment
  • Loading branch information
sgrebnov authored Jul 17, 2024
1 parent de0765a commit b0925c8
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
52 changes: 51 additions & 1 deletion datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use regex::Regex;
use sqlparser::keywords::ALL_KEYWORDS;
use sqlparser::{ast, keywords::ALL_KEYWORDS};

/// `Dialect` to use for Unparsing
///
Expand Down Expand Up @@ -45,6 +45,17 @@ pub trait Dialect {
fn interval_style(&self) -> IntervalStyle {
IntervalStyle::PostgresVerbose
}

// The SQL type to use for Arrow Utf8 unparsing
// Most dialects use VARCHAR, but some, like MySQL, require CHAR
fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Varchar(None)
}
// The SQL type to use for Arrow LargeUtf8 unparsing
// Most dialects use TEXT, but some, like MySQL, require CHAR
fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Text
}
}

/// `IntervalStyle` to use for unparsing
Expand Down Expand Up @@ -103,6 +114,14 @@ impl Dialect for MySqlDialect {
fn interval_style(&self) -> IntervalStyle {
IntervalStyle::MySQL
}

fn utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Char(None)
}

fn large_utf8_cast_dtype(&self) -> ast::DataType {
ast::DataType::Char(None)
}
}

pub struct SqliteDialect {}
Expand All @@ -118,6 +137,8 @@ pub struct CustomDialect {
supports_nulls_first_in_sort: bool,
use_timestamp_for_date64: bool,
interval_style: IntervalStyle,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
}

impl Default for CustomDialect {
Expand All @@ -127,6 +148,8 @@ impl Default for CustomDialect {
supports_nulls_first_in_sort: true,
use_timestamp_for_date64: false,
interval_style: IntervalStyle::SQLStandard,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
}
}
}
Expand Down Expand Up @@ -158,6 +181,14 @@ impl Dialect for CustomDialect {
fn interval_style(&self) -> IntervalStyle {
self.interval_style
}

fn utf8_cast_dtype(&self) -> ast::DataType {
self.utf8_cast_dtype.clone()
}

fn large_utf8_cast_dtype(&self) -> ast::DataType {
self.large_utf8_cast_dtype.clone()
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand All @@ -179,6 +210,8 @@ pub struct CustomDialectBuilder {
supports_nulls_first_in_sort: bool,
use_timestamp_for_date64: bool,
interval_style: IntervalStyle,
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
}

impl Default for CustomDialectBuilder {
Expand All @@ -194,6 +227,8 @@ impl CustomDialectBuilder {
supports_nulls_first_in_sort: true,
use_timestamp_for_date64: false,
interval_style: IntervalStyle::PostgresVerbose,
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
}
}

Expand All @@ -203,6 +238,8 @@ impl CustomDialectBuilder {
supports_nulls_first_in_sort: self.supports_nulls_first_in_sort,
use_timestamp_for_date64: self.use_timestamp_for_date64,
interval_style: self.interval_style,
utf8_cast_dtype: self.utf8_cast_dtype,
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
}
}

Expand Down Expand Up @@ -235,4 +272,17 @@ impl CustomDialectBuilder {
self.interval_style = interval_style;
self
}

pub fn with_utf8_cast_dtype(mut self, utf8_cast_dtype: ast::DataType) -> Self {
self.utf8_cast_dtype = utf8_cast_dtype;
self
}

pub fn with_large_utf8_cast_dtype(
mut self,
large_utf8_cast_dtype: ast::DataType,
) -> Self {
self.large_utf8_cast_dtype = large_utf8_cast_dtype;
self
}
}
34 changes: 32 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1275,8 +1275,8 @@ impl Unparser<'_> {
DataType::BinaryView => {
not_impl_err!("Unsupported DataType: conversion: {data_type:?}")
}
DataType::Utf8 => Ok(ast::DataType::Varchar(None)),
DataType::LargeUtf8 => Ok(ast::DataType::Text),
DataType::Utf8 => Ok(self.dialect.utf8_cast_dtype()),
DataType::LargeUtf8 => Ok(self.dialect.large_utf8_cast_dtype()),
DataType::Utf8View => {
not_impl_err!("Unsupported DataType: conversion: {data_type:?}")
}
Expand Down Expand Up @@ -1936,4 +1936,34 @@ mod tests {
assert_eq!(actual, expected);
}
}

#[test]
fn custom_dialect_use_char_for_utf8_cast() -> Result<()> {
let default_dialect = CustomDialectBuilder::default().build();
let mysql_custom_dialect = CustomDialectBuilder::new()
.with_utf8_cast_dtype(ast::DataType::Char(None))
.with_large_utf8_cast_dtype(ast::DataType::Char(None))
.build();

for (dialect, data_type, identifier) in [
(&default_dialect, DataType::Utf8, "VARCHAR"),
(&default_dialect, DataType::LargeUtf8, "TEXT"),
(&mysql_custom_dialect, DataType::Utf8, "CHAR"),
(&mysql_custom_dialect, DataType::LargeUtf8, "CHAR"),
] {
let unparser = Unparser::new(dialect);

let expr = Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type,
});
let ast = unparser.expr_to_sql(&expr)?;

let actual = format!("{}", ast);
let expected = format!(r#"CAST(a AS {identifier})"#);

assert_eq!(actual, expected);
}
Ok(())
}
}

0 comments on commit b0925c8

Please sign in to comment.