diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index d3883eac8b..4dff659b9c 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -1078,6 +1078,12 @@ impl DeltaDataChecker { self } + /// Add the specified set of constraints to the current DeltaDataChecker's constraints + pub fn with_extra_constraints(mut self, constraints: Vec) -> Self { + self.constraints.extend(constraints); + self + } + /// Create a new DeltaDataChecker pub fn new(snapshot: &DeltaTableState) -> Self { let metadata = snapshot.metadata(); diff --git a/crates/deltalake-core/src/operations/write.rs b/crates/deltalake-core/src/operations/write.rs index 6da3b18ecb..a81d8a476d 100644 --- a/crates/deltalake-core/src/operations/write.rs +++ b/crates/deltalake-core/src/operations/write.rs @@ -1,3 +1,4 @@ +#![allow(unused)] //! Used to write [RecordBatch]es into a delta table. //! //! New Table Semantics @@ -33,21 +34,31 @@ use arrow_array::RecordBatch; use arrow_cast::can_cast_types; use arrow_schema::{DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; +use datafusion::physical_expr::create_physical_expr; +use datafusion::physical_plan::filter::FilterExec; use datafusion::physical_plan::{memory::MemoryExec, ExecutionPlan}; +use datafusion_common::DFSchema; +use datafusion_expr::Expr; +use datafusion_proto::protobuf::Constraint; use futures::future::BoxFuture; use futures::StreamExt; use parquet::file::properties::WriterProperties; +use super::datafusion_utils::Expression; use super::transaction::PROTOCOL; use super::writer::{DeltaWriter, WriterConfig}; use super::{transaction::commit, CreateBuilder}; +use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::DeltaDataChecker; +use crate::delta_datafusion::{find_files, register_store, DeltaScanBuilder}; use crate::errors::{DeltaResult, DeltaTableError}; use crate::kernel::{Action, Add, Metadata, Remove, StructType}; use crate::logstore::LogStoreRef; +use crate::operations::delete::DeleteBuilder; use crate::protocol::{DeltaOperation, SaveMode}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; +use crate::table::Constraint as DeltaConstraint; use crate::writer::record_batch::divide_by_partition_values; use crate::writer::utils::PartitionPath; use crate::DeltaTable; @@ -81,7 +92,7 @@ impl From for DeltaTableError { } /// Write data into a DeltaTable -#[derive(Debug, Clone)] +// #[derive(Debug, Clone)] pub struct WriteBuilder { /// A snapshot of the to-be-loaded table's state snapshot: DeltaTableState, @@ -96,7 +107,8 @@ pub struct WriteBuilder { /// Column names for table partitioning partition_columns: Option>, /// When using `Overwrite` mode, replace data that matches a predicate - predicate: Option, + // predicate: Option, + predicate: Option, /// Size above which we will write a buffered parquet file to disk. target_file_size: Option, /// Number of records to be written in single batch to underlying writer @@ -156,7 +168,7 @@ impl WriteBuilder { } /// When using `Overwrite` mode, replace data that matches a predicate - pub fn with_replace_where(mut self, predicate: impl Into) -> Self { + pub fn with_replace_where(mut self, predicate: impl Into) -> Self { self.predicate = Some(predicate.into()); self } @@ -294,7 +306,8 @@ impl WriteBuilder { } #[allow(clippy::too_many_arguments)] -pub(crate) async fn write_execution_plan( +async fn write_execution_plan_with_predicate( + predicate: Option, snapshot: &DeltaTableState, state: SessionState, plan: Arc, @@ -314,6 +327,14 @@ pub(crate) async fn write_execution_plan( }; let checker = DeltaDataChecker::new(snapshot); + let checker = match predicate { + Some(pred) => { + // TODO: get the name of the outer-most column? `*` will also work but would be slower + let chk = DeltaConstraint::new("*", &fmt_expr_to_sql(&pred)?); + checker.with_extra_constraints(vec![chk]) + } + _ => checker, + }; // Write data to disk let mut tasks = vec![]; @@ -359,6 +380,138 @@ pub(crate) async fn write_execution_plan( .collect::>()) } +#[allow(clippy::too_many_arguments)] +pub(crate) async fn write_execution_plan( + snapshot: &DeltaTableState, + state: SessionState, + plan: Arc, + partition_columns: Vec, + object_store: ObjectStoreRef, + target_file_size: Option, + write_batch_size: Option, + writer_properties: Option, + safe_cast: bool, + overwrite_schema: bool, +) -> DeltaResult> { + write_execution_plan_with_predicate( + None, + snapshot, + state, + plan, + partition_columns, + object_store, + target_file_size, + write_batch_size, + writer_properties, + safe_cast, + overwrite_schema, + ) + .await +} + +async fn execute_non_empty_expr( + snapshot: &DeltaTableState, + log_store: LogStoreRef, + state: &SessionState, + expression: &Expr, + rewrite: &[Add], + writer_properties: Option, +) -> DeltaResult> { + // For each identified file perform a parquet scan + filter + limit (1) + count. + // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. + + let input_schema = snapshot.input_schema()?; + let input_dfschema: DFSchema = input_schema.clone().as_ref().clone().try_into()?; + + let table_partition_cols = snapshot + .metadata() + .ok_or(DeltaTableError::NoMetadata)? + .partition_columns + .clone(); + + let scan = DeltaScanBuilder::new(snapshot, log_store.clone(), state) + .with_files(rewrite) + .build() + .await?; + let scan = Arc::new(scan); + + // Apply the negation of the filter and rewrite files + let negated_expression = Expr::Not(Box::new(Expr::IsTrue(Box::new(expression.clone())))); + + let predicate_expr = create_physical_expr( + &negated_expression, + &input_dfschema, + &input_schema, + state.execution_props(), + )?; + let filter: Arc = + Arc::new(FilterExec::try_new(predicate_expr, scan.clone())?); + + // We don't want to verify the predicate against existing data + let add_actions = write_execution_plan( + snapshot, + state.clone(), + filter.clone(), + table_partition_cols.clone(), + log_store.object_store(), + Some(snapshot.table_config().target_file_size() as usize), + None, + writer_properties, + false, + false, + ) + .await?; + + Ok(add_actions) +} + +// This should only be called wth a valid predicate +async fn prepare_predicate_actions( + predicate: Expr, + log_store: LogStoreRef, + snapshot: &DeltaTableState, + state: &SessionState, + writer_properties: Option, + deletion_timestamp: i64, +) -> DeltaResult> { + let candidates = + find_files(snapshot, log_store.clone(), state, Some(predicate.clone())).await?; + + let add = if candidates.partition_scan { + Vec::new() + } else { + let add = execute_non_empty_expr( + snapshot, + log_store, + &state, + &predicate, + &candidates.candidates, + writer_properties, + ) + .await?; + add + }; + let remove = candidates.candidates; + + let mut actions: Vec = add.into_iter().map(Action::Add).collect(); + + for action in remove { + actions.push(Action::Remove(Remove { + path: action.path, + deletion_timestamp: Some(deletion_timestamp), + data_change: true, + extended_file_metadata: Some(true), + partition_values: Some(action.partition_values), + size: Some(action.size), + deletion_vector: action.deletion_vector, + tags: None, + base_row_id: action.base_row_id, + default_row_commit_version: action.default_row_commit_version, + })) + } + Ok(actions) +} + impl std::future::IntoFuture for WriteBuilder { type Output = DeltaResult; type IntoFuture = BoxFuture<'static, Self::Output>; @@ -462,19 +615,35 @@ impl std::future::IntoFuture for WriteBuilder { Some(state) => state, None => { let ctx = SessionContext::new(); + register_store(this.log_store.clone(), ctx.runtime_env()); ctx.state() } }; - let add_actions = write_execution_plan( + let (predicate_str, predicate) = match this.predicate { + Some(predicate) => { + let pred = match predicate { + Expression::DataFusion(expr) => expr, + Expression::String(s) => { + this.snapshot.parse_predicate_expression(s, &state)? + } + }; + (Some(fmt_expr_to_sql(&pred)?), Some(pred)) + } + _ => (None, None), + }; + + // Here we need to validate if the new data conforms to a predicate if one is provided + let add_actions = write_execution_plan_with_predicate( + predicate.clone(), &this.snapshot, - state, + state.clone(), plan, partition_columns.clone(), this.log_store.object_store().clone(), this.target_file_size, this.write_batch_size, - this.writer_properties, + this.writer_properties.clone(), this.safe_cast, this.overwrite_schema, ) @@ -501,36 +670,43 @@ impl std::future::IntoFuture for WriteBuilder { let metadata_action = Metadata::try_from(metadata)?; actions.push(Action::Metadata(metadata_action)); } - // This should never error, since now() will always be larger than UNIX_EPOCH + let deletion_timestamp = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_millis() as i64; - let to_remove_action = |add: &Add| { - Action::Remove(Remove { - path: add.path.clone(), - deletion_timestamp: Some(deletion_timestamp), - data_change: true, - extended_file_metadata: Some(false), - partition_values: Some(add.partition_values.clone()), - size: Some(add.size), - // TODO add file metadata to remove action (tags missing) - tags: None, - deletion_vector: add.deletion_vector.clone(), - base_row_id: add.base_row_id, - default_row_commit_version: add.default_row_commit_version, - }) - }; - - match this.predicate { - Some(_pred) => { - return Err(DeltaTableError::Generic( - "Overwriting data based on predicate is not yet implemented" - .to_string(), - )); + match predicate { + Some(pred) => { + let predicate_actions = prepare_predicate_actions( + pred, + this.log_store.clone(), + &this.snapshot, + &state, + this.writer_properties, + deletion_timestamp, + ) + .await?; + if !predicate_actions.is_empty() { + actions.extend(predicate_actions); + } } _ => { + let to_remove_action = |add: &Add| { + Action::Remove(Remove { + path: add.path.clone(), + deletion_timestamp: Some(deletion_timestamp), + data_change: true, + extended_file_metadata: Some(false), + partition_values: Some(add.partition_values.clone()), + size: Some(add.size), + // TODO add file metadata to remove action (tags missing) + tags: None, + deletion_vector: add.deletion_vector.clone(), + base_row_id: add.base_row_id, + default_row_commit_version: add.default_row_commit_version, + }) + }; let remove_actions = this .snapshot .files() @@ -552,7 +728,7 @@ impl std::future::IntoFuture for WriteBuilder { } else { None }, - predicate: this.predicate, + predicate: predicate_str, }, &this.snapshot, this.app_metadata, @@ -564,10 +740,9 @@ impl std::future::IntoFuture for WriteBuilder { // then again, having only some tombstones may be misleading. this.snapshot .merge(DeltaTableState::from_actions(actions, version)?, true, true); + let table = DeltaTable::new_with_state(this.log_store, this.snapshot); - // TODO should we build checkpoints based on config? - - Ok(DeltaTable::new_with_state(this.log_store, this.snapshot)) + Ok(table) }) } } @@ -600,7 +775,7 @@ mod tests { use crate::writer::test_utils::datafusion::get_data; use crate::writer::test_utils::datafusion::write_batch; use crate::writer::test_utils::{ - get_delta_schema, get_delta_schema_with_nested_struct, get_record_batch, + get_arrow_schema, get_delta_schema, get_delta_schema_with_nested_struct, get_record_batch, get_record_batch_with_nested_struct, setup_table_with_configuration, }; use crate::DeltaConfigKey; @@ -608,6 +783,7 @@ mod tests { use arrow::datatypes::Schema as ArrowSchema; use arrow_array::{Int32Array, StringArray, TimestampMicrosecondArray}; use arrow_schema::{DataType, TimeUnit}; + use datafusion::prelude::*; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; use serde_json::{json, Value}; @@ -960,4 +1136,95 @@ mod tests { assert_batches_eq!(&expected, &data); } + + #[tokio::test] + async fn test_replace_where() { + // TODO: add tests for replaceWhere with partition column + let tmp_dir = tempdir::TempDir::new("test").unwrap(); + let tmp_path = std::fs::canonicalize(tmp_dir.path()).unwrap(); + + let schema = get_arrow_schema(&None); + // let table = DeltaOps::new_in_memory() + // let ops = DeltaOps::try_from_uri("/tmp/issue_1957_delta_rs").await.unwrap(); + // let table = ops + let table_path = tmp_path.as_os_str().to_str().unwrap(); + let ops = DeltaOps::try_from_uri(table_path).await.unwrap(); + + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![0, 20, 10, 100])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2021-02-03", + "2021-02-02", + "2021-02-04", + ])), + ], + ) + .unwrap(); + + // write some data + // let table = DeltaOps(table) + let table = ops + .write(vec![batch]) + .with_save_mode(SaveMode::Append) + .await + .unwrap(); + assert_eq!(table.version(), 0); + + let batch_add = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["C"])), + Arc::new(arrow::array::Int32Array::from(vec![50])), + Arc::new(arrow::array::StringArray::from(vec!["2023-01-01"])), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![batch_add]) + .with_save_mode(SaveMode::Overwrite) + .with_replace_where(col("id").eq(lit("C"))) + .await + .unwrap(); + assert_eq!(table.version(), 1); + + let expected = [ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 0 | 2021-02-02 |", + "| B | 20 | 2021-02-03 |", + "| C | 50 | 2023-01-01 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + + let batch_fail = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["D"])), + Arc::new(arrow::array::Int32Array::from(vec![1000])), + Arc::new(arrow::array::StringArray::from(vec!["2023-01-01"])), + ], + ) + .unwrap(); + + let table = DeltaOps(table) + .write(vec![batch_fail]) + .with_save_mode(SaveMode::Overwrite) + .with_replace_where(col("id").eq(lit("C"))) + .await; + assert!(table.is_err()); + + let table = crate::open_table(table_path).await.unwrap(); + assert_eq!(table.get_latest_version().await.unwrap(), 1); + + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } } diff --git a/python/src/lib.rs b/python/src/lib.rs index 55a7442281..b5842e547a 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1208,7 +1208,7 @@ fn write_to_deltalake( builder = builder.with_description(description); }; - if let Some(predicate) = &predicate { + if let Some(predicate) = predicate { builder = builder.with_replace_where(predicate); };