diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 93b3cbdc3e..ff650eff98 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -7,6 +7,9 @@ on: branches: [main, "rust-v*"] merge_group: +env: + DEFAULT_FEATURES: "azure,datafusion,s3,gcs,glue,hdfs " + jobs: format: runs-on: ubuntu-latest @@ -17,12 +20,37 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: default - toolchain: stable + toolchain: '1.80' override: true - name: Format run: cargo fmt -- --check + coverage: + runs-on: ubuntu-latest + env: + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - name: Install rust + uses: actions-rs/toolchain@v1 + with: + profile: default + toolchain: '1.80' + override: true + - name: Install cargo-llvm-cov + uses: taiki-e/install-action@cargo-llvm-cov + - uses: Swatinem/rust-cache@v2 + - name: Generate code coverage + run: cargo llvm-cov --features ${DEFAULT_FEATURES} --workspace --codecov --output-path codecov.json -- --skip read_table_version_hdfs + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + files: codecov.json + fail_ci_if_error: true + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + build: strategy: fail-fast: false @@ -39,17 +67,17 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: default - toolchain: stable + toolchain: '1.80' override: true - name: build and lint with clippy - run: cargo clippy --features azure,datafusion,s3,gcs,glue,hdfs --tests + run: cargo clippy --features ${{ env.DEFAULT_FEATURES }} --tests - name: Spot-check build for native-tls features run: cargo clippy --no-default-features --features azure,datafusion,s3-native-tls,gcs,glue --tests - name: Check docs - run: cargo doc --features azure,datafusion,s3,gcs,glue,hdfs + run: cargo doc --features ${{ env.DEFAULT_FEATURES }} - name: Check no default features (except rustls) run: cargo check --no-default-features --features rustls @@ -76,11 +104,11 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: default - toolchain: "stable" + toolchain: '1.80' override: true - name: Run tests - run: cargo test --verbose --features datafusion,azure + run: cargo test --verbose --features ${{ env.DEFAULT_FEATURES }} integration_test: name: Integration Tests @@ -110,7 +138,7 @@ jobs: uses: actions-rs/toolchain@v1 with: profile: default - toolchain: stable + toolchain: '1.80' override: true # Install Java and Hadoop for HDFS integration tests @@ -130,7 +158,7 @@ jobs: - name: Run tests with rustls (default) run: | - cargo test --features integration_test,azure,s3,gcs,datafusion,hdfs + cargo test --features integration_test,${{ env.DEFAULT_FEATURES }} - name: Run tests with native-tls run: | diff --git a/Cargo.toml b/Cargo.toml index e8cb698318..74a4c73597 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,7 +5,7 @@ resolver = "2" [workspace.package] authors = ["Qingping Hou "] -rust-version = "1.75" +rust-version = "1.80" keywords = ["deltalake", "delta", "datalake"] readme = "README.md" edition = "2021" diff --git a/crates/aws/Cargo.toml b/crates/aws/Cargo.toml index c47065dce4..e79d92a3d2 100644 --- a/crates/aws/Cargo.toml +++ b/crates/aws/Cargo.toml @@ -33,7 +33,7 @@ tokio = { workspace = true } regex = { workspace = true } uuid = { workspace = true, features = ["serde", "v4"] } url = { workspace = true } -backoff = { version = "0.4", features = [ "tokio" ] } +backon = { version = "1",default-features = false, features = [ "tokio-sleep" ] } hyper-tls = { version = "0.5", optional = true } [dev-dependencies] diff --git a/crates/aws/src/lib.rs b/crates/aws/src/lib.rs index 720a1e6a07..187462cb12 100644 --- a/crates/aws/src/lib.rs +++ b/crates/aws/src/lib.rs @@ -7,6 +7,7 @@ pub mod logstore; mod native; pub mod storage; use aws_config::SdkConfig; +use aws_sdk_dynamodb::error::SdkError; use aws_sdk_dynamodb::{ operation::{ create_table::CreateTableError, delete_item::DeleteItemError, get_item::GetItemError, @@ -283,28 +284,28 @@ impl DynamoDbLockClient { version: i64, ) -> Result, LockClientError> { let item = self - .retry(|| async { - match self - .dynamodb_client - .get_item() - .consistent_read(true) - .table_name(&self.config.lock_table_name) - .set_key(Some(self.get_primary_key(version, table_path))) - .send() - .await - { - Ok(x) => Ok(x), - Err(sdk_err) => match sdk_err.as_service_error() { - Some(GetItemError::ProvisionedThroughputExceededException(_)) => { - Err(backoff::Error::transient( - LockClientError::ProvisionedThroughputExceeded, - )) - } - _ => Err(backoff::Error::permanent(sdk_err.into())), - }, + .retry( + || async { + self.dynamodb_client + .get_item() + .consistent_read(true) + .table_name(&self.config.lock_table_name) + .set_key(Some(self.get_primary_key(version, table_path))) + .send() + .await + }, + |err| match err.as_service_error() { + Some(GetItemError::ProvisionedThroughputExceededException(_)) => true, + _ => false, + }, + ) + .await + .map_err(|err| match err.as_service_error() { + Some(GetItemError::ProvisionedThroughputExceededException(_)) => { + LockClientError::ProvisionedThroughputExceeded } - }) - .await?; + _ => err.into(), + })?; item.item.as_ref().map(CommitEntry::try_from).transpose() } @@ -314,36 +315,38 @@ impl DynamoDbLockClient { table_path: &str, entry: &CommitEntry, ) -> Result<(), LockClientError> { - self.retry(|| async { - let item = create_value_map(entry, table_path); - match self - .dynamodb_client - .put_item() - .condition_expression(constants::CONDITION_EXPR_CREATE.as_str()) - .table_name(self.get_lock_table_name()) - .set_item(Some(item)) - .send() - .await - { - Ok(_) => Ok(()), - Err(err) => match err.as_service_error() { - Some(PutItemError::ProvisionedThroughputExceededException(_)) => Err( - backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), - ), - Some(PutItemError::ConditionalCheckFailedException(_)) => Err( - backoff::Error::permanent(LockClientError::VersionAlreadyExists { - table_path: table_path.to_owned(), - version: entry.version, - }), - ), - Some(PutItemError::ResourceNotFoundException(_)) => Err( - backoff::Error::permanent(LockClientError::LockTableNotFound), - ), - _ => Err(backoff::Error::permanent(err.into())), - }, + self.retry( + || async { + let item = create_value_map(entry, table_path); + let _ = self + .dynamodb_client + .put_item() + .condition_expression(constants::CONDITION_EXPR_CREATE.as_str()) + .table_name(self.get_lock_table_name()) + .set_item(Some(item)) + .send() + .await?; + Ok(()) + }, + |err: &SdkError<_, _>| match err.as_service_error() { + Some(PutItemError::ProvisionedThroughputExceededException(_)) => true, + _ => false, + }, + ) + .await + .map_err(|err| match err.as_service_error() { + Some(PutItemError::ProvisionedThroughputExceededException(_)) => { + LockClientError::ProvisionedThroughputExceeded } + Some(PutItemError::ConditionalCheckFailedException(_)) => { + LockClientError::VersionAlreadyExists { + table_path: table_path.to_owned(), + version: entry.version, + } + } + Some(PutItemError::ResourceNotFoundException(_)) => LockClientError::LockTableNotFound, + _ => err.into(), }) - .await } /// Get the latest entry (entry with highest version). @@ -365,33 +368,33 @@ impl DynamoDbLockClient { limit: i64, ) -> Result, LockClientError> { let query_result = self - .retry(|| async { - match self - .dynamodb_client - .query() - .table_name(self.get_lock_table_name()) - .consistent_read(true) - .limit(limit.try_into().unwrap_or(i32::MAX)) - .scan_index_forward(false) - .key_condition_expression(format!("{} = :tn", constants::ATTR_TABLE_PATH)) - .set_expression_attribute_values(Some( - maplit::hashmap!(":tn".into() => string_attr(table_path)), - )) - .send() - .await - { - Ok(result) => Ok(result), - Err(sdk_err) => match sdk_err.as_service_error() { - Some(QueryError::ProvisionedThroughputExceededException(_)) => { - Err(backoff::Error::transient( - LockClientError::ProvisionedThroughputExceeded, - )) - } - _ => Err(backoff::Error::permanent(sdk_err.into())), - }, + .retry( + || async { + self.dynamodb_client + .query() + .table_name(self.get_lock_table_name()) + .consistent_read(true) + .limit(limit.try_into().unwrap_or(i32::MAX)) + .scan_index_forward(false) + .key_condition_expression(format!("{} = :tn", constants::ATTR_TABLE_PATH)) + .set_expression_attribute_values(Some( + maplit::hashmap!(":tn".into() => string_attr(table_path)), + )) + .send() + .await + }, + |err: &SdkError<_, _>| match err.as_service_error() { + Some(QueryError::ProvisionedThroughputExceededException(_)) => true, + _ => false, + }, + ) + .await + .map_err(|err| match err.as_service_error() { + Some(QueryError::ProvisionedThroughputExceededException(_)) => { + LockClientError::ProvisionedThroughputExceeded } - }) - .await?; + _ => err.into(), + })?; query_result .items @@ -412,35 +415,44 @@ impl DynamoDbLockClient { .duration_since(SystemTime::UNIX_EPOCH) .unwrap() .as_secs(); - self.retry(|| async { - match self - .dynamodb_client - .update_item() - .table_name(self.get_lock_table_name()) - .set_key(Some(self.get_primary_key(version, table_path))) - .update_expression("SET complete = :c, expireTime = :e".to_owned()) - .set_expression_attribute_values(Some(maplit::hashmap! { - ":c".to_owned() => string_attr("true"), - ":e".to_owned() => num_attr(seconds_since_epoch), - ":f".into() => string_attr("false"), - })) - .condition_expression(constants::CONDITION_UPDATE_INCOMPLETE) - .send() - .await - { - Ok(_) => Ok(UpdateLogEntryResult::UpdatePerformed), - Err(err) => match err.as_service_error() { - Some(UpdateItemError::ProvisionedThroughputExceededException(_)) => Err( - backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), - ), - Some(UpdateItemError::ConditionalCheckFailedException(_)) => { - Ok(UpdateLogEntryResult::AlreadyCompleted) - } - _ => Err(backoff::Error::permanent(err.into())), + let res = self + .retry( + || async { + let _ = self + .dynamodb_client + .update_item() + .table_name(self.get_lock_table_name()) + .set_key(Some(self.get_primary_key(version, table_path))) + .update_expression("SET complete = :c, expireTime = :e".to_owned()) + .set_expression_attribute_values(Some(maplit::hashmap! { + ":c".to_owned() => string_attr("true"), + ":e".to_owned() => num_attr(seconds_since_epoch), + ":f".into() => string_attr("false"), + })) + .condition_expression(constants::CONDITION_UPDATE_INCOMPLETE) + .send() + .await?; + Ok(()) }, - } - }) - .await + |err: &SdkError<_, _>| match err.as_service_error() { + Some(UpdateItemError::ProvisionedThroughputExceededException(_)) => true, + _ => false, + }, + ) + .await; + + match res { + Ok(()) => Ok(UpdateLogEntryResult::UpdatePerformed), + Err(err) => match err.as_service_error() { + Some(UpdateItemError::ProvisionedThroughputExceededException(_)) => { + Err(LockClientError::ProvisionedThroughputExceeded) + } + Some(UpdateItemError::ConditionalCheckFailedException(_)) => { + Ok(UpdateLogEntryResult::AlreadyCompleted) + } + _ => Err(err.into()), + }, + } } /// Delete existing log entry if it is not already complete @@ -449,48 +461,52 @@ impl DynamoDbLockClient { version: i64, table_path: &str, ) -> Result<(), LockClientError> { - self.retry(|| async { - match self - .dynamodb_client - .delete_item() - .table_name(self.get_lock_table_name()) - .set_key(Some(self.get_primary_key(version, table_path))) - .set_expression_attribute_values(Some(maplit::hashmap! { - ":f".into() => string_attr("false"), - })) - .condition_expression(constants::CONDITION_DELETE_INCOMPLETE.as_str()) - .send() - .await - { - Ok(_) => Ok(()), - Err(err) => match err.as_service_error() { - Some(DeleteItemError::ProvisionedThroughputExceededException(_)) => Err( - backoff::Error::transient(LockClientError::ProvisionedThroughputExceeded), - ), - Some(DeleteItemError::ConditionalCheckFailedException(_)) => Err( - backoff::Error::permanent(LockClientError::VersionAlreadyCompleted { - table_path: table_path.to_owned(), - version, - }), - ), - _ => Err(backoff::Error::permanent(err.into())), - }, + self.retry( + || async { + let _ = self + .dynamodb_client + .delete_item() + .table_name(self.get_lock_table_name()) + .set_key(Some(self.get_primary_key(version, table_path))) + .set_expression_attribute_values(Some(maplit::hashmap! { + ":f".into() => string_attr("false"), + })) + .condition_expression(constants::CONDITION_DELETE_INCOMPLETE.as_str()) + .send() + .await?; + Ok(()) + }, + |err: &SdkError<_, _>| match err.as_service_error() { + Some(DeleteItemError::ProvisionedThroughputExceededException(_)) => true, + _ => false, + }, + ) + .await + .map_err(|err| match err.as_service_error() { + Some(DeleteItemError::ProvisionedThroughputExceededException(_)) => { + LockClientError::ProvisionedThroughputExceeded + } + Some(DeleteItemError::ConditionalCheckFailedException(_)) => { + LockClientError::VersionAlreadyCompleted { + table_path: table_path.to_owned(), + version, + } } + _ => err.into(), }) - .await } - async fn retry(&self, operation: Fn) -> Result + async fn retry(&self, operation: F, when: Wn) -> Result where - Fn: FnMut() -> Fut, - Fut: std::future::Future>>, + F: FnMut() -> Fut, + Fut: std::future::Future>, + Wn: Fn(&E) -> bool, { - let backoff = backoff::ExponentialBackoffBuilder::new() - .with_multiplier(2.) - .with_max_interval(Duration::from_secs(15)) - .with_max_elapsed_time(Some(self.config.max_elapsed_request_time)) - .build(); - backoff::future::retry(backoff, operation).await + use backon::Retryable; + let backoff = backon::ExponentialBuilder::default() + .with_factor(2.) + .with_max_delay(self.config.max_elapsed_request_time); + operation.retry(backoff).when(when).await } } diff --git a/crates/core/src/operations/merge/barrier.rs b/crates/core/src/operations/merge/barrier.rs index f11474291d..9084d721b7 100644 --- a/crates/core/src/operations/merge/barrier.rs +++ b/crates/core/src/operations/merge/barrier.rs @@ -443,14 +443,16 @@ impl UserDefinedLogicalNodeCore for MergeBarrier { } } -pub(crate) fn find_barrier_node(parent: &Arc) -> Option> { - //! Used to locate the physical Barrier Node after the planner converts the logical node - if parent.as_any().downcast_ref::().is_some() { +pub(crate) fn find_node( + parent: &Arc, +) -> Option> { + //! Used to locate a Node:: after the planner converts the logical node + if parent.as_any().downcast_ref::().is_some() { return Some(parent.to_owned()); } for child in &parent.children() { - let res = find_barrier_node(child); + let res = find_node::(child); if res.is_some() { return res; } diff --git a/crates/core/src/operations/merge/filter.rs b/crates/core/src/operations/merge/filter.rs new file mode 100644 index 0000000000..0745c55830 --- /dev/null +++ b/crates/core/src/operations/merge/filter.rs @@ -0,0 +1,943 @@ +//! Utility functions to determine early filters for file/partition pruning +use datafusion::functions_aggregate::expr_fn::{max, min}; +use std::collections::HashMap; + +use datafusion::execution::context::SessionState; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{ScalarValue, TableReference}; +use datafusion_expr::expr::{InList, Placeholder}; +use datafusion_expr::{Aggregate, BinaryExpr, LogicalPlan, Operator}; +use datafusion_expr::{Between, Expr}; + +use either::{Left, Right}; + +use itertools::Itertools; + +use crate::delta_datafusion::execute_plan_to_batch; +use crate::table::state::DeltaTableState; +use crate::DeltaResult; + +#[derive(Debug)] +enum ReferenceTableCheck { + HasReference(String), + NoReference, + Unknown, +} +impl ReferenceTableCheck { + fn has_reference(&self) -> bool { + matches!(self, ReferenceTableCheck::HasReference(_)) + } +} + +fn references_table(expr: &Expr, table: &TableReference) -> ReferenceTableCheck { + let res = match expr { + Expr::Alias(alias) => references_table(&alias.expr, table), + Expr::Column(col) => col + .relation + .as_ref() + .map(|rel| { + if rel == table { + ReferenceTableCheck::HasReference(col.name.to_owned()) + } else { + ReferenceTableCheck::NoReference + } + }) + .unwrap_or(ReferenceTableCheck::NoReference), + Expr::Negative(neg) => references_table(neg, table), + Expr::Cast(cast) => references_table(&cast.expr, table), + Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), + Expr::ScalarFunction(func) => { + if func.args.len() == 1 { + references_table(&func.args[0], table) + } else { + ReferenceTableCheck::Unknown + } + } + Expr::IsNull(inner) => references_table(inner, table), + Expr::Literal(_) => ReferenceTableCheck::NoReference, + _ => ReferenceTableCheck::Unknown, + }; + res +} + +fn construct_placeholder( + binary: BinaryExpr, + source_left: bool, + is_partition_column: bool, + column_name: String, + placeholders: &mut Vec, +) -> Option { + if is_partition_column { + let placeholder_name = format!("{column_name}_{}", placeholders.len()); + let placeholder = Expr::Placeholder(Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + + let (left, right, source_expr): (Box, Box, Expr) = if source_left { + (placeholder.into(), binary.clone().right, *binary.left) + } else { + (binary.clone().left, placeholder.into(), *binary.right) + }; + + let replaced = Expr::BinaryExpr(BinaryExpr { + left, + op: binary.op, + right, + }); + + placeholders.push(PredicatePlaceholder { + expr: source_expr, + alias: placeholder_name, + is_aggregate: false, + }); + + Some(replaced) + } else { + match binary.op { + Operator::Eq => { + let name_min = format!("{column_name}_{}_min", placeholders.len()); + let placeholder_min = Expr::Placeholder(Placeholder { + id: name_min.clone(), + data_type: None, + }); + let name_max = format!("{column_name}_{}_max", placeholders.len()); + let placeholder_max = Expr::Placeholder(Placeholder { + id: name_max.clone(), + data_type: None, + }); + let (source_expr, target_expr) = if source_left { + (*binary.left, *binary.right) + } else { + (*binary.right, *binary.left) + }; + let replaced = Expr::Between(Between { + expr: target_expr.into(), + negated: false, + low: placeholder_min.into(), + high: placeholder_max.into(), + }); + + placeholders.push(PredicatePlaceholder { + expr: min(source_expr.clone()), + alias: name_min, + is_aggregate: true, + }); + placeholders.push(PredicatePlaceholder { + expr: max(source_expr), + alias: name_max, + is_aggregate: true, + }); + Some(replaced) + } + _ => None, + } + } +} + +fn replace_placeholders(expr: Expr, placeholders: &HashMap) -> Expr { + expr.transform(&|expr| match expr { + Expr::Placeholder(Placeholder { id, .. }) => { + let value = placeholders[&id].clone(); + // Replace the placeholder with the value + Ok(Transformed::yes(Expr::Literal(value))) + } + _ => Ok(Transformed::no(expr)), + }) + .unwrap() + .data +} + +pub(crate) struct PredicatePlaceholder { + pub expr: Expr, + pub alias: String, + pub is_aggregate: bool, +} + +/// Takes the predicate provided and does three things: +/// +/// 1. for any relations between a source column and a partition target column, +/// replace source with a placeholder matching the name of the partition +/// columns +/// +/// 2. for any is equal relations between a source column and a non-partition target column, +/// replace source with is between expression with min(source_column) and max(source_column) placeholders +/// +/// 3. for any other relation with a source column, remove them. +/// +/// For example, for the predicate: +/// +/// `source.date = target.date and source.id = target.id and frob > 42` +/// +/// where `date` is a partition column, would result in the expr: +/// +/// `$date_0 = target.date and target.id between $id_1_min and $id_1_max and frob > 42` +/// +/// This leaves us with a predicate that we can push into delta scan after expanding it out to +/// a conjunction between the distinct partitions in the source input. +/// +pub(crate) fn generalize_filter( + predicate: Expr, + partition_columns: &Vec, + source_name: &TableReference, + target_name: &TableReference, + placeholders: &mut Vec, +) -> Option { + match predicate { + Expr::BinaryExpr(binary) => { + if references_table(&binary.right, source_name).has_reference() { + if let ReferenceTableCheck::HasReference(left_target) = + references_table(&binary.left, target_name) + { + return construct_placeholder( + binary, + false, + partition_columns.contains(&left_target), + left_target, + placeholders, + ); + } + return None; + } + if references_table(&binary.left, source_name).has_reference() { + if let ReferenceTableCheck::HasReference(right_target) = + references_table(&binary.right, target_name) + { + return construct_placeholder( + binary, + true, + partition_columns.contains(&right_target), + right_target, + placeholders, + ); + } + return None; + } + + let left = generalize_filter( + *binary.left, + partition_columns, + source_name, + target_name, + placeholders, + ); + let right = generalize_filter( + *binary.right, + partition_columns, + source_name, + target_name, + placeholders, + ); + + match (left, right) { + (None, None) => None, + (None, Some(one_side)) | (Some(one_side), None) => { + // in the case of an AND clause, it's safe to generalize the filter down to just one side of the AND. + // this is because this filter will be more permissive than the actual predicate, so we know that + // we will catch all data that could be matched by the predicate. For OR this is not the case - we + // could potentially eliminate one side of the predicate and the filter would only match half the + // cases that would have satisfied the match predicate. + match binary.op { + Operator::And => Some(one_side), + Operator::Or => None, + _ => None, + } + } + (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { + left: l.into(), + op: binary.op, + right: r.into(), + }) + .into(), + } + } + Expr::InList(in_list) => { + let compare_expr = match generalize_filter( + *in_list.expr, + partition_columns, + source_name, + target_name, + placeholders, + ) { + Some(expr) => expr, + None => return None, // Return early + }; + + let mut list_expr = Vec::new(); + for item in in_list.list.into_iter() { + match item { + // If it's a literal just immediately push it in list_expr so we can avoid the unnecessary generalizing + Expr::Literal(_) => list_expr.push(item), + _ => { + if let Some(item) = generalize_filter( + item.clone(), + partition_columns, + source_name, + target_name, + placeholders, + ) { + list_expr.push(item) + } + } + } + } + if !list_expr.is_empty() { + Expr::InList(InList { + expr: compare_expr.into(), + list: list_expr, + negated: in_list.negated, + }) + .into() + } else { + None + } + } + other => match references_table(&other, source_name) { + ReferenceTableCheck::HasReference(col) => { + let placeholder_name = format!("{col}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + + placeholders.push(PredicatePlaceholder { + expr: other, + alias: placeholder_name, + is_aggregate: true, + }); + Some(placeholder) + } + ReferenceTableCheck::NoReference => Some(other), + ReferenceTableCheck::Unknown => None, + }, + } +} + +pub(crate) async fn try_construct_early_filter( + join_predicate: Expr, + table_snapshot: &DeltaTableState, + session_state: &SessionState, + source: &LogicalPlan, + source_name: &TableReference, + target_name: &TableReference, +) -> DeltaResult> { + let table_metadata = table_snapshot.metadata(); + let partition_columns = &table_metadata.partition_columns; + + let mut placeholders = Vec::default(); + + match generalize_filter( + join_predicate, + partition_columns, + source_name, + target_name, + &mut placeholders, + ) { + None => Ok(None), + Some(filter) => { + if placeholders.is_empty() { + // if we haven't recognised any source predicates in the join predicate, return our filter with static only predicates + Ok(Some(filter)) + } else { + // if we have some filters, which depend on the source df, then collect the placeholders values from the source data + // We aggregate the distinct values for partitions with the group_columns and stats(min, max) for dynamic filter as agg_columns + // Can be translated into `SELECT partition1 as part1_0, min(id) as id_1_min, max(id) as id_1_max FROM source GROUP BY partition1` + let (agg_columns, group_columns) = placeholders.into_iter().partition_map(|p| { + if p.is_aggregate { + Left(p.expr.alias(p.alias)) + } else { + Right(p.expr.alias(p.alias)) + } + }); + let distinct_partitions = LogicalPlan::Aggregate(Aggregate::try_new( + source.clone().into(), + group_columns, + agg_columns, + )?); + let execution_plan = session_state + .create_physical_plan(&distinct_partitions) + .await?; + let items = execute_plan_to_batch(session_state, execution_plan).await?; + let placeholder_names = items + .schema() + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect_vec(); + let expr = (0..items.num_rows()) + .map(|i| { + let replacements = placeholder_names + .iter() + .map(|placeholder| { + let col = items.column_by_name(placeholder).unwrap(); + let value = ScalarValue::try_from_array(col, i)?; + DeltaResult::Ok((placeholder.to_owned(), value)) + }) + .try_collect()?; + Ok(replace_placeholders(filter.clone(), &replacements)) + }) + .collect::>>()? + .into_iter() + .reduce(Expr::or); + Ok(expr) + } + } + } +} + +#[cfg(test)] +mod tests { + use crate::operations::merge::tests::setup_table; + use crate::operations::merge::try_construct_early_filter; + use crate::writer::test_utils::get_arrow_schema; + + use arrow::record_batch::RecordBatch; + + use datafusion::datasource::provider_as_source; + + use datafusion::prelude::*; + use datafusion_common::Column; + use datafusion_common::ScalarValue; + use datafusion_common::TableReference; + use datafusion_expr::col; + + use datafusion_expr::Expr; + use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::Operator; + + use std::sync::Arc; + + #[tokio::test] + async fn test_try_construct_early_filter_with_partitions_expands() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let split_pred = { + fn split(expr: Expr, parts: &mut Vec<(String, String)>) { + match expr { + Expr::BinaryExpr(ex) if ex.op == Operator::Or => { + split(*ex.left, parts); + split(*ex.right, parts); + } + Expr::BinaryExpr(ex) if ex.op == Operator::Eq => { + let col = match *ex.right { + Expr::Column(col) => col.name, + ex => panic!("expected column in pred, got {ex}!"), + }; + + let value = match *ex.left { + Expr::Literal(ScalarValue::Utf8(Some(value))) => value, + ex => panic!("expected value in predicate, got {ex}!"), + }; + + parts.push((col, value)) + } + + expr => panic!("expected either = or OR, got {expr}"), + } + } + + let mut parts = vec![]; + split(pred.unwrap(), &mut parts); + parts.sort(); + parts + }; + + let expected_pred_parts = [ + ("id".to_owned(), "B".to_owned()), + ("id".to_owned(), "C".to_owned()), + ("id".to_owned(), "X".to_owned()), + ]; + + assert_eq!(split_pred, expected_pred_parts); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_range() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("B".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ); + assert_eq!(pred.unwrap(), filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_partition_and_range() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })) + .and( + col(Column { + relation: Some(source_name.clone()), + name: "modified".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + })), + ); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("B".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ) + .and( + Expr::Literal(ScalarValue::Utf8(Some("2023-07-04".to_string()))).eq(col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + })), + ); + assert_eq!(pred.unwrap(), filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_is_in_literals() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-05", + "2023-07-05", + ])), + ], + ) + .unwrap(); + let source_df = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source_plan = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source_df.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })) + .and(col("modified".to_owned()).in_list( + vec![lit("2023-07-05"), lit("2023-07-06"), lit("2023-07-07")], + false, + )); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source_plan, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("A".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ) + .and( + col(Column { + relation: None, + name: "modified".to_owned(), + }) + .in_list( + vec![ + Expr::Literal(ScalarValue::Utf8(Some("2023-07-05".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("2023-07-06".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("2023-07-07".to_string()))), + ], + false, + ), + ); + assert_eq!(pred.unwrap(), filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_is_in_columns() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-05", + "2023-07-05", + ])), + ], + ) + .unwrap(); + let source_df = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source_plan = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source_df.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })) + .and(col("modified".to_owned()).in_list( + vec![ + col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }), + col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + }), + ], + false, + )); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source_plan, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("A".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ) + .and( + col(Column { + relation: None, + name: "modified".to_owned(), + }) + .in_list( + vec![ + col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }), + col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + }), + ], + false, + ), + ); + assert_eq!(pred.unwrap(), filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_is_in_ident_and_cols() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_files_count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["A", "B", "C"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-05", + "2023-07-05", + ])), + ], + ) + .unwrap(); + let source_df = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source_plan = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source_df.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })) + .and(ident("source.id").in_list( + vec![ + col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }), + col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + }), + ], + false, + )); + + let pred = try_construct_early_filter( + join_predicate, + table.snapshot().unwrap(), + &ctx.state(), + &source_plan, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let filter = col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }) + .between( + Expr::Literal(ScalarValue::Utf8(Some("A".to_string()))), + Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), + ) + .and(ident("source.id").in_list( + vec![ + col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + }), + col(Column { + relation: Some(target_name.clone()), + name: "modified".to_owned(), + }), + ], + false, + )); + assert_eq!(pred.unwrap(), filter); + } +} diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 33f2200616..86930e3a26 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -29,6 +29,7 @@ //! ```` use std::collections::HashMap; +use std::fmt::Debug; use std::sync::Arc; use std::time::Instant; @@ -37,7 +38,6 @@ use datafusion::datasource::provider_as_source; use datafusion::error::Result as DataFusionResult; use datafusion::execution::context::SessionConfig; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion::functions_aggregate::expr_fn::{max, min}; use datafusion::logical_expr::build_join_schema; use datafusion::physical_plan::metrics::MetricBuilder; use datafusion::physical_planner::{ExtensionPlanner, PhysicalPlanner}; @@ -48,15 +48,12 @@ use datafusion::{ }; use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; -use datafusion_expr::expr::Placeholder; +use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - col, conditional_expressions::CaseBuilder, lit, when, Between, Expr, JoinType, + Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, }; -use datafusion_expr::{ - Aggregate, BinaryExpr, Extension, LogicalPlan, LogicalPlanBuilder, Operator, - UserDefinedLogicalNode, UNNAMED_TABLE, -}; -use either::{Left, Right}; + +use filter::try_construct_early_filter; use futures::future::BoxFuture; use itertools::Itertools; use parquet::file::properties::WriterProperties; @@ -72,13 +69,13 @@ use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, get_metric, MetricObserverExec}; use crate::delta_datafusion::planner::DeltaPlanner; use crate::delta_datafusion::{ - execute_plan_to_batch, register_store, DataFusionMixins, DeltaColumn, DeltaScanConfigBuilder, + register_store, DataFusionMixins, DeltaColumn, DeltaScan, DeltaScanConfigBuilder, DeltaSessionConfig, DeltaTableProvider, }; use crate::kernel::Action; use crate::logstore::LogStoreRef; use crate::operations::cdc::*; -use crate::operations::merge::barrier::find_barrier_node; +use crate::operations::merge::barrier::find_node; use crate::operations::transaction::CommitBuilder; use crate::operations::write::{write_execution_plan, write_execution_plan_cdc, WriterStatsConfig}; use crate::protocol::{DeltaOperation, MergePredicate}; @@ -86,6 +83,7 @@ use crate::table::state::DeltaTableState; use crate::{DeltaResult, DeltaTable, DeltaTableError}; mod barrier; +mod filter; const SOURCE_COLUMN: &str = "__delta_rs_source"; const TARGET_COLUMN: &str = "__delta_rs_target"; @@ -568,6 +566,10 @@ pub struct MergeMetrics { pub num_target_rows_copied: usize, /// Total number of rows written out pub num_output_rows: usize, + /// Amount of files considered during table scan + pub num_target_files_scanned: usize, + /// Amount of files not considered (pruned) during table scan + pub num_target_files_skipped_during_scan: usize, /// Number of files added to the sink(target) pub num_target_files_added: usize, /// Number of files removed from the sink(target) @@ -672,334 +674,6 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } -struct PredicatePlaceholder { - expr: Expr, - alias: String, - is_aggregate: bool, -} - -/// Takes the predicate provided and does three things: -/// -/// 1. for any relations between a source column and a partition target column, -/// replace source with a placeholder matching the name of the partition -/// columns -/// -/// 2. for any is equal relations between a source column and a non-partition target column, -/// replace source with is between expression with min(source_column) and max(source_column) placeholders -/// -/// 3. for any other relation with a source column, remove them. -/// -/// For example, for the predicate: -/// -/// `source.date = target.date and source.id = target.id and frob > 42` -/// -/// where `date` is a partition column, would result in the expr: -/// -/// `$date_0 = target.date and target.id between $id_1_min and $id_1_max and frob > 42` -/// -/// This leaves us with a predicate that we can push into delta scan after expanding it out to -/// a conjunction between the distinct partitions in the source input. -/// -fn generalize_filter( - predicate: Expr, - partition_columns: &Vec, - source_name: &TableReference, - target_name: &TableReference, - placeholders: &mut Vec, -) -> Option { - #[derive(Debug)] - enum ReferenceTableCheck { - HasReference(String), - NoReference, - Unknown, - } - impl ReferenceTableCheck { - fn has_reference(&self) -> bool { - matches!(self, ReferenceTableCheck::HasReference(_)) - } - } - fn references_table(expr: &Expr, table: &TableReference) -> ReferenceTableCheck { - let res = match expr { - Expr::Alias(alias) => references_table(&alias.expr, table), - Expr::Column(col) => col - .relation - .as_ref() - .map(|rel| { - if rel == table { - ReferenceTableCheck::HasReference(col.name.to_owned()) - } else { - ReferenceTableCheck::NoReference - } - }) - .unwrap_or(ReferenceTableCheck::NoReference), - Expr::Negative(neg) => references_table(neg, table), - Expr::Cast(cast) => references_table(&cast.expr, table), - Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), - Expr::ScalarFunction(func) => { - if func.args.len() == 1 { - references_table(&func.args[0], table) - } else { - ReferenceTableCheck::Unknown - } - } - Expr::IsNull(inner) => references_table(inner, table), - Expr::Literal(_) => ReferenceTableCheck::NoReference, - _ => ReferenceTableCheck::Unknown, - }; - res - } - - fn construct_placeholder( - binary: BinaryExpr, - source_left: bool, - is_partition_column: bool, - column_name: String, - placeholders: &mut Vec, - ) -> Option { - if is_partition_column { - let placeholder_name = format!("{column_name}_{}", placeholders.len()); - let placeholder = Expr::Placeholder(Placeholder { - id: placeholder_name.clone(), - data_type: None, - }); - - let (left, right, source_expr): (Box, Box, Expr) = if source_left { - (placeholder.into(), binary.clone().right, *binary.left) - } else { - (binary.clone().left, placeholder.into(), *binary.right) - }; - - let replaced = Expr::BinaryExpr(BinaryExpr { - left, - op: binary.op, - right, - }); - - placeholders.push(PredicatePlaceholder { - expr: source_expr, - alias: placeholder_name, - is_aggregate: false, - }); - - Some(replaced) - } else { - match binary.op { - Operator::Eq => { - let name_min = format!("{column_name}_{}_min", placeholders.len()); - let placeholder_min = Expr::Placeholder(Placeholder { - id: name_min.clone(), - data_type: None, - }); - let name_max = format!("{column_name}_{}_max", placeholders.len()); - let placeholder_max = Expr::Placeholder(Placeholder { - id: name_max.clone(), - data_type: None, - }); - let (source_expr, target_expr) = if source_left { - (*binary.left, *binary.right) - } else { - (*binary.right, *binary.left) - }; - let replaced = Expr::Between(Between { - expr: target_expr.into(), - negated: false, - low: placeholder_min.into(), - high: placeholder_max.into(), - }); - - placeholders.push(PredicatePlaceholder { - expr: min(source_expr.clone()), - alias: name_min, - is_aggregate: true, - }); - placeholders.push(PredicatePlaceholder { - expr: max(source_expr), - alias: name_max, - is_aggregate: true, - }); - Some(replaced) - } - _ => None, - } - } - } - - match predicate { - Expr::BinaryExpr(binary) => { - if references_table(&binary.right, source_name).has_reference() { - if let ReferenceTableCheck::HasReference(left_target) = - references_table(&binary.left, target_name) - { - return construct_placeholder( - binary, - false, - partition_columns.contains(&left_target), - left_target, - placeholders, - ); - } - return None; - } - if references_table(&binary.left, source_name).has_reference() { - if let ReferenceTableCheck::HasReference(right_target) = - references_table(&binary.right, target_name) - { - return construct_placeholder( - binary, - true, - partition_columns.contains(&right_target), - right_target, - placeholders, - ); - } - return None; - } - - let left = generalize_filter( - *binary.left, - partition_columns, - source_name, - target_name, - placeholders, - ); - let right = generalize_filter( - *binary.right, - partition_columns, - source_name, - target_name, - placeholders, - ); - - match (left, right) { - (None, None) => None, - (None, Some(one_side)) | (Some(one_side), None) => { - // in the case of an AND clause, it's safe to generalize the filter down to just one side of the AND. - // this is because this filter will be more permissive than the actual predicate, so we know that - // we will catch all data that could be matched by the predicate. For OR this is not the case - we - // could potentially eliminate one side of the predicate and the filter would only match half the - // cases that would have satisfied the match predicate. - match binary.op { - Operator::And => Some(one_side), - Operator::Or => None, - _ => None, - } - } - (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { - left: l.into(), - op: binary.op, - right: r.into(), - }) - .into(), - } - } - other => match references_table(&other, source_name) { - ReferenceTableCheck::HasReference(col) => { - let placeholder_name = format!("{col}_{}", placeholders.len()); - - let placeholder = Expr::Placeholder(Placeholder { - id: placeholder_name.clone(), - data_type: None, - }); - - placeholders.push(PredicatePlaceholder { - expr: other, - alias: placeholder_name, - is_aggregate: true, - }); - - Some(placeholder) - } - ReferenceTableCheck::NoReference => Some(other), - ReferenceTableCheck::Unknown => None, - }, - } -} - -fn replace_placeholders(expr: Expr, placeholders: &HashMap) -> Expr { - expr.transform(&|expr| match expr { - Expr::Placeholder(Placeholder { id, .. }) => { - let value = placeholders[&id].clone(); - // Replace the placeholder with the value - Ok(Transformed::yes(Expr::Literal(value))) - } - _ => Ok(Transformed::no(expr)), - }) - .unwrap() - .data -} - -async fn try_construct_early_filter( - join_predicate: Expr, - table_snapshot: &DeltaTableState, - session_state: &SessionState, - source: &LogicalPlan, - source_name: &TableReference, - target_name: &TableReference, -) -> DeltaResult> { - let table_metadata = table_snapshot.metadata(); - let partition_columns = &table_metadata.partition_columns; - - let mut placeholders = Vec::default(); - - match generalize_filter( - join_predicate, - partition_columns, - source_name, - target_name, - &mut placeholders, - ) { - None => Ok(None), - Some(filter) => { - if placeholders.is_empty() { - // if we haven't recognised any source predicates in the join predicate, return our filter with static only predicates - Ok(Some(filter)) - } else { - // if we have some filters, which depend on the source df, then collect the placeholders values from the source data - // We aggregate the distinct values for partitions with the group_columns and stats(min, max) for dynamic filter as agg_columns - // Can be translated into `SELECT partition1 as part1_0, min(id) as id_1_min, max(id) as id_1_max FROM source GROUP BY partition1` - let (agg_columns, group_columns) = placeholders.into_iter().partition_map(|p| { - if p.is_aggregate { - Left(p.expr.alias(p.alias)) - } else { - Right(p.expr.alias(p.alias)) - } - }); - let distinct_partitions = LogicalPlan::Aggregate(Aggregate::try_new( - source.clone().into(), - group_columns, - agg_columns, - )?); - let execution_plan = session_state - .create_physical_plan(&distinct_partitions) - .await?; - let items = execute_plan_to_batch(session_state, execution_plan).await?; - let placeholder_names = items - .schema() - .fields() - .iter() - .map(|f| f.name().to_owned()) - .collect_vec(); - let expr = (0..items.num_rows()) - .map(|i| { - let replacements = placeholder_names - .iter() - .map(|placeholder| { - let col = items.column_by_name(placeholder).unwrap(); - let value = ScalarValue::try_from_array(col, i)?; - DeltaResult::Ok((placeholder.to_owned(), value)) - }) - .try_collect()?; - Ok(replace_placeholders(filter.clone(), &replacements)) - }) - .collect::>>()? - .into_iter() - .reduce(Expr::or); - Ok(expr) - } - } - } -} - #[allow(clippy::too_many_arguments)] async fn execute( predicate: Expression, @@ -1509,9 +1183,9 @@ async fn execute( let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); let source_count = find_metric_node(SOURCE_COUNT_ID, &write).ok_or_else(err)?; let op_count = find_metric_node(OUTPUT_COUNT_ID, &write).ok_or_else(err)?; - let barrier = find_barrier_node(&write).ok_or_else(err)?; + let barrier = find_node::(&write).ok_or_else(err)?; + let scan_count = find_node::(&write).ok_or_else(err)?; - // write projected records let table_partition_cols = current_metadata.partition_columns.clone(); let writer_stats_config = WriterStatsConfig::new( @@ -1586,6 +1260,7 @@ async fn execute( let source_count_metrics = source_count.metrics().unwrap(); let target_count_metrics = op_count.metrics().unwrap(); + let scan_count_metrics = scan_count.metrics().unwrap(); metrics.num_source_rows = get_metric(&source_count_metrics, SOURCE_COUNT_METRIC); metrics.num_target_rows_inserted = get_metric(&target_count_metrics, TARGET_INSERTED_METRIC); @@ -1595,7 +1270,8 @@ async fn execute( metrics.num_output_rows = metrics.num_target_rows_inserted + metrics.num_target_rows_updated + metrics.num_target_rows_copied; - + metrics.num_target_files_scanned = get_metric(&scan_count_metrics, "files_scanned"); + metrics.num_target_files_skipped_during_scan = get_metric(&scan_count_metrics, "files_pruned"); metrics.execution_time_ms = Instant::now().duration_since(exec_start).as_millis() as u64; let app_metadata = &mut commit_properties.app_metadata; @@ -1703,8 +1379,7 @@ mod tests { use crate::kernel::PrimitiveType; use crate::kernel::StructField; use crate::operations::load_cdf::collect_batches; - use crate::operations::merge::generalize_filter; - use crate::operations::merge::try_construct_early_filter; + use crate::operations::merge::filter::generalize_filter; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1718,18 +1393,14 @@ mod tests { use arrow_schema::DataType as ArrowDataType; use arrow_schema::Field; use datafusion::assert_batches_sorted_eq; - use datafusion::datasource::provider_as_source; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion_common::Column; - use datafusion_common::ScalarValue; use datafusion_common::TableReference; use datafusion_expr::col; use datafusion_expr::expr::Placeholder; use datafusion_expr::lit; use datafusion_expr::Expr; - use datafusion_expr::LogicalPlanBuilder; - use datafusion_expr::Operator; use itertools::Itertools; use regex::Regex; use serde_json::json; @@ -1738,7 +1409,7 @@ mod tests { use super::MergeMetrics; - async fn setup_table(partitions: Option>) -> DeltaTable { + pub(crate) async fn setup_table(partitions: Option>) -> DeltaTable { let table_schema = get_delta_schema(); let table = DeltaOps::new_in_memory() @@ -2301,6 +1972,115 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_merge_partitions_with_in() { + /* Validate the join predicate works with table partitions */ + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["modified"])).await; + + let table = write_data(table, &schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_files_count(), 2); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge( + source, + col("target.id") + .eq(col("source.id")) + .and(col("target.id").in_list( + vec![ + col("source.id"), + col("source.modified"), + col("source.value"), + ], + false, + )) + .and(col("target.modified").in_list(vec![lit("2021-02-02")], false)), + ) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate(col("target.value").eq(lit(1))) + .update("value", col("target.value") + lit(1)) + }) + .unwrap() + .when_not_matched_by_source_update(|update| { + update + .predicate(col("target.modified").eq(lit("2021-02-01"))) + .update("value", col("target.value") - lit(1)) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert!(table.get_files_count() >= 3); + assert!(metrics.num_target_files_added >= 3); + assert_eq!(metrics.num_target_files_removed, 2); + assert_eq!(metrics.num_target_rows_copied, 1); + assert_eq!(metrics.num_target_rows_updated, 3); + assert_eq!(metrics.num_target_rows_inserted, 2); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 6); + assert_eq!(metrics.num_source_rows, 3); + + let commit_info = table.history(None).await.unwrap(); + let last_commit = &commit_info[0]; + let parameters = last_commit.operation_parameters.clone().unwrap(); + assert!(!parameters.contains_key("predicate")); + assert_eq!( + parameters["mergePredicate"], + "target.id = source.id AND \ + target.id IN (source.id, source.modified, source.value) AND \ + target.modified IN ('2021-02-02')" + ); + + let expected = vec![ + "+----+-------+------------+", + "| id | value | modified |", + "+----+-------+------------+", + "| A | 2 | 2021-02-01 |", + "| B | 9 | 2021-02-01 |", + "| B | 10 | 2021-02-02 |", + "| C | 20 | 2023-07-04 |", + "| D | 100 | 2021-02-02 |", + "| X | 30 | 2023-07-04 |", + "+----+-------+------------+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_delete_matched() { // Validate behaviours of match delete @@ -2943,256 +2723,6 @@ mod tests { assert_eq!(generalized, expected_filter); } - #[tokio::test] - async fn test_try_construct_early_filter_with_partitions_expands() { - let schema = get_arrow_schema(&None); - let table = setup_table(Some(vec!["id"])).await; - - assert_eq!(table.version(), 0); - assert_eq!(table.get_files_count(), 0); - - let ctx = SessionContext::new(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), - Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), - Arc::new(arrow::array::StringArray::from(vec![ - "2021-02-02", - "2023-07-04", - "2023-07-04", - ])), - ], - ) - .unwrap(); - let source = ctx.read_batch(batch).unwrap(); - - let source_name = TableReference::parse_str("source"); - let target_name = TableReference::parse_str("target"); - - let source = LogicalPlanBuilder::scan( - source_name.clone(), - provider_as_source(source.into_view()), - None, - ) - .unwrap() - .build() - .unwrap(); - - let join_predicate = col(Column { - relation: Some(source_name.clone()), - name: "id".to_owned(), - }) - .eq(col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - })); - - let pred = try_construct_early_filter( - join_predicate, - table.snapshot().unwrap(), - &ctx.state(), - &source, - &source_name, - &target_name, - ) - .await - .unwrap(); - - assert!(pred.is_some()); - - let split_pred = { - fn split(expr: Expr, parts: &mut Vec<(String, String)>) { - match expr { - Expr::BinaryExpr(ex) if ex.op == Operator::Or => { - split(*ex.left, parts); - split(*ex.right, parts); - } - Expr::BinaryExpr(ex) if ex.op == Operator::Eq => { - let col = match *ex.right { - Expr::Column(col) => col.name, - ex => panic!("expected column in pred, got {ex}!"), - }; - - let value = match *ex.left { - Expr::Literal(ScalarValue::Utf8(Some(value))) => value, - ex => panic!("expected value in predicate, got {ex}!"), - }; - - parts.push((col, value)) - } - - expr => panic!("expected either = or OR, got {expr}"), - } - } - - let mut parts = vec![]; - split(pred.unwrap(), &mut parts); - parts.sort(); - parts - }; - - let expected_pred_parts = [ - ("id".to_owned(), "B".to_owned()), - ("id".to_owned(), "C".to_owned()), - ("id".to_owned(), "X".to_owned()), - ]; - - assert_eq!(split_pred, expected_pred_parts); - } - - #[tokio::test] - async fn test_try_construct_early_filter_with_range() { - let schema = get_arrow_schema(&None); - let table = setup_table(Some(vec!["modified"])).await; - - assert_eq!(table.version(), 0); - assert_eq!(table.get_files_count(), 0); - - let ctx = SessionContext::new(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(arrow::array::StringArray::from(vec!["B", "C"])), - Arc::new(arrow::array::Int32Array::from(vec![10, 20])), - Arc::new(arrow::array::StringArray::from(vec![ - "2023-07-04", - "2023-07-04", - ])), - ], - ) - .unwrap(); - let source = ctx.read_batch(batch).unwrap(); - - let source_name = TableReference::parse_str("source"); - let target_name = TableReference::parse_str("target"); - - let source = LogicalPlanBuilder::scan( - source_name.clone(), - provider_as_source(source.into_view()), - None, - ) - .unwrap() - .build() - .unwrap(); - - let join_predicate = col(Column { - relation: Some(source_name.clone()), - name: "id".to_owned(), - }) - .eq(col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - })); - - let pred = try_construct_early_filter( - join_predicate, - table.snapshot().unwrap(), - &ctx.state(), - &source, - &source_name, - &target_name, - ) - .await - .unwrap(); - - assert!(pred.is_some()); - - let filter = col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - }) - .between( - Expr::Literal(ScalarValue::Utf8(Some("B".to_string()))), - Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), - ); - assert_eq!(pred.unwrap(), filter); - } - - #[tokio::test] - async fn test_try_construct_early_filter_with_partition_and_range() { - let schema = get_arrow_schema(&None); - let table = setup_table(Some(vec!["modified"])).await; - - assert_eq!(table.version(), 0); - assert_eq!(table.get_files_count(), 0); - - let ctx = SessionContext::new(); - let batch = RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(arrow::array::StringArray::from(vec!["B", "C"])), - Arc::new(arrow::array::Int32Array::from(vec![10, 20])), - Arc::new(arrow::array::StringArray::from(vec![ - "2023-07-04", - "2023-07-04", - ])), - ], - ) - .unwrap(); - let source = ctx.read_batch(batch).unwrap(); - - let source_name = TableReference::parse_str("source"); - let target_name = TableReference::parse_str("target"); - - let source = LogicalPlanBuilder::scan( - source_name.clone(), - provider_as_source(source.into_view()), - None, - ) - .unwrap() - .build() - .unwrap(); - - let join_predicate = col(Column { - relation: Some(source_name.clone()), - name: "id".to_owned(), - }) - .eq(col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - })) - .and( - col(Column { - relation: Some(source_name.clone()), - name: "modified".to_owned(), - }) - .eq(col(Column { - relation: Some(target_name.clone()), - name: "modified".to_owned(), - })), - ); - - let pred = try_construct_early_filter( - join_predicate, - table.snapshot().unwrap(), - &ctx.state(), - &source, - &source_name, - &target_name, - ) - .await - .unwrap(); - - assert!(pred.is_some()); - - let filter = col(Column { - relation: Some(target_name.clone()), - name: "id".to_owned(), - }) - .between( - Expr::Literal(ScalarValue::Utf8(Some("B".to_string()))), - Expr::Literal(ScalarValue::Utf8(Some("C".to_string()))), - ) - .and( - Expr::Literal(ScalarValue::Utf8(Some("2023-07-04".to_string()))).eq(col(Column { - relation: Some(target_name.clone()), - name: "modified".to_owned(), - })), - ); - assert_eq!(pred.unwrap(), filter); - } - #[tokio::test] async fn test_merge_pushdowns() { //See https://github.com/delta-io/delta-rs/issues/2158 diff --git a/crates/core/src/operations/mod.rs b/crates/core/src/operations/mod.rs index 515143f088..676098c832 100644 --- a/crates/core/src/operations/mod.rs +++ b/crates/core/src/operations/mod.rs @@ -283,6 +283,22 @@ pub fn get_num_idx_cols_and_stats_columns( ) } +/// Get the target_file_size from the table configuration in the sates +/// If table_config does not exist (only can occur in the first write action) it takes +/// the configuration that was passed to the writerBuilder. +pub(crate) fn get_target_file_size( + config: &Option>, + configuration: &HashMap>, +) -> i64 { + match &config { + Some(conf) => conf.target_file_size(), + _ => configuration + .get("delta.targetFileSize") + .and_then(|v| v.clone().map(|v| v.parse::().unwrap())) + .unwrap_or(crate::table::config::DEFAULT_TARGET_FILE_SIZE), + } +} + #[cfg(feature = "datafusion")] mod datafusion_utils { use datafusion::execution::context::SessionState; diff --git a/crates/core/src/operations/optimize.rs b/crates/core/src/operations/optimize.rs index 632b6a9474..e00fd6451e 100644 --- a/crates/core/src/operations/optimize.rs +++ b/crates/core/src/operations/optimize.rs @@ -203,9 +203,9 @@ pub struct OptimizeBuilder<'a> { commit_properties: CommitProperties, /// Whether to preserve insertion order within files (default false) preserve_insertion_order: bool, - /// Max number of concurrent tasks (default is number of cpus) + /// Maximum number of concurrent tasks (default is number of cpus) max_concurrent_tasks: usize, - /// Maximum number of bytes that are allowed to spill to disk + /// Maximum number of bytes allowed in memory before spilling to disk max_spill_size: usize, /// Optimize type optimize_type: OptimizeType, diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index e12edf8be4..5c44bf1703 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -27,7 +27,7 @@ use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; -use std::time::{SystemTime, UNIX_EPOCH}; +use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::vec; use arrow_array::RecordBatch; @@ -46,6 +46,7 @@ use futures::future::BoxFuture; use futures::StreamExt; use object_store::prefix::PrefixStore; use parquet::file::properties::WriterProperties; +use serde::{Deserialize, Serialize}; use tracing::log::*; use super::cdc::should_write_cdc; @@ -60,7 +61,9 @@ use crate::delta_datafusion::{ }; use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{Action, Add, AddCDCFile, Metadata, PartitionsExt, Remove, StructType}; +use crate::kernel::{ + Action, ActionType, Add, AddCDCFile, Metadata, PartitionsExt, Remove, StructType, +}; use crate::logstore::LogStoreRef; use crate::operations::cast::{cast_record_batch, merge_schema::merge_arrow_schema}; use crate::protocol::{DeltaOperation, SaveMode}; @@ -162,6 +165,21 @@ pub struct WriteBuilder { configuration: HashMap>, } +#[derive(Default, Debug, Serialize, Deserialize)] +/// Metrics for the Write Operation +pub struct WriteMetrics { + /// Number of files added + pub num_added_files: usize, + /// Number of files removed + pub num_removed_files: usize, + /// Number of partitions + pub num_partitions: usize, + /// Number of rows added + pub num_added_rows: usize, + /// Time taken to execute the entire operation + pub execution_time_ms: u64, +} + impl super::Operation<()> for WriteBuilder {} impl WriteBuilder { @@ -398,7 +416,6 @@ async fn write_execution_plan_with_predicate( } _ => checker, }; - // Write data to disk let mut tasks = vec![]; for i in 0..plan.properties().output_partitioning().partition_count() { @@ -766,6 +783,9 @@ impl std::future::IntoFuture for WriteBuilder { let this = self; Box::pin(async move { + let mut metrics = WriteMetrics::default(); + let exec_start = Instant::now(); + if this.mode == SaveMode::Overwrite { if let Some(snapshot) = &this.snapshot { PROTOCOL.check_append_only(&snapshot.snapshot)?; @@ -857,6 +877,8 @@ impl std::future::IntoFuture for WriteBuilder { let data = if !partition_columns.is_empty() { // TODO partitioning should probably happen in its own plan ... let mut partitions: HashMap> = HashMap::new(); + let mut num_partitions = 0; + let mut num_added_rows = 0; for batch in batches { let real_batch = match new_schema.clone() { Some(new_schema) => cast_record_batch( @@ -873,7 +895,9 @@ impl std::future::IntoFuture for WriteBuilder { partition_columns.clone(), &real_batch, )?; + num_partitions += divided.len(); for part in divided { + num_added_rows += part.record_batch.num_rows(); let key = part.partition_values.hive_partition_path(); match partitions.get_mut(&key) { Some(part_batches) => { @@ -885,11 +909,14 @@ impl std::future::IntoFuture for WriteBuilder { } } } + metrics.num_partitions = num_partitions; + metrics.num_added_rows = num_added_rows; partitions.into_values().collect::>() } else { match new_schema { Some(ref new_schema) => { let mut new_batches = vec![]; + let mut num_added_rows = 0; for batch in batches { new_batches.push(cast_record_batch( &batch, @@ -897,10 +924,15 @@ impl std::future::IntoFuture for WriteBuilder { this.safe_cast, schema_drift, // Schema drifted so we have to add the missing columns/structfields. )?); + num_added_rows += batch.num_rows(); } + metrics.num_added_rows = num_added_rows; vec![new_batches] } - None => vec![batches], + None => { + metrics.num_added_rows = batches.iter().map(|b| b.num_rows()).sum(); + vec![batches] + } } }; @@ -980,6 +1012,9 @@ impl std::future::IntoFuture for WriteBuilder { .as_ref() .map(|snapshot| snapshot.table_config()); + let target_file_size = this.target_file_size.or_else(|| { + Some(super::get_target_file_size(&config, &this.configuration) as usize) + }); let (num_indexed_cols, stats_columns) = super::get_num_idx_cols_and_stats_columns(config, this.configuration); @@ -987,6 +1022,7 @@ impl std::future::IntoFuture for WriteBuilder { num_indexed_cols, stats_columns, }; + // Here we need to validate if the new data conforms to a predicate if one is provided let add_actions = write_execution_plan_with_predicate( predicate.clone(), @@ -995,13 +1031,14 @@ impl std::future::IntoFuture for WriteBuilder { plan.clone(), partition_columns.clone(), this.log_store.object_store().clone(), - this.target_file_size, + target_file_size, this.write_batch_size, this.writer_properties.clone(), writer_stats_config.clone(), None, ) .await?; + metrics.num_added_files = add_actions.len(); actions.extend(add_actions); // Collect remove actions if we are overwriting the table @@ -1077,8 +1114,15 @@ impl std::future::IntoFuture for WriteBuilder { } }; } + metrics.num_removed_files = actions + .iter() + .filter(|a| a.action_type() == ActionType::Remove) + .count(); } + metrics.execution_time_ms = + Instant::now().duration_since(exec_start).as_millis() as u64; + let operation = DeltaOperation::Write { mode: this.mode, partition_by: if !partition_columns.is_empty() { @@ -1089,7 +1133,13 @@ impl std::future::IntoFuture for WriteBuilder { predicate: predicate_str, }; - let commit = CommitBuilder::from(this.commit_properties) + let mut commit_properties = this.commit_properties.clone(); + commit_properties.app_metadata.insert( + "operationMetrics".to_owned(), + serde_json::to_value(&metrics)?, + ); + + let commit = CommitBuilder::from(commit_properties) .with_actions(actions) .build( this.snapshot.as_ref().map(|f| f as &dyn TableReference), @@ -1189,12 +1239,33 @@ mod tests { use itertools::Itertools; use serde_json::{json, Value}; + async fn get_write_metrics(table: DeltaTable) -> WriteMetrics { + let mut commit_info = table.history(Some(1)).await.unwrap(); + let metrics = commit_info + .first_mut() + .unwrap() + .info + .remove("operationMetrics") + .unwrap(); + serde_json::from_value(metrics).unwrap() + } + + fn assert_common_write_metrics(write_metrics: WriteMetrics) { + assert!(write_metrics.execution_time_ms > 0); + assert!(write_metrics.num_added_files > 0); + } + #[tokio::test] async fn test_write_when_delta_table_is_append_only() { let table = setup_table_with_configuration(DeltaConfigKey::AppendOnly, Some("true")).await; let batch = get_record_batch(None, false); // Append let table = write_batch(table, batch.clone()).await; + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, batch.num_rows()); + assert_eq!(write_metrics.num_removed_files, 0); + assert_common_write_metrics(write_metrics); + // Overwrite let _err = DeltaOps(table) .write(vec![batch]) @@ -1226,6 +1297,12 @@ mod tests { .unwrap(); assert_eq!(table.version(), 1); assert_eq!(table.get_files_count(), 1); + + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, batch.num_rows()); + assert_eq!(write_metrics.num_added_files, table.get_files_count()); + assert_common_write_metrics(write_metrics); + table.load().await.unwrap(); assert_eq!(table.history(None).await.unwrap().len(), 2); assert_eq!( @@ -1233,7 +1310,7 @@ mod tests { .info .clone() .into_iter() - .filter(|(k, _)| k != "clientVersion") + .filter(|(k, _)| k == "k1") .collect::>(), metadata ); @@ -1249,6 +1326,11 @@ mod tests { .unwrap(); assert_eq!(table.version(), 2); assert_eq!(table.get_files_count(), 2); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, batch.num_rows()); + assert_eq!(write_metrics.num_added_files, 1); + assert_common_write_metrics(write_metrics); + table.load().await.unwrap(); assert_eq!(table.history(None).await.unwrap().len(), 3); assert_eq!( @@ -1256,7 +1338,7 @@ mod tests { .info .clone() .into_iter() - .filter(|(k, _)| k != "clientVersion") + .filter(|(k, _)| k == "k1") .collect::>(), metadata ); @@ -1265,13 +1347,18 @@ mod tests { let metadata: HashMap = HashMap::from_iter(vec![("k2".to_string(), json!("v2.1"))]); let mut table = DeltaOps(table) - .write(vec![batch]) + .write(vec![batch.clone()]) .with_save_mode(SaveMode::Overwrite) .with_commit_properties(CommitProperties::default().with_metadata(metadata.clone())) .await .unwrap(); assert_eq!(table.version(), 3); assert_eq!(table.get_files_count(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, batch.num_rows()); + assert!(write_metrics.num_removed_files > 0); + assert_common_write_metrics(write_metrics); + table.load().await.unwrap(); assert_eq!(table.history(None).await.unwrap().len(), 4); assert_eq!( @@ -1279,7 +1366,7 @@ mod tests { .info .clone() .into_iter() - .filter(|(k, _)| k != "clientVersion") + .filter(|(k, _)| k == "k2") .collect::>(), metadata ); @@ -1302,6 +1389,9 @@ mod tests { ) .unwrap(); let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap(); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 2); + assert_common_write_metrics(write_metrics); let schema = Arc::new(ArrowSchema::new(vec![Field::new( "value", @@ -1326,6 +1416,10 @@ mod tests { .await .unwrap(); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 3); + assert_common_write_metrics(write_metrics); + let expected = [ "+-------+", "| value |", @@ -1359,6 +1453,10 @@ mod tests { .unwrap(); let table = DeltaOps::new_in_memory().write(vec![batch]).await.unwrap(); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 1); + assert_common_write_metrics(write_metrics); + let schema = Arc::new(ArrowSchema::new(vec![Field::new( "value", DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".to_string().into())), @@ -1394,7 +1492,9 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); - assert_eq!(table.get_files_count(), 1) + assert_eq!(table.get_files_count(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); } #[tokio::test] @@ -1408,6 +1508,10 @@ mod tests { .unwrap(); assert_eq!(table.version(), 0); assert_eq!(table.get_files_count(), 2); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert!(write_metrics.num_partitions > 0); + assert_eq!(write_metrics.num_added_files, 2); + assert_common_write_metrics(write_metrics); let table = DeltaOps::new_in_memory() .write(vec![batch]) @@ -1416,7 +1520,12 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); - assert_eq!(table.get_files_count(), 4) + assert_eq!(table.get_files_count(), 4); + + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert!(write_metrics.num_partitions > 0); + assert_eq!(write_metrics.num_added_files, 4); + assert_common_write_metrics(write_metrics); } #[tokio::test] @@ -1429,6 +1538,9 @@ mod tests { .unwrap(); assert_eq!(table.version(), 0); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); for field in batch.schema().fields() { if field.name() != "modified" { @@ -1475,6 +1587,9 @@ mod tests { let fields = new_schema.fields(); let names = fields.map(|f| f.name()).collect::>(); assert_eq!(names, vec!["id", "value", "modified", "inserted_by"]); + + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); } #[tokio::test] @@ -1488,6 +1603,10 @@ mod tests { .unwrap(); assert_eq!(table.version(), 0); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert!(write_metrics.num_partitions > 0); + assert_common_write_metrics(write_metrics); + let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); for field in batch.schema().fields() { if field.name() != "modified" { @@ -1536,6 +1655,10 @@ mod tests { assert_eq!(names, vec!["id", "inserted_by", "modified", "value"]); let part_cols = table.metadata().unwrap().partition_columns.clone(); assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions + + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert!(write_metrics.num_partitions > 0); + assert_common_write_metrics(write_metrics); } #[tokio::test] @@ -1547,7 +1670,8 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); - + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); for field in batch.schema().fields() { if field.name() != "modified" { @@ -1600,6 +1724,8 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); @@ -1655,6 +1781,8 @@ mod tests { let table = DeltaOps(table).write(vec![batch.clone()]).await.unwrap(); assert_eq!(table.version(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); let schema: StructType = serde_json::from_value(json!({ "type": "struct", @@ -1676,7 +1804,7 @@ mod tests { assert_eq!(table.version(), 0); let table = DeltaOps(table).write(vec![batch.clone()]).await; - assert!(table.is_err()) + assert!(table.is_err()); } #[tokio::test] @@ -1697,6 +1825,8 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); let actual = get_data(&table).await; let expected = DataType::Struct(Fields::from(vec![Field::new( @@ -1735,6 +1865,8 @@ mod tests { .with_partition_columns(["string"]) .await .unwrap(); + let write_metrics: WriteMetrics = get_write_metrics(_table.clone()).await; + assert_common_write_metrics(write_metrics); let table = crate::open_table(tmp_path.as_os_str().to_str().unwrap()) .await @@ -1778,6 +1910,9 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 4); + assert_common_write_metrics(write_metrics); let batch_add = RecordBatch::try_new( Arc::clone(&schema), @@ -1796,6 +1931,9 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 1); + assert_common_write_metrics(write_metrics); let expected = [ "+----+-------+------------+", @@ -1834,6 +1972,8 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); // Take clones of these before an operation resulting in error, otherwise it will // be impossible to refer to an in-memory table @@ -1876,6 +2016,8 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 0); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_common_write_metrics(write_metrics); let batch_add = RecordBatch::try_new( Arc::clone(&schema), @@ -1898,6 +2040,9 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 3); + assert_common_write_metrics(write_metrics); let expected = [ "+----+-------+------------+", @@ -1959,6 +2104,9 @@ mod tests { .await .expect("Failed to write first batch"); assert_eq!(table.version(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 3); + assert_common_write_metrics(write_metrics); let table = DeltaOps(table) .write([second_batch]) @@ -1966,6 +2114,10 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 2); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 1); + assert!(write_metrics.num_removed_files > 0); + assert_common_write_metrics(write_metrics); let snapshot_bytes = table .log_store @@ -2025,6 +2177,10 @@ mod tests { .await .expect("Failed to write first batch"); assert_eq!(table.version(), 1); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 3); + assert!(write_metrics.num_partitions > 0); + assert_common_write_metrics(write_metrics); let table = DeltaOps(table) .write([second_batch]) @@ -2033,6 +2189,12 @@ mod tests { .await .unwrap(); assert_eq!(table.version(), 2); + let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; + assert_eq!(write_metrics.num_added_rows, 1); + assert!(write_metrics.num_partitions > 0); + assert!(write_metrics.num_removed_files > 0); + assert_common_write_metrics(write_metrics); + let snapshot_bytes = table .log_store .read_commit_entry(2) diff --git a/crates/core/src/table/config.rs b/crates/core/src/table/config.rs index 47307cfecd..68b41d6f67 100644 --- a/crates/core/src/table/config.rs +++ b/crates/core/src/table/config.rs @@ -210,6 +210,8 @@ pub struct TableConfig<'a>(pub(crate) &'a HashMap>); /// Default num index cols pub const DEFAULT_NUM_INDEX_COLS: i32 = 32; +/// Default target file size +pub const DEFAULT_TARGET_FILE_SIZE: i64 = 104857600; impl<'a> TableConfig<'a> { table_config!( diff --git a/docs/usage/writing/writing-to-s3-with-locking-provider.md b/docs/usage/writing/writing-to-s3-with-locking-provider.md index bbe2fa958c..6a275d685a 100644 --- a/docs/usage/writing/writing-to-s3-with-locking-provider.md +++ b/docs/usage/writing/writing-to-s3-with-locking-provider.md @@ -95,6 +95,7 @@ In DynamoDB, you need those permissions: - dynamodb:Query - dynamodb:PutItem - dynamodb:UpdateItem +- dynamodb:DeleteItem ### Enabling concurrent writes for alternative clients diff --git a/python/Cargo.toml b/python/Cargo.toml index 70eb378e20..b851425d53 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "deltalake-python" -version = "0.19.1" +version = "0.19.2" authors = ["Qingping Hou ", "Will Jones "] homepage = "https://github.com/delta-io/delta-rs" license = "Apache-2.0" diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 7d63cfa148..17cb1c04f4 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple, Union import pyarrow import pyarrow.fs as fs -from deltalake.writer import AddAction, WriterProperties +from deltalake.writer import AddAction, PostCommitHookProperties, WriterProperties __version__: str @@ -59,7 +59,7 @@ class RawDeltaTable: retention_hours: Optional[int], enforce_retention_duration: bool, custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> List[str]: ... def compact_optimize( self, @@ -69,7 +69,7 @@ class RawDeltaTable: min_commit_interval: Optional[int], writer_properties: Optional[WriterProperties], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> str: ... def z_order_optimize( self, @@ -81,26 +81,26 @@ class RawDeltaTable: min_commit_interval: Optional[int], writer_properties: Optional[WriterProperties], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> str: ... def add_columns( self, fields: List[Field], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> None: ... def add_constraints( self, constraints: Dict[str, str], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> None: ... def drop_constraints( self, name: str, raise_if_not_exists: bool, custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> None: ... def set_table_properties( self, @@ -127,13 +127,13 @@ class RawDeltaTable: predicate: Optional[str], writer_properties: Optional[WriterProperties], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> str: ... def repair( self, dry_run: bool, custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> str: ... def update( self, @@ -142,9 +142,9 @@ class RawDeltaTable: writer_properties: Optional[WriterProperties], safe_cast: bool, custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> str: ... - def merge_execute( + def create_merge_builder( self, source: pyarrow.RecordBatchReader, predicate: str, @@ -152,19 +152,10 @@ class RawDeltaTable: target_alias: Optional[str], writer_properties: Optional[WriterProperties], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], safe_cast: bool, - matched_update_updates: Optional[List[Dict[str, str]]], - matched_update_predicate: Optional[List[Optional[str]]], - matched_delete_predicate: Optional[List[str]], - matched_delete_all: Optional[bool], - not_matched_insert_updates: Optional[List[Dict[str, str]]], - not_matched_insert_predicate: Optional[List[Optional[str]]], - not_matched_by_source_update_updates: Optional[List[Dict[str, str]]], - not_matched_by_source_update_predicate: Optional[List[Optional[str]]], - not_matched_by_source_delete_predicate: Optional[List[str]], - not_matched_by_source_delete_all: Optional[bool], - ) -> str: ... + ) -> PyMergeBuilder: ... + def merge_execute(self, merge_builder: PyMergeBuilder) -> str: ... def get_active_partitions( self, partitions_filters: Optional[FilterType] = None ) -> Any: ... @@ -176,7 +167,7 @@ class RawDeltaTable: schema: pyarrow.Schema, partitions_filters: Optional[FilterType], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> None: ... def cleanup_metadata(self) -> None: ... def check_can_write_timestamp_ntz(self, schema: pyarrow.Schema) -> None: ... @@ -210,13 +201,14 @@ def write_to_deltalake( table: Optional[RawDeltaTable], schema_mode: Optional[str], predicate: Optional[str], + target_file_size: Optional[int], name: Optional[str], description: Optional[str], configuration: Optional[Mapping[str, Optional[str]]], storage_options: Optional[Dict[str, str]], writer_properties: Optional[WriterProperties], custom_metadata: Optional[Dict[str, str]], - post_commithook_properties: Optional[Dict[str, Optional[bool]]], + post_commithook_properties: Optional[PostCommitHookProperties], ) -> None: ... def convert_to_deltalake( uri: str, @@ -245,6 +237,26 @@ def get_num_idx_cols_and_stats_columns( table: Optional[RawDeltaTable], configuration: Optional[Mapping[str, Optional[str]]] ) -> Tuple[int, Optional[List[str]]]: ... +class PyMergeBuilder: + source_alias: str + target_alias: str + arrow_schema: pyarrow.Schema + + def when_matched_update( + self, updates: Dict[str, str], predicate: Optional[str] + ) -> None: ... + def when_matched_delete(self, predicate: Optional[str]) -> None: ... + def when_not_matched_insert( + self, updates: Dict[str, str], predicate: Optional[str] + ) -> None: ... + def when_not_matched_by_source_update( + self, updates: Dict[str, str], predicate: Optional[str] + ) -> None: ... + def when_not_matched_by_source_delete( + self, + predicate: Optional[str], + ) -> None: ... + # Can't implement inheritance (see note in src/schema.rs), so this is next # best thing. DataType = Union["PrimitiveType", "MapType", "StructType", "ArrayType"] diff --git a/python/deltalake/table.py b/python/deltalake/table.py index f00943e346..b2748ed861 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -38,7 +38,11 @@ if TYPE_CHECKING: import os -from deltalake._internal import DeltaError, RawDeltaTable +from deltalake._internal import ( + DeltaError, + PyMergeBuilder, + RawDeltaTable, +) from deltalake._internal import create_deltalake as _create_deltalake from deltalake._util import encode_partition_value from deltalake.data_catalog import DataCatalog @@ -508,6 +512,24 @@ def version(self) -> int: """ return self._table.version() + def partitions( + self, + partition_filters: Optional[List[Tuple[str, str, Any]]] = None, + ) -> List[Dict[str, str]]: + """ + Returns the partitions as a list of dicts. Example: `[{'month': '1', 'year': '2020', 'day': '1'}, ...]` + + Args: + partition_filters: The partition filters that will be used for getting the matched partitions, defaults to `None` (no filtering). + """ + + partitions: List[Dict[str, str]] = [] + for partition in self._table.get_active_partitions(partition_filters): + if not partition: + continue + partitions.append({k: v for (k, v) in partition}) + return partitions + def files( self, partition_filters: Optional[List[Tuple[str, str, Any]]] = None ) -> List[str]: @@ -745,7 +767,7 @@ def vacuum( retention_hours, enforce_retention_duration, custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + post_commithook_properties, ) def update( @@ -849,9 +871,7 @@ def update( writer_properties, safe_cast=not error_on_type_mismatch, custom_metadata=custom_metadata, - post_commithook_properties=post_commithook_properties.__dict__ - if post_commithook_properties - else None, + post_commithook_properties=post_commithook_properties, ) return json.loads(metrics) @@ -950,8 +970,7 @@ def merge( source.schema, (batch for batch in source) ) - return TableMerger( - self, + py_merge_builder = self._table.create_merge_builder( source=source, predicate=predicate, source_alias=source_alias, @@ -961,6 +980,7 @@ def merge( custom_metadata=custom_metadata, post_commithook_properties=post_commithook_properties, ) + return TableMerger(py_merge_builder, self._table) def restore( self, @@ -1243,10 +1263,7 @@ def delete( the metrics from delete. """ metrics = self._table.delete( - predicate, - writer_properties, - custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + predicate, writer_properties, custom_metadata, post_commithook_properties ) return json.loads(metrics) @@ -1283,9 +1300,7 @@ def repair( ``` """ metrics = self._table.repair( - dry_run, - custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + dry_run, custom_metadata, post_commithook_properties ) return json.loads(metrics) @@ -1295,37 +1310,11 @@ class TableMerger: def __init__( self, - table: DeltaTable, - source: pyarrow.RecordBatchReader, - predicate: str, - source_alias: Optional[str] = None, - target_alias: Optional[str] = None, - safe_cast: bool = True, - writer_properties: Optional[WriterProperties] = None, - custom_metadata: Optional[Dict[str, str]] = None, - post_commithook_properties: Optional[PostCommitHookProperties] = None, + builder: PyMergeBuilder, + table: RawDeltaTable, ): - self.table = table - self.source = source - self.predicate = predicate - self.source_alias = source_alias - self.target_alias = target_alias - self.safe_cast = safe_cast - self.writer_properties = writer_properties - self.custom_metadata = custom_metadata - self.post_commithook_properties = post_commithook_properties - self.matched_update_updates: Optional[List[Dict[str, str]]] = None - self.matched_update_predicate: Optional[List[Optional[str]]] = None - self.matched_delete_predicate: Optional[List[str]] = None - self.matched_delete_all: Optional[bool] = None - self.not_matched_insert_updates: Optional[List[Dict[str, str]]] = None - self.not_matched_insert_predicate: Optional[List[Optional[str]]] = None - self.not_matched_by_source_update_updates: Optional[List[Dict[str, str]]] = None - self.not_matched_by_source_update_predicate: Optional[List[Optional[str]]] = ( - None - ) - self.not_matched_by_source_delete_predicate: Optional[List[str]] = None - self.not_matched_by_source_delete_all: Optional[bool] = None + self._builder = builder + self._table = table def when_matched_update( self, updates: Dict[str, str], predicate: Optional[str] = None @@ -1372,14 +1361,7 @@ def when_matched_update( 2 3 6 ``` """ - if isinstance(self.matched_update_updates, list) and isinstance( - self.matched_update_predicate, list - ): - self.matched_update_updates.append(updates) - self.matched_update_predicate.append(predicate) - else: - self.matched_update_updates = [updates] - self.matched_update_predicate = [predicate] + self._builder.when_matched_update(updates, predicate) return self def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerger": @@ -1424,24 +1406,20 @@ def when_matched_update_all(self, predicate: Optional[str] = None) -> "TableMerg 2 3 6 ``` """ + maybe_source_alias = self._builder.source_alias + maybe_target_alias = self._builder.target_alias - src_alias = (self.source_alias + ".") if self.source_alias is not None else "" - trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else "" + trgt_alias = ( + (maybe_target_alias + ".") if maybe_target_alias is not None else "" + ) updates = { f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" - for col in self.source.schema + for col in self._builder.arrow_schema } - if isinstance(self.matched_update_updates, list) and isinstance( - self.matched_update_predicate, list - ): - self.matched_update_updates.append(updates) - self.matched_update_predicate.append(predicate) - else: - self.matched_update_updates = [updates] - self.matched_update_predicate = [predicate] - + self._builder.when_matched_update(updates, predicate) return self def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": @@ -1507,19 +1485,7 @@ def when_matched_delete(self, predicate: Optional[str] = None) -> "TableMerger": 0 1 4 ``` """ - if self.matched_delete_all is not None: - raise ValueError( - """when_matched_delete without a predicate has already been set, which means - it will delete all, any subsequent when_matched_delete, won't make sense.""" - ) - - if predicate is None: - self.matched_delete_all = True - else: - if isinstance(self.matched_delete_predicate, list): - self.matched_delete_predicate.append(predicate) - else: - self.matched_delete_predicate = [predicate] + self._builder.when_matched_delete(predicate) return self def when_not_matched_insert( @@ -1572,16 +1538,7 @@ def when_not_matched_insert( 3 4 7 ``` """ - - if isinstance(self.not_matched_insert_updates, list) and isinstance( - self.not_matched_insert_predicate, list - ): - self.not_matched_insert_updates.append(updates) - self.not_matched_insert_predicate.append(predicate) - else: - self.not_matched_insert_updates = [updates] - self.not_matched_insert_predicate = [predicate] - + self._builder.when_not_matched_insert(updates, predicate) return self def when_not_matched_insert_all( @@ -1630,22 +1587,19 @@ def when_not_matched_insert_all( 3 4 7 ``` """ + maybe_source_alias = self._builder.source_alias + maybe_target_alias = self._builder.target_alias - src_alias = (self.source_alias + ".") if self.source_alias is not None else "" - trgt_alias = (self.target_alias + ".") if self.target_alias is not None else "" + src_alias = (maybe_source_alias + ".") if maybe_source_alias is not None else "" + trgt_alias = ( + (maybe_target_alias + ".") if maybe_target_alias is not None else "" + ) updates = { f"{trgt_alias}`{col.name}`": f"{src_alias}`{col.name}`" - for col in self.source.schema + for col in self._builder.arrow_schema } - if isinstance(self.not_matched_insert_updates, list) and isinstance( - self.not_matched_insert_predicate, list - ): - self.not_matched_insert_updates.append(updates) - self.not_matched_insert_predicate.append(predicate) - else: - self.not_matched_insert_updates = [updates] - self.not_matched_insert_predicate = [predicate] + self._builder.when_not_matched_insert(updates, predicate) return self def when_not_matched_by_source_update( @@ -1695,15 +1649,7 @@ def when_not_matched_by_source_update( 2 3 6 ``` """ - - if isinstance(self.not_matched_by_source_update_updates, list) and isinstance( - self.not_matched_by_source_update_predicate, list - ): - self.not_matched_by_source_update_updates.append(updates) - self.not_matched_by_source_update_predicate.append(predicate) - else: - self.not_matched_by_source_update_updates = [updates] - self.not_matched_by_source_update_predicate = [predicate] + self._builder.when_not_matched_by_source_update(updates, predicate) return self def when_not_matched_by_source_delete( @@ -1722,19 +1668,7 @@ def when_not_matched_by_source_delete( Returns: TableMerger: TableMerger Object """ - if self.not_matched_by_source_delete_all is not None: - raise ValueError( - """when_not_matched_by_source_delete without a predicate has already been set, which means - it will delete all, any subsequent when_not_matched_by_source_delete, won't make sense.""" - ) - - if predicate is None: - self.not_matched_by_source_delete_all = True - else: - if isinstance(self.not_matched_by_source_delete_predicate, list): - self.not_matched_by_source_delete_predicate.append(predicate) - else: - self.not_matched_by_source_delete_predicate = [predicate] + self._builder.when_not_matched_by_source_delete(predicate) return self def execute(self) -> Dict[str, Any]: @@ -1743,31 +1677,7 @@ def execute(self) -> Dict[str, Any]: Returns: Dict: metrics """ - metrics = self.table._table.merge_execute( - source=self.source, - predicate=self.predicate, - source_alias=self.source_alias, - target_alias=self.target_alias, - safe_cast=self.safe_cast, - writer_properties=self.writer_properties - if self.writer_properties - else None, - custom_metadata=self.custom_metadata, - post_commithook_properties=self.post_commithook_properties.__dict__ - if self.post_commithook_properties - else None, - matched_update_updates=self.matched_update_updates, - matched_update_predicate=self.matched_update_predicate, - matched_delete_predicate=self.matched_delete_predicate, - matched_delete_all=self.matched_delete_all, - not_matched_insert_updates=self.not_matched_insert_updates, - not_matched_insert_predicate=self.not_matched_insert_predicate, - not_matched_by_source_update_updates=self.not_matched_by_source_update_updates, - not_matched_by_source_update_predicate=self.not_matched_by_source_update_predicate, - not_matched_by_source_delete_predicate=self.not_matched_by_source_delete_predicate, - not_matched_by_source_delete_all=self.not_matched_by_source_delete_all, - ) - self.table.update_incremental() + metrics = self._table.merge_execute(self._builder) return json.loads(metrics) @@ -1808,9 +1718,7 @@ def add_columns( fields = [fields] self.table._table.add_columns( - fields, - custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + fields, custom_metadata, post_commithook_properties ) def add_constraint( @@ -1849,9 +1757,7 @@ def add_constraint( ) self.table._table.add_constraints( - constraints, - custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + constraints, custom_metadata, post_commithook_properties ) def drop_constraint( @@ -1890,10 +1796,7 @@ def drop_constraint( ``` """ self.table._table.drop_constraints( - name, - raise_if_not_exists, - custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + name, raise_if_not_exists, custom_metadata, post_commithook_properties ) def set_table_properties( @@ -1999,7 +1902,7 @@ def compact( min_commit_interval, writer_properties, custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + post_commithook_properties, ) self.table.update_incremental() return json.loads(metrics) @@ -2030,7 +1933,7 @@ def z_order( max_concurrent_tasks: the maximum number of concurrent tasks to use for file compaction. Defaults to number of CPUs. More concurrent tasks can make compaction faster, but will also use more memory. - max_spill_size: the maximum number of bytes to spill to disk. Defaults to 20GB. + max_spill_size: the maximum number of bytes allowed in memory before spilling to disk. Defaults to 20GB. min_commit_interval: minimum interval in seconds or as timedeltas before a new commit is created. Interval is useful for long running executions. Set to 0 or timedelta(0), if you want a commit per partition. @@ -2069,7 +1972,7 @@ def z_order( min_commit_interval, writer_properties, custom_metadata, - post_commithook_properties.__dict__ if post_commithook_properties else None, + post_commithook_properties, ) self.table.update_incremental() return json.loads(metrics) diff --git a/python/deltalake/writer.py b/python/deltalake/writer.py index 99b915183f..e08d9cc9b8 100644 --- a/python/deltalake/writer.py +++ b/python/deltalake/writer.py @@ -178,6 +178,7 @@ def write_deltalake( schema_mode: Optional[Literal["merge", "overwrite"]] = ..., storage_options: Optional[Dict[str, str]] = ..., predicate: Optional[str] = ..., + target_file_size: Optional[int] = ..., large_dtypes: bool = ..., engine: Literal["rust"] = ..., writer_properties: WriterProperties = ..., @@ -214,6 +215,7 @@ def write_deltalake( storage_options: Optional[Dict[str, str]] = None, partition_filters: Optional[List[Tuple[str, str, Any]]] = None, predicate: Optional[str] = None, + target_file_size: Optional[int] = None, large_dtypes: bool = False, engine: Literal["pyarrow", "rust"] = "rust", writer_properties: Optional[WriterProperties] = None, @@ -267,7 +269,8 @@ def write_deltalake( configuration: A map containing configuration options for the metadata action. schema_mode: If set to "overwrite", allows replacing the schema of the table. Set to "merge" to merge with existing schema. storage_options: options passed to the native delta filesystem. - predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine. + predicate: When using `Overwrite` mode, replace data that matches a predicate. Only used in rust engine.' + target_file_size: Override for target file size for data files written to the delta table. If not passed, it's taken from `delta.targetFileSize`. partition_filters: the partition filters that will be used for partition overwrite. Only used in pyarrow engine. large_dtypes: Only used for pyarrow engine engine: writer engine to write the delta table. PyArrow engine is deprecated, and will be removed in v1.0. @@ -308,15 +311,14 @@ def write_deltalake( table=table._table if table is not None else None, schema_mode=schema_mode, predicate=predicate, + target_file_size=target_file_size, name=name, description=description, configuration=configuration, storage_options=storage_options, writer_properties=writer_properties, custom_metadata=custom_metadata, - post_commithook_properties=post_commithook_properties.__dict__ - if post_commithook_properties - else None, + post_commithook_properties=post_commithook_properties, ) if table: table.update_incremental() @@ -549,9 +551,7 @@ def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch: schema, partition_filters, custom_metadata, - post_commithook_properties=post_commithook_properties.__dict__ - if post_commithook_properties - else None, + post_commithook_properties=post_commithook_properties, ) table.update_incremental() else: diff --git a/python/src/lib.rs b/python/src/lib.rs index 7ceb1853be..df78ac5004 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -1,12 +1,12 @@ mod error; mod filesystem; +mod merge; mod schema; mod utils; use std::collections::{HashMap, HashSet}; use std::future::IntoFuture; use std::str::FromStr; -use std::sync::Arc; use std::time; use std::time::{SystemTime, UNIX_EPOCH}; @@ -16,12 +16,9 @@ use delta_kernel::expressions::Scalar; use delta_kernel::schema::StructField; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; -use deltalake::arrow::record_batch::RecordBatchReader; use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator}; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::{cleanup_metadata, create_checkpoint}; -use deltalake::datafusion::catalog::TableProvider; -use deltalake::datafusion::datasource::memory::MemTable; use deltalake::datafusion::physical_plan::ExecutionPlan; use deltalake::datafusion::prelude::SessionContext; use deltalake::delta_datafusion::DeltaDataChecker; @@ -37,7 +34,6 @@ use deltalake::operations::delete::DeleteBuilder; use deltalake::operations::drop_constraints::DropConstraintBuilder; use deltalake::operations::filesystem_check::FileSystemCheckBuilder; use deltalake::operations::load_cdf::CdfLoadBuilder; -use deltalake::operations::merge::MergeBuilder; use deltalake::operations::optimize::{OptimizeBuilder, OptimizeType}; use deltalake::operations::restore::RestoreBuilder; use deltalake::operations::set_tbl_properties::SetTablePropertiesBuilder; @@ -56,6 +52,7 @@ use deltalake::DeltaTableBuilder; use deltalake::{DeltaOps, DeltaResult}; use error::DeltaError; use futures::future::join_all; + use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -65,6 +62,7 @@ use serde_json::{Map, Value}; use crate::error::DeltaProtocolError; use crate::error::PythonError; use crate::filesystem::FsConfig; +use crate::merge::PyMergeBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; @@ -343,7 +341,7 @@ impl RawDeltaTable { retention_hours: Option, enforce_retention_duration: bool, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult> { let (table, metrics) = py.allow_threads(|| { let mut cmd = VacuumBuilder::new( @@ -378,7 +376,7 @@ impl RawDeltaTable { writer_properties: Option, safe_cast: bool, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = UpdateBuilder::new( @@ -434,7 +432,7 @@ impl RawDeltaTable { min_commit_interval: Option, writer_properties: Option, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = OptimizeBuilder::new( @@ -494,7 +492,7 @@ impl RawDeltaTable { min_commit_interval: Option, writer_properties: Option, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = OptimizeBuilder::new( @@ -540,7 +538,7 @@ impl RawDeltaTable { py: Python, fields: Vec, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { let mut cmd = AddColumnBuilder::new( @@ -573,7 +571,7 @@ impl RawDeltaTable { py: Python, constraints: HashMap, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { let mut cmd = ConstraintBuilder::new( @@ -604,7 +602,7 @@ impl RawDeltaTable { name: String, raise_if_not_exists: bool, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult<()> { let table = py.allow_threads(|| { let mut cmd = DropConstraintBuilder::new( @@ -694,7 +692,8 @@ impl RawDeltaTable { } #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (source, + #[pyo3(signature = ( + source, predicate, source_alias = None, target_alias = None, @@ -702,19 +701,9 @@ impl RawDeltaTable { writer_properties = None, post_commithook_properties = None, custom_metadata = None, - matched_update_updates = None, - matched_update_predicate = None, - matched_delete_predicate = None, - matched_delete_all = None, - not_matched_insert_updates = None, - not_matched_insert_predicate = None, - not_matched_by_source_update_updates = None, - not_matched_by_source_update_predicate = None, - not_matched_by_source_delete_predicate = None, - not_matched_by_source_delete_all = None, ))] - pub fn merge_execute( - &mut self, + pub fn create_merge_builder( + &self, py: Python, source: PyArrowType, predicate: String, @@ -722,169 +711,39 @@ impl RawDeltaTable { target_alias: Option, safe_cast: bool, writer_properties: Option, - post_commithook_properties: Option>>, + post_commithook_properties: Option, custom_metadata: Option>, - matched_update_updates: Option>>, - matched_update_predicate: Option>>, - matched_delete_predicate: Option>, - matched_delete_all: Option, - not_matched_insert_updates: Option>>, - not_matched_insert_predicate: Option>>, - not_matched_by_source_update_updates: Option>>, - not_matched_by_source_update_predicate: Option>>, - not_matched_by_source_delete_predicate: Option>, - not_matched_by_source_delete_all: Option, - ) -> PyResult { - let (table, metrics) = py.allow_threads(|| { - let ctx = SessionContext::new(); - let schema = source.0.schema(); - let batches = vec![source.0.map(|batch| batch.unwrap()).collect::>()]; - let table_provider: Arc = - Arc::new(MemTable::try_new(schema, batches).unwrap()); - let source_df = ctx.read_table(table_provider).unwrap(); - - let mut cmd = MergeBuilder::new( + ) -> PyResult { + py.allow_threads(|| { + Ok(PyMergeBuilder::new( self._table.log_store(), self._table.snapshot().map_err(PythonError::from)?.clone(), + source.0, predicate, - source_df, + source_alias, + target_alias, + safe_cast, + writer_properties, + post_commithook_properties, + custom_metadata, ) - .with_safe_cast(safe_cast); - - if let Some(src_alias) = source_alias { - cmd = cmd.with_source_alias(src_alias); - } - - if let Some(trgt_alias) = target_alias { - cmd = cmd.with_target_alias(trgt_alias); - } - - if let Some(writer_props) = writer_properties { - cmd = cmd.with_writer_properties( - set_writer_properties(writer_props).map_err(PythonError::from)?, - ); - } - - if let Some(commit_properties) = - maybe_create_commit_properties(custom_metadata, post_commithook_properties) - { - cmd = cmd.with_commit_properties(commit_properties); - } - - if let Some(mu_updates) = matched_update_updates { - if let Some(mu_predicate) = matched_update_predicate { - for it in mu_updates.iter().zip(mu_predicate.iter()) { - let (update_values, predicate_value) = it; - - if let Some(pred) = predicate_value { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update.predicate(pred.clone()) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_matched_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; - } - } - } - } - - if let Some(_md_delete_all) = matched_delete_all { - cmd = cmd - .when_matched_delete(|delete| delete) - .map_err(PythonError::from)?; - } else if let Some(md_predicate) = matched_delete_predicate { - for pred in md_predicate.iter() { - cmd = cmd - .when_matched_delete(|delete| delete.predicate(pred.clone())) - .map_err(PythonError::from)?; - } - } - - if let Some(nmi_updates) = not_matched_insert_updates { - if let Some(nmi_predicate) = not_matched_insert_predicate { - for it in nmi_updates.iter().zip(nmi_predicate.iter()) { - let (update_values, predicate_value) = it; - if let Some(pred) = predicate_value { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in update_values { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert.predicate(pred.clone()) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_insert(|mut insert| { - for (col_name, expression) in update_values { - insert = insert.set(col_name.clone(), expression.clone()); - } - insert - }) - .map_err(PythonError::from)?; - } - } - } - } - - if let Some(nmbsu_updates) = not_matched_by_source_update_updates { - if let Some(nmbsu_predicate) = not_matched_by_source_update_predicate { - for it in nmbsu_updates.iter().zip(nmbsu_predicate.iter()) { - let (update_values, predicate_value) = it; - if let Some(pred) = predicate_value { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update.predicate(pred.clone()) - }) - .map_err(PythonError::from)?; - } else { - cmd = cmd - .when_not_matched_by_source_update(|mut update| { - for (col_name, expression) in update_values { - update = - update.update(col_name.clone(), expression.clone()); - } - update - }) - .map_err(PythonError::from)?; - } - } - } - } - - if let Some(_nmbs_delete_all) = not_matched_by_source_delete_all { - cmd = cmd - .when_not_matched_by_source_delete(|delete| delete) - .map_err(PythonError::from)?; - } else if let Some(nmbs_predicate) = not_matched_by_source_delete_predicate { - for pred in nmbs_predicate.iter() { - cmd = cmd - .when_not_matched_by_source_delete(|delete| delete.predicate(pred.clone())) - .map_err(PythonError::from)?; - } - } + .map_err(PythonError::from)?) + }) + } - rt().block_on(cmd.into_future()).map_err(PythonError::from) - })?; - self._table.state = table.state; - Ok(serde_json::to_string(&metrics).unwrap()) + #[pyo3(signature=( + merge_builder + ))] + pub fn merge_execute( + &mut self, + py: Python, + merge_builder: &mut PyMergeBuilder, + ) -> PyResult { + py.allow_threads(|| { + let (table, metrics) = merge_builder.execute().map_err(PythonError::from)?; + self._table.state = table.state; + Ok(metrics) + }) } // Run the restore command on the Delta Table: restore table to a given version or datetime @@ -1078,7 +937,7 @@ impl RawDeltaTable { schema: PyArrowType, partitions_filters: Option>, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult<()> { py.allow_threads(|| { let mode = mode.parse().map_err(PythonError::from)?; @@ -1251,7 +1110,7 @@ impl RawDeltaTable { predicate: Option, writer_properties: Option, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult { let (table, metrics) = py.allow_threads(|| { let mut cmd = DeleteBuilder::new( @@ -1310,7 +1169,7 @@ impl RawDeltaTable { &mut self, dry_run: bool, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult { let mut cmd = FileSystemCheckBuilder::new( self._table.log_store(), @@ -1334,14 +1193,12 @@ impl RawDeltaTable { fn set_post_commithook_properties( mut commit_properties: CommitProperties, - post_commithook_properties: HashMap>, + post_commithook_properties: PyPostCommitHookProperties, ) -> CommitProperties { - if let Some(Some(create_checkpoint)) = post_commithook_properties.get("create_checkpoint") { - commit_properties = commit_properties.with_create_checkpoint(*create_checkpoint) - } - if let Some(cleanup_expired_logs) = post_commithook_properties.get("cleanup_expired_logs") { - commit_properties = commit_properties.with_cleanup_expired_logs(*cleanup_expired_logs) - } + commit_properties = + commit_properties.with_create_checkpoint(post_commithook_properties.create_checkpoint); + commit_properties = commit_properties + .with_cleanup_expired_logs(post_commithook_properties.cleanup_expired_logs); commit_properties } @@ -1461,7 +1318,7 @@ fn convert_partition_filters( fn maybe_create_commit_properties( custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> Option { if custom_metadata.is_none() && post_commithook_properties.is_none() { return None; @@ -1743,6 +1600,12 @@ pub struct PyWriterProperties { column_properties: Option>>, } +#[derive(FromPyObject)] +pub struct PyPostCommitHookProperties { + create_checkpoint: bool, + cleanup_expired_logs: Option, +} + #[pyfunction] #[allow(clippy::too_many_arguments)] fn write_to_deltalake( @@ -1754,13 +1617,14 @@ fn write_to_deltalake( schema_mode: Option, partition_by: Option>, predicate: Option, + target_file_size: Option, name: Option, description: Option, configuration: Option>>, storage_options: Option>, writer_properties: Option, custom_metadata: Option>, - post_commithook_properties: Option>>, + post_commithook_properties: Option, ) -> PyResult<()> { py.allow_threads(|| { let batches = data.0.map(|batch| batch.unwrap()).collect::>(); @@ -1802,6 +1666,10 @@ fn write_to_deltalake( builder = builder.with_replace_where(predicate); }; + if let Some(target_file_size) = target_file_size { + builder = builder.with_target_file_size(target_file_size) + }; + if let Some(config) = configuration { builder = builder.with_configuration(config); }; @@ -2085,6 +1953,7 @@ fn _internal(m: &Bound<'_, PyModule>) -> PyResult<()> { m )?)?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; // There are issues with submodules, so we will expose them flat for now diff --git a/python/src/merge.rs b/python/src/merge.rs new file mode 100644 index 0000000000..8fb1f044d8 --- /dev/null +++ b/python/src/merge.rs @@ -0,0 +1,214 @@ +use deltalake::arrow::array::RecordBatchReader; +use deltalake::arrow::datatypes::Schema as ArrowSchema; +use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; +use deltalake::arrow::pyarrow::IntoPyArrow; +use deltalake::datafusion::catalog::TableProvider; +use deltalake::datafusion::datasource::MemTable; +use deltalake::datafusion::prelude::SessionContext; +use deltalake::logstore::LogStoreRef; +use deltalake::operations::merge::MergeBuilder; +use deltalake::table::state::DeltaTableState; +use deltalake::{DeltaResult, DeltaTable}; +use pyo3::prelude::*; +use std::collections::HashMap; +use std::future::IntoFuture; +use std::sync::Arc; + +use crate::error::PythonError; +use crate::utils::rt; +use crate::{ + maybe_create_commit_properties, set_writer_properties, PyPostCommitHookProperties, + PyWriterProperties, +}; + +#[pyclass(module = "deltalake._internal")] +pub(crate) struct PyMergeBuilder { + _builder: Option, + #[pyo3(get)] + source_alias: Option, + #[pyo3(get)] + target_alias: Option, + arrow_schema: Arc, +} + +impl PyMergeBuilder { + #[allow(clippy::too_many_arguments)] + pub fn new( + log_store: LogStoreRef, + snapshot: DeltaTableState, + source: ArrowArrayStreamReader, + predicate: String, + source_alias: Option, + target_alias: Option, + safe_cast: bool, + writer_properties: Option, + post_commithook_properties: Option, + custom_metadata: Option>, + ) -> DeltaResult { + let ctx = SessionContext::new(); + let schema = source.schema(); + let batches = vec![source.map(|batch| batch.unwrap()).collect::>()]; + let table_provider: Arc = + Arc::new(MemTable::try_new(schema.clone(), batches).unwrap()); + let source_df = ctx.read_table(table_provider).unwrap(); + + let mut cmd = + MergeBuilder::new(log_store, snapshot, predicate, source_df).with_safe_cast(safe_cast); + + if let Some(src_alias) = &source_alias { + cmd = cmd.with_source_alias(src_alias); + } + + if let Some(trgt_alias) = &target_alias { + cmd = cmd.with_target_alias(trgt_alias); + } + + if let Some(writer_props) = writer_properties { + cmd = cmd.with_writer_properties(set_writer_properties(writer_props)?); + } + + if let Some(commit_properties) = + maybe_create_commit_properties(custom_metadata, post_commithook_properties) + { + cmd = cmd.with_commit_properties(commit_properties); + } + Ok(Self { + _builder: Some(cmd), + source_alias, + target_alias, + arrow_schema: schema, + }) + } + + pub fn execute(&mut self) -> DeltaResult<(DeltaTable, String)> { + let (table, metrics) = rt().block_on(self._builder.take().unwrap().into_future())?; + Ok((table, serde_json::to_string(&metrics).unwrap())) + } +} + +#[pymethods] +impl PyMergeBuilder { + #[getter] + fn get_arrow_schema(&self, py: Python) -> PyResult { + ::clone(&self.arrow_schema).into_pyarrow(py) + } + + #[pyo3(signature=( + updates, + predicate = None, + ))] + fn when_matched_update( + &mut self, + updates: HashMap, + predicate: Option, + ) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_matched_update(|mut update| { + for (column, expression) in updates { + update = update.update(column, expression) + } + if let Some(predicate) = predicate { + update = update.predicate(predicate) + }; + update + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + predicate = None, + ))] + fn when_matched_delete(&mut self, predicate: Option) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_matched_delete(|mut delete| { + if let Some(predicate) = predicate { + delete = delete.predicate(predicate) + }; + delete + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + updates, + predicate = None, + ))] + fn when_not_matched_insert( + &mut self, + updates: HashMap, + predicate: Option, + ) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_not_matched_insert(|mut insert| { + for (column, expression) in updates { + insert = insert.set(column, expression) + } + if let Some(predicate) = predicate { + insert = insert.predicate(predicate) + }; + insert + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + updates, + predicate = None, + ))] + fn when_not_matched_by_source_update( + &mut self, + updates: HashMap, + predicate: Option, + ) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_not_matched_by_source_update(|mut update| { + for (column, expression) in updates { + update = update.update(column, expression) + } + if let Some(predicate) = predicate { + update = update.predicate(predicate) + }; + update + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } + + #[pyo3(signature=( + predicate = None, + ))] + fn when_not_matched_by_source_delete(&mut self, predicate: Option) -> PyResult<()> { + self._builder = match self._builder.take() { + Some(cmd) => Some( + cmd.when_not_matched_by_source_delete(|mut delete| { + if let Some(predicate) = predicate { + delete = delete.predicate(predicate) + }; + delete + }) + .map_err(PythonError::from)?, + ), + None => unreachable!(), + }; + Ok(()) + } +} diff --git a/python/tests/test_merge.py b/python/tests/test_merge.py index 0b047580f1..ea13adf85b 100644 --- a/python/tests/test_merge.py +++ b/python/tests/test_merge.py @@ -985,3 +985,54 @@ def test_struct_casting(tmp_path: pathlib.Path): .execute() ) assert result is not None + + +def test_merge_isin_partition_pruning( + tmp_path: pathlib.Path, +): + nrows = 5 + data = pa.table( + { + "id": pa.array([str(x) for x in range(nrows)]), + "partition": pa.array(list(range(nrows)), pa.int64()), + "sold": pa.array(list(range(nrows)), pa.int32()), + } + ) + + write_deltalake(tmp_path, data, mode="append", partition_by="partition") + + dt = DeltaTable(tmp_path) + + source_table = pa.table( + { + "id": pa.array(["3", "4"]), + "partition": pa.array([3, 4], pa.int64()), + "sold": pa.array([10, 20], pa.int32()), + } + ) + + metrics = ( + dt.merge( + source=source_table, + predicate="t.id = s.id and t.partition in (3,4)", + source_alias="s", + target_alias="t", + ) + .when_matched_update_all() + .execute() + ) + + expected = pa.table( + { + "id": pa.array(["0", "1", "2", "3", "4"]), + "partition": pa.array([0, 1, 2, 3, 4], pa.int64()), + "sold": pa.array([0, 1, 2, 10, 20], pa.int32()), + } + ) + result = dt.to_pyarrow_table().sort_by([("id", "ascending")]) + last_action = dt.history(1)[0] + + assert last_action["operation"] == "MERGE" + assert result == expected + assert metrics["num_target_files_scanned"] == 2 + assert metrics["num_target_files_skipped_during_scan"] == 3 diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index cc36fc0274..5ff07ed9e8 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -1,4 +1,5 @@ import os +import tempfile from datetime import date, datetime, timezone from pathlib import Path from random import random @@ -839,6 +840,83 @@ def test_encode_partition_value(input_value: Any, expected: str) -> None: assert encode_partition_value(input_value) == expected +def test_partitions_partitioned_table(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = [ + {"year": "2020", "month": "2", "day": "5"}, + {"year": "2021", "month": "12", "day": "4"}, + {"year": "2020", "month": "2", "day": "3"}, + {"year": "2021", "month": "4", "day": "5"}, + {"year": "2020", "month": "1", "day": "1"}, + {"year": "2021", "month": "12", "day": "20"}, + ] + actual = dt.partitions() + for partition in expected: + assert partition in actual + + +def test_partitions_filtering_partitioned_table(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = [ + {"day": "5", "month": "4", "year": "2021"}, + {"day": "20", "month": "12", "year": "2021"}, + {"day": "4", "month": "12", "year": "2021"}, + ] + + partition_filters = [("year", ">=", "2021")] + actual = dt.partitions(partition_filters=partition_filters) + assert len(expected) == len(actual) + for partition in expected: + partition in actual + + +def test_partitions_date_partitioned_table(): + table_path = tempfile.gettempdir() + "/date_partition_table" + date_partitions = [ + date(2024, 8, 1), + date(2024, 8, 2), + date(2024, 8, 3), + date(2024, 8, 4), + ] + sample_data = pa.table( + { + "date_field": pa.array(date_partitions, pa.date32()), + "numeric_data": pa.array([1, 2, 3, 4], pa.int64()), + } + ) + write_deltalake( + table_path, sample_data, mode="overwrite", partition_by=["date_field"] + ) + + delta_table = DeltaTable(table_path) + expected = [ + {"date_field": "2024-08-01"}, + {"date_field": "2024-08-02"}, + {"date_field": "2024-08-03"}, + {"date_field": "2024-08-04"}, + ] + actual = sorted(delta_table.partitions(), key=lambda x: x["date_field"]) + assert expected == actual + + +def test_partitions_special_partitioned_table(): + table_path = "../crates/test/tests/data/delta-0.8.0-special-partition" + dt = DeltaTable(table_path) + + expected = [{"x": "A/A"}, {"x": "B B"}] + actual = dt.partitions() + for partition in expected: + partition in actual + + +def test_partitions_unpartitioned_table(): + table_path = "../crates/test/tests/data/simple_table" + dt = DeltaTable(table_path) + assert len(dt.partitions()) == 0 + + def test_read_table_last_checkpoint_not_updated(): dt = DeltaTable("../crates/test/tests/data/table_failed_last_checkpoint_update")