From b4c055d0f1cee4352153de1fc8bb18bac50b1c98 Mon Sep 17 00:00:00 2001 From: David Blajda Date: Sat, 9 Dec 2023 21:09:50 -0500 Subject: [PATCH] fix: respect case sensitivity on operations --- .../src/delta_datafusion/expr.rs | 5 +- .../src/delta_datafusion/mod.rs | 174 +++++++++++++++++- .../deltalake-core/src/operations/delete.rs | 4 +- crates/deltalake-core/src/operations/merge.rs | 110 ++++++++++- .../deltalake-core/src/operations/update.rs | 87 ++++++++- 5 files changed, 365 insertions(+), 15 deletions(-) diff --git a/crates/deltalake-core/src/delta_datafusion/expr.rs b/crates/deltalake-core/src/delta_datafusion/expr.rs index e0c284df92..f9275832a1 100644 --- a/crates/deltalake-core/src/delta_datafusion/expr.rs +++ b/crates/deltalake-core/src/delta_datafusion/expr.rs @@ -42,6 +42,8 @@ use sqlparser::tokenizer::Tokenizer; use crate::{DeltaResult, DeltaTableError}; +use super::DeltaParserOptions; + pub(crate) struct DeltaContextProvider<'a> { state: &'a SessionState, } @@ -97,7 +99,8 @@ pub(crate) fn parse_predicate_expression( })?; let context_provider = DeltaContextProvider { state: df_state }; - let sql_to_rel = SqlToRel::new(&context_provider); + let sql_to_rel = + SqlToRel::new_with_options(&context_provider, DeltaParserOptions::default().into()); Ok(sql_to_rel.sql_to_expr(sql, schema, &mut Default::default())?) } diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 973d575904..3917347888 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -42,7 +42,7 @@ use datafusion::datasource::physical_plan::{ }; use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::{listing::PartitionedFile, MemTable, TableProvider, TableType}; -use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; +use datafusion::execution::context::{SessionConfig, SessionContext, SessionState, TaskContext}; use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::execution::FunctionRegistry; use datafusion::optimizer::utils::conjunction; @@ -65,6 +65,7 @@ use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; +use datafusion_sql::planner::ParserOptions; use log::error; use object_store::ObjectMeta; use serde::{Deserialize, Serialize}; @@ -1494,6 +1495,111 @@ pub async fn find_files<'a>( } } +/// A wrapper for sql_parser's ParserOptions to capture sane default table defaults +pub struct DeltaParserOptions { + inner: ParserOptions, +} + +impl Default for DeltaParserOptions { + fn default() -> Self { + DeltaParserOptions { + inner: ParserOptions { + enable_ident_normalization: false, + ..ParserOptions::default() + }, + } + } +} + +impl From for ParserOptions { + fn from(value: DeltaParserOptions) -> Self { + value.inner + } +} + +/// A wrapper for Deltafusion's SessionConfig to capture sane default table defaults +pub struct DeltaSessionConfig { + inner: SessionConfig, +} + +impl Default for DeltaSessionConfig { + fn default() -> Self { + DeltaSessionConfig { + inner: SessionConfig::default() + .set_bool("datafusion.sql_parser.enable_ident_normalization", false), + } + } +} + +impl From for SessionConfig { + fn from(value: DeltaSessionConfig) -> Self { + value.inner + } +} + +/// A wrapper for Deltafusion's SessionContext to capture sane default table defaults +pub struct DeltaSessionContext { + inner: SessionContext, +} + +impl Default for DeltaSessionContext { + fn default() -> Self { + DeltaSessionContext { + inner: SessionContext::new_with_config(DeltaSessionConfig::default().into()), + } + } +} + +impl From for SessionContext { + fn from(value: DeltaSessionContext) -> Self { + value.inner + } +} + +/// A wrapper for Deltafusion's Column to preserve case-sensitivity during string conversion +pub struct DeltaColumn { + inner: Column, +} + +impl From<&str> for DeltaColumn { + fn from(c: &str) -> Self { + DeltaColumn { + inner: Column::from_qualified_name_ignore_case(c), + } + } +} + +/// Create a column, cloning the string +impl From<&String> for DeltaColumn { + fn from(c: &String) -> Self { + DeltaColumn { + inner: Column::from_qualified_name_ignore_case(c), + } + } +} + +/// Create a column, reusing the existing string +impl From for DeltaColumn { + fn from(c: String) -> Self { + DeltaColumn { + inner: Column::from_qualified_name_ignore_case(c), + } + } +} + +impl From for Column { + fn from(value: DeltaColumn) -> Self { + value.inner + } +} + +/// Create a column, resuing the existing datafusion column +impl From for DeltaColumn { + fn from(c: Column) -> Self { + DeltaColumn { inner: c } + } +} + #[cfg(test)] mod tests { use crate::writer::test_utils::get_delta_schema; @@ -1804,4 +1910,70 @@ mod tests { ]; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn delta_scan_case_sensitive() { + let schema = Arc::new(ArrowSchema::new(vec![ + Field::new("moDified", DataType::Utf8, true), + Field::new("ID", DataType::Utf8, true), + Field::new("vaLue", DataType::Int32, true), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-01", + "2021-02-01", + "2021-02-02", + "2021-02-02", + ])), + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "D"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 10, 20, 100])), + ], + ) + .unwrap(); + // write some data + let table = crate::DeltaOps::new_in_memory() + .write(vec![batch.clone()]) + .with_save_mode(crate::protocol::SaveMode::Append) + .await + .unwrap(); + + let config = DeltaScanConfigBuilder::new().build(&table.state).unwrap(); + let log = table.log_store(); + + let provider = DeltaTableProvider::try_new(table.state, log, config).unwrap(); + let ctx: SessionContext = DeltaSessionContext::default().into(); + ctx.register_table("test", Arc::new(provider)).unwrap(); + + let df = ctx + .sql("select ID, moDified, vaLue from test") + .await + .unwrap(); + let actual = df.collect().await.unwrap(); + let expected = vec![ + "+----+------------+-------+", + "| ID | moDified | vaLue |", + "+----+------------+-------+", + "| A | 2021-02-01 | 1 |", + "| B | 2021-02-01 | 10 |", + "| C | 2021-02-02 | 20 |", + "| D | 2021-02-02 | 100 |", + "+----+------------+-------+", + ]; + assert_batches_sorted_eq!(&expected, &actual); + + /* TODO: Datafusion doesn't have any options to prevent case-sensitivity with the col func */ + /* + let df = ctx + .table("test") + .await + .unwrap() + .select(vec![col("ID"), col("moDified"), col("vaLue")]) + .unwrap(); + let actual = df.collect().await.unwrap(); + assert_batches_sorted_eq!(&expected, &actual); + */ + } } diff --git a/crates/deltalake-core/src/operations/delete.rs b/crates/deltalake-core/src/operations/delete.rs index 31486183be..15fea91ee9 100644 --- a/crates/deltalake-core/src/operations/delete.rs +++ b/crates/deltalake-core/src/operations/delete.rs @@ -37,7 +37,7 @@ use serde_json::Value; use super::datafusion_utils::Expression; use super::transaction::PROTOCOL; use crate::delta_datafusion::expr::fmt_expr_to_sql; -use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; +use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder, DeltaSessionContext}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Remove}; use crate::operations::transaction::commit; @@ -280,7 +280,7 @@ impl std::future::IntoFuture for DeleteBuilder { PROTOCOL.can_write_to(&this.snapshot)?; let state = this.state.unwrap_or_else(|| { - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); // If a user provides their own their DF state then they must register the store themselves register_store(this.log_store.clone(), session.runtime_env()); diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 433e9cda43..7b03965747 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -65,7 +65,9 @@ use super::transaction::{commit, PROTOCOL}; use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression}; use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; -use crate::delta_datafusion::{register_store, DeltaScanConfig, DeltaTableProvider}; +use crate::delta_datafusion::{ + register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, DeltaTableProvider, +}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; use crate::operations::write::write_execution_plan; @@ -391,12 +393,12 @@ impl UpdateBuilder { /// How a column from the target table should be updated. /// In the match case the expression may contain both source and target columns. /// In the source not match case the expression may only contain target columns - pub fn update, E: Into>( + pub fn update, E: Into>( mut self, column: C, expression: E, ) -> Self { - self.updates.insert(column.into(), expression.into()); + self.updates.insert(column.into().into(), expression.into()); self } } @@ -419,8 +421,12 @@ impl InsertBuilder { /// Which values to insert into the target tables. If a target column is not /// specified then null is inserted. - pub fn set, E: Into>(mut self, column: C, expression: E) -> Self { - self.set.insert(column.into(), expression.into()); + pub fn set, E: Into>( + mut self, + column: C, + expression: E, + ) -> Self { + self.set.insert(column.into().into(), expression.into()); self } } @@ -888,7 +894,10 @@ async fn execute( .end()?; let name = "__delta_rs_c_".to_owned() + delta_field.name(); - write_projection.push(col(name.clone()).alias(delta_field.name())); + write_projection.push( + Expr::Column(Column::from_qualified_name_ignore_case(name.clone())) + .alias(delta_field.name()), + ); new_columns = new_columns.with_column(&name, case)?; } @@ -1112,7 +1121,8 @@ impl std::future::IntoFuture for MergeBuilder { let state = this.state.unwrap_or_else(|| { //TODO: Datafusion's Hashjoin has some memory issues. Running with all cores results in a OoM. Can be removed when upstream improvemetns are made. - let config = SessionConfig::new().with_target_partitions(1); + let config: SessionConfig = DeltaSessionConfig::default().into(); + let config = config.with_target_partitions(1); let session = SessionContext::new_with_config(config); // If a user provides their own their DF state then they must register the store themselves @@ -1149,6 +1159,9 @@ impl std::future::IntoFuture for MergeBuilder { #[cfg(test)] mod tests { + use crate::kernel::DataType; + use crate::kernel::PrimitiveType; + use crate::kernel::StructField; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1159,6 +1172,8 @@ mod tests { use crate::DeltaTable; use arrow::datatypes::Schema as ArrowSchema; use arrow::record_batch::RecordBatch; + use arrow_schema::DataType as ArrowDataType; + use arrow_schema::Field; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::DataFrame; use datafusion::prelude::SessionContext; @@ -1919,4 +1934,85 @@ mod tests { let actual = get_data(&table).await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_merge_case_sensitive() { + let schema = vec![ + StructField::new( + "Id".to_string(), + DataType::Primitive(PrimitiveType::String), + true, + ), + StructField::new( + "vAlue".to_string(), + DataType::Primitive(PrimitiveType::Integer), + true, + ), + StructField::new( + "mOdifieD".to_string(), + DataType::Primitive(PrimitiveType::String), + true, + ), + ]; + + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("Id", ArrowDataType::Utf8, true), + Field::new("vAlue", ArrowDataType::Int32, true), + Field::new("mOdifieD", ArrowDataType::Utf8, true), + ])); + + let table = DeltaOps::new_in_memory() + .create() + .with_columns(schema) + .await + .unwrap(); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&arrow_schema.clone()), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let table = write_data(table, &arrow_schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 1); + + let (table, _metrics) = DeltaOps(table) + .merge(source, "target.Id = source.Id") + .with_source_alias("source") + .with_target_alias("target") + .when_not_matched_insert(|insert| { + insert + .set("Id", "source.Id") + .set("vAlue", "source.vAlue + 1") + .set("mOdifieD", "source.mOdifieD") + }) + .unwrap() + .await + .unwrap(); + + let expected = vec![ + "+----+-------+------------+", + "| Id | vAlue | mOdifieD |", + "+----+-------+------------+", + "| A | 1 | 2021-02-01 |", + "| B | 10 | 2021-02-01 |", + "| C | 10 | 2021-02-02 |", + "| D | 100 | 2021-02-02 |", + "| X | 31 | 2023-07-04 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } } diff --git a/crates/deltalake-core/src/operations/update.rs b/crates/deltalake-core/src/operations/update.rs index b7ff813712..78319c00b0 100644 --- a/crates/deltalake-core/src/operations/update.rs +++ b/crates/deltalake-core/src/operations/update.rs @@ -46,7 +46,9 @@ use serde_json::Value; use super::datafusion_utils::Expression; use super::transaction::{commit, PROTOCOL}; use super::write::write_execution_plan; -use crate::delta_datafusion::{expr::fmt_expr_to_sql, physical::MetricObserverExec}; +use crate::delta_datafusion::{ + expr::fmt_expr_to_sql, physical::MetricObserverExec, DeltaColumn, DeltaSessionContext, +}; use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; @@ -115,12 +117,12 @@ impl UpdateBuilder { } /// Perform an additional update expression during the operaton - pub fn with_update, E: Into>( + pub fn with_update, E: Into>( mut self, column: S, expression: E, ) -> Self { - self.updates.insert(column.into(), expression.into()); + self.updates.insert(column.into().into(), expression.into()); self } @@ -431,7 +433,7 @@ impl std::future::IntoFuture for UpdateBuilder { PROTOCOL.can_write_to(&this.snapshot)?; let state = this.state.unwrap_or_else(|| { - let session = SessionContext::new(); + let session: SessionContext = DeltaSessionContext::default().into(); // If a user provides their own their DF state then they must register the store themselves register_store(this.log_store.clone(), session.runtime_env()); @@ -462,6 +464,10 @@ impl std::future::IntoFuture for UpdateBuilder { #[cfg(test)] mod tests { + use crate::kernel::DataType as DeltaDataType; + use crate::kernel::PrimitiveType; + use crate::kernel::StructField; + use crate::kernel::StructType; use crate::operations::DeltaOps; use crate::writer::test_utils::datafusion::get_data; use crate::writer::test_utils::datafusion::write_batch; @@ -470,9 +476,11 @@ mod tests { }; use crate::DeltaConfigKey; use crate::DeltaTable; + use arrow::datatypes::Schema as ArrowSchema; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use arrow_array::Int32Array; + use arrow_schema::DataType; use datafusion::assert_batches_sorted_eq; use datafusion::prelude::*; use serde_json::json; @@ -730,6 +738,77 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_update_case_sensitive() { + let schema = StructType::new(vec![ + StructField::new( + "Id".to_string(), + DeltaDataType::Primitive(PrimitiveType::String), + true, + ), + StructField::new( + "ValUe".to_string(), + DeltaDataType::Primitive(PrimitiveType::Integer), + true, + ), + StructField::new( + "mOdified".to_string(), + DeltaDataType::Primitive(PrimitiveType::String), + true, + ), + ]); + + let arrow_schema = Arc::new(ArrowSchema::new(vec![ + Field::new("Id", DataType::Utf8, true), + Field::new("ValUe", DataType::Int32, true), + Field::new("mOdified", DataType::Utf8, true), + ])); + + let batch = RecordBatch::try_new( + Arc::clone(&arrow_schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "A", "A"])), + Arc::new(arrow::array::Int32Array::from(vec![1, 10, 10, 100])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2021-02-02", + "2021-02-03", + "2021-02-03", + ])), + ], + ) + .unwrap(); + + let table = DeltaOps::new_in_memory() + .create() + .with_columns(schema.fields().clone()) + .await + .unwrap(); + let table = write_batch(table, batch).await; + + let (table, _metrics) = DeltaOps(table) + .update() + .with_predicate("mOdified = '2021-02-03'") + .with_update("mOdified", "'2023-05-14'") + .with_update("Id", "'C'") + .await + .unwrap(); + + let expected = vec![ + "+----+-------+------------+", + "| Id | ValUe | mOdified |", + "+----+-------+------------+", + "| A | 1 | 2021-02-02 |", + "| B | 10 | 2021-02-02 |", + "| C | 10 | 2023-05-14 |", + "| C | 100 | 2023-05-14 |", + "+----+-------+------------+", + ]; + + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_update_null() { let table = prepare_values_table().await;