diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 8e255cd08b..ce2cdf962c 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -68,6 +68,7 @@ use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_sql::planner::ParserOptions; +use futures::TryStreamExt; use itertools::Itertools; use log::error; @@ -1019,6 +1020,31 @@ pub(crate) fn logical_expr_to_physical_expr( create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } +pub(crate) async fn execute_plan_to_batch( + state: &SessionState, + plan: Arc, +) -> DeltaResult { + let data = + futures::future::try_join_all((0..plan.output_partitioning().partition_count()).map(|p| { + let plan_copy = plan.clone(); + let task_context = state.task_ctx().clone(); + async move { + let batch_stream = plan_copy.execute(p, task_context)?; + + let schema = batch_stream.schema(); + + let batches = batch_stream.try_collect::>().await?; + + DataFusionResult::<_>::Ok(arrow::compute::concat_batches(&schema, batches.iter())?) + } + })) + .await?; + + let batch = arrow::compute::concat_batches(&plan.schema(), data.iter())?; + + Ok(batch) +} + /// Responsible for checking batches of data conform to table's invariants. #[derive(Clone)] pub struct DeltaDataChecker { diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 7b03965747..0f0da1c21f 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -50,12 +50,16 @@ use datafusion::{ }, prelude::{DataFrame, SessionContext}, }; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; +use datafusion_expr::expr::Placeholder; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, + BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + UserDefinedLogicalNode, UNNAMED_TABLE, }; use futures::future::BoxFuture; +use itertools::Itertools; use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; @@ -66,7 +70,8 @@ use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression} use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; use crate::delta_datafusion::{ - register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, DeltaTableProvider, + execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, + DeltaTableProvider, }; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; @@ -652,6 +657,242 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } +/// Takes the predicate provided and does two things: +/// +/// 1. for any relations between a source column and a target column, if the target column is a +/// partition column, then replace source with a placeholder matching the name of the partition +/// columns +/// +/// 2. for any other relation with a source column, remove them. +/// +/// For example, for the predicate: +/// +/// `source.date = target.date and source.id = target.id and frob > 42` +/// +/// where `date` is a partition column, would result in the expr: +/// +/// `$date = target.date and frob > 42` +/// +/// This leaves us with a predicate that we can push into delta scan after expanding it out to +/// a conjunction between the disinct partitions in the source input. +/// +/// TODO: A futher improvement here might be for non-partition columns to be replaced with min/max +/// checks, so the above example could become: +/// +/// `$date = target.date and target.id between 12345 and 99999 and frob > 42` +fn generalize_filter( + predicate: Expr, + partition_columns: &Vec, + source_name: &TableReference, + target_name: &TableReference, + placeholders: &mut HashMap, +) -> Option { + fn references_table(expr: &Expr, table: &TableReference) -> Option { + match expr { + Expr::Alias(alias) => references_table(&alias.expr, table), + Expr::Column(col) => col.relation.as_ref().and_then(|rel| { + if rel == table { + Some(col.name.to_owned()) + } else { + None + } + }), + Expr::Negative(neg) => references_table(neg, table), + Expr::Cast(cast) => references_table(&cast.expr, table), + Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), + Expr::ScalarFunction(func) => { + if func.args.len() == 1 { + references_table(&func.args[0], table) + } else { + None + } + } + Expr::ScalarUDF(udf) => { + if udf.args.len() == 1 { + references_table(&udf.args[0], table) + } else { + None + } + } + _ => None, + } + } + + match predicate { + Expr::BinaryExpr(binary) => { + if references_table(&binary.right, source_name).is_some() { + if let Some(left_target) = references_table(&binary.left, target_name) { + if partition_columns.contains(&left_target) { + let placeholder_name = format!("{left_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + left: binary.left, + op: binary.op, + right: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.right); + + return Some(replaced); + } + } + return None; + } + if references_table(&binary.left, source_name).is_some() { + if let Some(right_target) = references_table(&binary.right, target_name) { + if partition_columns.contains(&right_target) { + let placeholder_name = format!("{right_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + right: binary.right, + op: binary.op, + left: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.left); + + return Some(replaced); + } + } + return None; + } + + let left = generalize_filter( + *binary.left, + partition_columns, + source_name, + target_name, + placeholders, + ); + let right = generalize_filter( + *binary.right, + partition_columns, + source_name, + target_name, + placeholders, + ); + + match (left, right) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { + left: l.into(), + op: binary.op, + right: r.into(), + }) + .into(), + } + } + other => Some(other), + } +} + +fn replace_placeholders(expr: Expr, placeholders: &HashMap) -> Expr { + expr.transform(&|expr| match expr { + Expr::Placeholder(Placeholder { id, .. }) => { + let value = placeholders[&id].clone(); + // Replace the placeholder with the value + Ok(Transformed::Yes(Expr::Literal(value))) + } + _ => Ok(Transformed::No(expr)), + }) + .unwrap() +} + +async fn try_construct_early_filter( + join_predicate: Expr, + table_snapshot: &DeltaTableState, + session_state: &SessionState, + source: &LogicalPlan, + source_name: &TableReference<'_>, + target_name: &TableReference<'_>, +) -> DeltaResult> { + let table_metadata = table_snapshot.metadata(); + + if table_metadata.is_none() { + return Ok(None); + } + + let table_metadata = table_metadata.unwrap(); + + let partition_columns = &table_metadata.partition_columns; + + if partition_columns.is_empty() { + return Ok(None); + } + + let mut placeholders = HashMap::default(); + + match generalize_filter( + join_predicate, + partition_columns, + source_name, + target_name, + &mut placeholders, + ) { + None => Ok(None), + Some(filter) => { + if placeholders.is_empty() { + // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter + Ok(Some(filter)) + } else { + // if we have some recognised partitions, then discover the distinct set of partitions in the source data and + // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) + let distinct_partitions = LogicalPlan::Distinct(Distinct { + input: LogicalPlan::Projection(Projection::try_new( + placeholders + .into_iter() + .map(|(alias, expr)| expr.alias(alias)) + .collect_vec(), + source.clone().into(), + )?) + .into(), + }); + + let execution_plan = session_state + .create_physical_plan(&distinct_partitions) + .await?; + + let items = execute_plan_to_batch(session_state, execution_plan).await?; + + let placeholder_names = items + .schema() + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect_vec(); + + let expr = (0..items.num_rows()) + .map(|i| { + let replacements = placeholder_names + .iter() + .map(|placeholder| { + let col = items.column_by_name(placeholder).unwrap(); + let value = ScalarValue::try_from_array(col, i)?; + DeltaResult::Ok((placeholder.to_owned(), value)) + }) + .try_collect()?; + Ok(replace_placeholders(filter.clone(), &replacements)) + }) + .collect::>>()? + .into_iter() + .reduce(Expr::or); + + Ok(expr) + } + } + } +} + #[allow(clippy::too_many_arguments)] async fn execute( predicate: Expression, @@ -693,9 +934,12 @@ async fn execute( }; // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work - let source = - LogicalPlanBuilder::scan(source_name, provider_as_source(source.into_view()), None)? - .build()?; + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + )? + .build()?; let source = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { @@ -704,17 +948,68 @@ async fn execute( }), }); - let source = DataFrame::new(state.clone(), source); - let source = source.with_column(SOURCE_COLUMN, lit(true))?; - let target_provider = Arc::new(DeltaTableProvider::try_new( snapshot.clone(), log_store.clone(), DeltaScanConfig::default(), )?); + let target_provider = provider_as_source(target_provider); - let target = LogicalPlanBuilder::scan(target_name, target_provider, None)?.build()?; + let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?; + + let source_schema = source.schema(); + let target_schema = target.schema(); + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; + let predicate = match predicate { + Expression::DataFusion(expr) => expr, + Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, + }; + + let state = state.with_query_planner(Arc::new(MergePlanner {})); + + let (target, files) = { + // Attempt to construct an early filter that we can apply to the Add action list and the delta scan. + // In the case where there are partition columns in the join predicate, we can scan the source table + // to get the distinct list of partitions affected and constrain the search to those. + + if !not_match_source_operations.is_empty() { + // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since + // that implies a full scan + (target, snapshot.files().iter().collect_vec()) + } else if let Some(filter) = try_construct_early_filter( + predicate.clone(), + snapshot, + &state, + &source, + &source_name, + &target_name, + ) + .await? + { + let file_filter = filter + .clone() + .transform(&|expr| match expr { + Expr::Column(c) => Ok(Transformed::Yes(Expr::Column(Column { + relation: None, // the file filter won't be looking at columns like `target.partition`, it'll just be `partition` + name: c.name, + }))), + expr => Ok(Transformed::No(expr)), + }) + .unwrap(); + let files = snapshot + .files_matching_predicate(&[file_filter])? + .collect_vec(); + + let new_target = LogicalPlan::Filter(Filter::try_new(filter, target.into())?); + (new_target, files) + } else { + (target, snapshot.files().iter().collect_vec()) + } + }; + + let source = DataFrame::new(state.clone(), source); + let source = source.with_column(SOURCE_COLUMN, lit(true))?; // TODO: This is here to prevent predicate pushdowns. In the future we can replace this node to allow pushdowns depending on which operations are being used. let target = LogicalPlan::Extension(Extension { @@ -726,14 +1021,6 @@ async fn execute( let target = DataFrame::new(state.clone(), target); let target = target.with_column(TARGET_COLUMN, lit(true))?; - let source_schema = source.schema(); - let target_schema = target.schema(); - let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; - let predicate = match predicate { - Expression::DataFusion(expr) => expr, - Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, - }; - let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?; let join_schema_df = join.schema().to_owned(); @@ -999,7 +1286,6 @@ async fn execute( let project = filtered.select(write_projection)?; let optimized = &project.into_optimized_plan()?; - let state = state.with_query_planner(Arc::new(MergePlanner {})); let write = state.create_physical_plan(optimized).await?; let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); @@ -1034,7 +1320,7 @@ async fn execute( let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); metrics.num_target_files_added = actions.len(); - for action in snapshot.files() { + for action in files { metrics.num_target_files_removed += 1; actions.push(Action::Remove(Remove { path: action.path.clone(), @@ -1162,6 +1448,8 @@ mod tests { use crate::kernel::DataType; use crate::kernel::PrimitiveType; use crate::kernel::StructField; + use crate::operations::merge::generalize_filter; + use crate::operations::merge::try_construct_early_filter; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1175,11 +1463,21 @@ mod tests { use arrow_schema::DataType as ArrowDataType; use arrow_schema::Field; use datafusion::assert_batches_sorted_eq; + use datafusion::datasource::provider_as_source; use datafusion::prelude::DataFrame; use datafusion::prelude::SessionContext; + use datafusion_common::Column; + use datafusion_common::ScalarValue; + use datafusion_common::TableReference; use datafusion_expr::col; + use datafusion_expr::expr::Placeholder; use datafusion_expr::lit; + use datafusion_expr::Expr; + use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::Operator; use serde_json::json; + use std::collections::HashMap; + use std::ops::Neg; use std::sync::Arc; use super::MergeMetrics; @@ -1508,7 +1806,7 @@ mod tests { #[tokio::test] async fn test_merge_partitions() { - /* Validate the join predicate works with partition columns */ + /* Validate the join predicate works with table partitions */ let schema = get_arrow_schema(&None); let table = setup_table(Some(vec!["modified"])).await; @@ -1596,6 +1894,78 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_merge_partitions_skipping() { + /* Validate the join predicate can be used for skipping partitions */ + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + let table = write_data(table, &schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 4); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![999, 999, 999])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert!(table.get_file_uris().count() >= 3); + assert_eq!(metrics.num_target_files_added, 3); + assert_eq!(metrics.num_target_files_removed, 2); + assert_eq!(metrics.num_target_rows_copied, 0); + assert_eq!(metrics.num_target_rows_updated, 2); + assert_eq!(metrics.num_target_rows_inserted, 1); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_source_rows, 3); + + let expected = vec![ + "+-------+------------+----+", + "| value | modified | id |", + "+-------+------------+----+", + "| 1 | 2021-02-01 | A |", + "| 100 | 2021-02-02 | D |", + "| 999 | 2023-07-04 | B |", + "| 999 | 2023-07-04 | C |", + "| 999 | 2023-07-04 | X |", + "+-------+------------+----+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_delete_matched() { // Validate behaviours of match delete @@ -2015,4 +2385,228 @@ mod tests { let actual = get_data(&table).await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_generalize_filter_with_partitions() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_with_partitions_captures_expression() { + // Check that when generalizing the filter, the placeholder map captures the expression needed to make the statement the same + // when the distinct values are substitiuted in + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .neg() + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + + assert_eq!(placeholders.len(), 1); + + let placeholder_expr = &placeholders["id_0"]; + + let expected_placeholder = col(Column::new(source.clone().into(), "id")).neg(); + + assert_eq!(placeholder_expr, &expected_placeholder); + } + + #[tokio::test] + async fn test_generalize_filter_keeps_static_target_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_removes_source_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(source.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_partitions_expands() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_file_uris().count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + &table.state, + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let split_pred = { + fn split(expr: Expr, parts: &mut Vec<(String, String)>) { + match expr { + Expr::BinaryExpr(ex) if ex.op == Operator::Or => { + split(*ex.left, parts); + split(*ex.right, parts); + } + Expr::BinaryExpr(ex) if ex.op == Operator::Eq => { + let col = match *ex.right { + Expr::Column(col) => col.name, + ex => panic!("expected column in pred, got {ex}!"), + }; + + let value = match *ex.left { + Expr::Literal(ScalarValue::Utf8(Some(value))) => value, + ex => panic!("expected value in predicate, got {ex}!"), + }; + + parts.push((col, value)) + } + + expr => panic!("expected either = or OR, got {expr}"), + } + } + + let mut parts = vec![]; + split(pred.unwrap(), &mut parts); + parts.sort(); + parts + }; + + let expected_pred_parts = [ + ("id".to_owned(), "B".to_owned()), + ("id".to_owned(), "C".to_owned()), + ("id".to_owned(), "X".to_owned()), + ]; + + assert_eq!(split_pred, expected_pred_parts); + } }