From 11ea2a59a47e482e11ebbaab74a17006a59b6c9c Mon Sep 17 00:00:00 2001 From: emcake <3726783+emcake@users.noreply.github.com> Date: Wed, 20 Dec 2023 12:13:38 +0000 Subject: [PATCH] feat: merge using partition filters (#1958) # Description This upgrades merge so that it can leverage partitions where specified in the join predicate. There are two ways we can leverage partitions: 1. static references, i.e `target.partition = 1`. 2. Inferring from the data, i.e `source.partition = target.partition`. In the first case, this implements the logic described in [this comment](https://github.com/delta-io/delta-rs/blob/main/crates/deltalake-core/src/operations/merge.rs#L670). Any predicate mentioning the source that is not covered by (2) is pruned, which will leave predicates on just the target columns (and will be amenable to file pruning) In the second case, we first construct a version of the predicate with references to source replaced with placeholders: ```sql target.partition = source.partition and foo > 42 ``` becomes: ```sql target.partition = $1 and foo > 42 ``` We then stream through the source table, gathering the distinct tuples of the mentioned partitions: ``` | partition | ------------- | 1 | | 5 | | 7 | ``` and then expand out the sql to take these into account: ```sql (target.partition = 1 and foo > 42) or (target.partition = 5 and foo > 42) or (target.partition = 7 and foo > 42) ``` And insert this filter into the target chain. We also use the same filter to process the file list, meaning we only make remove actions for files that will be targeted by the scan. I considered whether it would be possible to do this via datafusion sql in a generic manner, for example by first joining against the distinct partitions. I don't think it's possible - because each of the filters on the logical plans are static, there's no opportunity for it to push the distinct partition tuples down into the scan. Another variant would be to make it so the source and partition tables share the same `output_partitioning` structure, but as far as I can tell you wouldn't be able to make the partitions line up such that you can do the merge effectively and not read the whole table (plus `DeltaScan` doesn't guarantee that one datafusion partition is one DeltaTable partition). I think the static bit is a no brainer but the eager read of the source table may cause issues if the source table is of a similar size to the target table. It may be prudent hide that part behind a feature flag on the merge, but would love comments on it. # Performance I created a 16GB table locally with 1.25 billion rows over 1k partitions, and when updating 1 partition a full merge takes 1000-ish seconds: ``` merge took 985.0801 seconds merge metrics: MergeMetrics { num_source_rows: 1250000, num_target_rows_inserted: 468790, num_target_rows_updated: 781210, num_target_rows_deleted: 0, num_target_rows_copied: 1249687667, num_output_rows: 1250937667, num_target_files_added: 1001, num_target_files_removed: 1001, execution_time_ms: 983851, scan_time_ms: 0, rewrite_time_ms: 983322 } ``` but with partitioning it takes about 3: ``` merge took 2.6337671 seconds merge metrics: MergeMetrics { num_source_rows: 1250000, num_target_rows_inserted: 468877, num_target_rows_updated: 781123, num_target_rows_deleted: 0, num_target_rows_copied: 468877, num_output_rows: 1718877, num_target_files_added: 2, num_target_files_removed: 2, execution_time_ms: 2622, scan_time_ms: 0, rewrite_time_ms: 2316 } ``` In practice, the tables I'm wanting to use this for are terabytes in size so using merge is currently impractical. This would be a significant speed boost to them. # Related Issue(s) closes #1846 --------- Co-authored-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- .../src/delta_datafusion/mod.rs | 26 + crates/deltalake-core/src/operations/merge.rs | 634 +++++++++++++++++- 2 files changed, 640 insertions(+), 20 deletions(-) diff --git a/crates/deltalake-core/src/delta_datafusion/mod.rs b/crates/deltalake-core/src/delta_datafusion/mod.rs index 9f2818de93..d3883eac8b 100644 --- a/crates/deltalake-core/src/delta_datafusion/mod.rs +++ b/crates/deltalake-core/src/delta_datafusion/mod.rs @@ -68,6 +68,7 @@ use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use datafusion_proto::logical_plan::LogicalExtensionCodec; use datafusion_proto::physical_plan::PhysicalExtensionCodec; use datafusion_sql::planner::ParserOptions; +use futures::TryStreamExt; use itertools::Itertools; use log::error; @@ -1019,6 +1020,31 @@ pub(crate) fn logical_expr_to_physical_expr( create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() } +pub(crate) async fn execute_plan_to_batch( + state: &SessionState, + plan: Arc, +) -> DeltaResult { + let data = + futures::future::try_join_all((0..plan.output_partitioning().partition_count()).map(|p| { + let plan_copy = plan.clone(); + let task_context = state.task_ctx().clone(); + async move { + let batch_stream = plan_copy.execute(p, task_context)?; + + let schema = batch_stream.schema(); + + let batches = batch_stream.try_collect::>().await?; + + DataFusionResult::<_>::Ok(arrow::compute::concat_batches(&schema, batches.iter())?) + } + })) + .await?; + + let batch = arrow::compute::concat_batches(&plan.schema(), data.iter())?; + + Ok(batch) +} + /// Responsible for checking batches of data conform to table's invariants. #[derive(Clone)] pub struct DeltaDataChecker { diff --git a/crates/deltalake-core/src/operations/merge.rs b/crates/deltalake-core/src/operations/merge.rs index 7b03965747..0f0da1c21f 100644 --- a/crates/deltalake-core/src/operations/merge.rs +++ b/crates/deltalake-core/src/operations/merge.rs @@ -50,12 +50,16 @@ use datafusion::{ }, prelude::{DataFrame, SessionContext}, }; +use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ScalarValue, TableReference}; +use datafusion_expr::expr::Placeholder; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, + BinaryExpr, Distinct, Extension, Filter, LogicalPlan, LogicalPlanBuilder, Projection, + UserDefinedLogicalNode, UNNAMED_TABLE, }; use futures::future::BoxFuture; +use itertools::Itertools; use parquet::file::properties::WriterProperties; use serde::Serialize; use serde_json::Value; @@ -66,7 +70,8 @@ use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression} use crate::delta_datafusion::logical::MetricObserver; use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; use crate::delta_datafusion::{ - register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, DeltaTableProvider, + execute_plan_to_batch, register_store, DeltaColumn, DeltaScanConfig, DeltaSessionConfig, + DeltaTableProvider, }; use crate::kernel::{Action, Remove}; use crate::logstore::LogStoreRef; @@ -652,6 +657,242 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { } } +/// Takes the predicate provided and does two things: +/// +/// 1. for any relations between a source column and a target column, if the target column is a +/// partition column, then replace source with a placeholder matching the name of the partition +/// columns +/// +/// 2. for any other relation with a source column, remove them. +/// +/// For example, for the predicate: +/// +/// `source.date = target.date and source.id = target.id and frob > 42` +/// +/// where `date` is a partition column, would result in the expr: +/// +/// `$date = target.date and frob > 42` +/// +/// This leaves us with a predicate that we can push into delta scan after expanding it out to +/// a conjunction between the disinct partitions in the source input. +/// +/// TODO: A futher improvement here might be for non-partition columns to be replaced with min/max +/// checks, so the above example could become: +/// +/// `$date = target.date and target.id between 12345 and 99999 and frob > 42` +fn generalize_filter( + predicate: Expr, + partition_columns: &Vec, + source_name: &TableReference, + target_name: &TableReference, + placeholders: &mut HashMap, +) -> Option { + fn references_table(expr: &Expr, table: &TableReference) -> Option { + match expr { + Expr::Alias(alias) => references_table(&alias.expr, table), + Expr::Column(col) => col.relation.as_ref().and_then(|rel| { + if rel == table { + Some(col.name.to_owned()) + } else { + None + } + }), + Expr::Negative(neg) => references_table(neg, table), + Expr::Cast(cast) => references_table(&cast.expr, table), + Expr::TryCast(try_cast) => references_table(&try_cast.expr, table), + Expr::ScalarFunction(func) => { + if func.args.len() == 1 { + references_table(&func.args[0], table) + } else { + None + } + } + Expr::ScalarUDF(udf) => { + if udf.args.len() == 1 { + references_table(&udf.args[0], table) + } else { + None + } + } + _ => None, + } + } + + match predicate { + Expr::BinaryExpr(binary) => { + if references_table(&binary.right, source_name).is_some() { + if let Some(left_target) = references_table(&binary.left, target_name) { + if partition_columns.contains(&left_target) { + let placeholder_name = format!("{left_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + left: binary.left, + op: binary.op, + right: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.right); + + return Some(replaced); + } + } + return None; + } + if references_table(&binary.left, source_name).is_some() { + if let Some(right_target) = references_table(&binary.right, target_name) { + if partition_columns.contains(&right_target) { + let placeholder_name = format!("{right_target}_{}", placeholders.len()); + + let placeholder = Expr::Placeholder(datafusion_expr::expr::Placeholder { + id: placeholder_name.clone(), + data_type: None, + }); + let replaced = Expr::BinaryExpr(BinaryExpr { + right: binary.right, + op: binary.op, + left: placeholder.into(), + }); + + placeholders.insert(placeholder_name, *binary.left); + + return Some(replaced); + } + } + return None; + } + + let left = generalize_filter( + *binary.left, + partition_columns, + source_name, + target_name, + placeholders, + ); + let right = generalize_filter( + *binary.right, + partition_columns, + source_name, + target_name, + placeholders, + ); + + match (left, right) { + (None, None) => None, + (None, Some(r)) => Some(r), + (Some(l), None) => Some(l), + (Some(l), Some(r)) => Expr::BinaryExpr(BinaryExpr { + left: l.into(), + op: binary.op, + right: r.into(), + }) + .into(), + } + } + other => Some(other), + } +} + +fn replace_placeholders(expr: Expr, placeholders: &HashMap) -> Expr { + expr.transform(&|expr| match expr { + Expr::Placeholder(Placeholder { id, .. }) => { + let value = placeholders[&id].clone(); + // Replace the placeholder with the value + Ok(Transformed::Yes(Expr::Literal(value))) + } + _ => Ok(Transformed::No(expr)), + }) + .unwrap() +} + +async fn try_construct_early_filter( + join_predicate: Expr, + table_snapshot: &DeltaTableState, + session_state: &SessionState, + source: &LogicalPlan, + source_name: &TableReference<'_>, + target_name: &TableReference<'_>, +) -> DeltaResult> { + let table_metadata = table_snapshot.metadata(); + + if table_metadata.is_none() { + return Ok(None); + } + + let table_metadata = table_metadata.unwrap(); + + let partition_columns = &table_metadata.partition_columns; + + if partition_columns.is_empty() { + return Ok(None); + } + + let mut placeholders = HashMap::default(); + + match generalize_filter( + join_predicate, + partition_columns, + source_name, + target_name, + &mut placeholders, + ) { + None => Ok(None), + Some(filter) => { + if placeholders.is_empty() { + // if we haven't recognised any partition-based predicates in the join predicate, return our reduced filter + Ok(Some(filter)) + } else { + // if we have some recognised partitions, then discover the distinct set of partitions in the source data and + // make a new filter, which expands out the placeholders for each distinct partition (and then OR these together) + let distinct_partitions = LogicalPlan::Distinct(Distinct { + input: LogicalPlan::Projection(Projection::try_new( + placeholders + .into_iter() + .map(|(alias, expr)| expr.alias(alias)) + .collect_vec(), + source.clone().into(), + )?) + .into(), + }); + + let execution_plan = session_state + .create_physical_plan(&distinct_partitions) + .await?; + + let items = execute_plan_to_batch(session_state, execution_plan).await?; + + let placeholder_names = items + .schema() + .fields() + .iter() + .map(|f| f.name().to_owned()) + .collect_vec(); + + let expr = (0..items.num_rows()) + .map(|i| { + let replacements = placeholder_names + .iter() + .map(|placeholder| { + let col = items.column_by_name(placeholder).unwrap(); + let value = ScalarValue::try_from_array(col, i)?; + DeltaResult::Ok((placeholder.to_owned(), value)) + }) + .try_collect()?; + Ok(replace_placeholders(filter.clone(), &replacements)) + }) + .collect::>>()? + .into_iter() + .reduce(Expr::or); + + Ok(expr) + } + } + } +} + #[allow(clippy::too_many_arguments)] async fn execute( predicate: Expression, @@ -693,9 +934,12 @@ async fn execute( }; // This is only done to provide the source columns with a correct table reference. Just renaming the columns does not work - let source = - LogicalPlanBuilder::scan(source_name, provider_as_source(source.into_view()), None)? - .build()?; + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + )? + .build()?; let source = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { @@ -704,17 +948,68 @@ async fn execute( }), }); - let source = DataFrame::new(state.clone(), source); - let source = source.with_column(SOURCE_COLUMN, lit(true))?; - let target_provider = Arc::new(DeltaTableProvider::try_new( snapshot.clone(), log_store.clone(), DeltaScanConfig::default(), )?); + let target_provider = provider_as_source(target_provider); - let target = LogicalPlanBuilder::scan(target_name, target_provider, None)?.build()?; + let target = LogicalPlanBuilder::scan(target_name.clone(), target_provider, None)?.build()?; + + let source_schema = source.schema(); + let target_schema = target.schema(); + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; + let predicate = match predicate { + Expression::DataFusion(expr) => expr, + Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, + }; + + let state = state.with_query_planner(Arc::new(MergePlanner {})); + + let (target, files) = { + // Attempt to construct an early filter that we can apply to the Add action list and the delta scan. + // In the case where there are partition columns in the join predicate, we can scan the source table + // to get the distinct list of partitions affected and constrain the search to those. + + if !not_match_source_operations.is_empty() { + // It's only worth trying to create an early filter where there are no `when_not_matched_source` operators, since + // that implies a full scan + (target, snapshot.files().iter().collect_vec()) + } else if let Some(filter) = try_construct_early_filter( + predicate.clone(), + snapshot, + &state, + &source, + &source_name, + &target_name, + ) + .await? + { + let file_filter = filter + .clone() + .transform(&|expr| match expr { + Expr::Column(c) => Ok(Transformed::Yes(Expr::Column(Column { + relation: None, // the file filter won't be looking at columns like `target.partition`, it'll just be `partition` + name: c.name, + }))), + expr => Ok(Transformed::No(expr)), + }) + .unwrap(); + let files = snapshot + .files_matching_predicate(&[file_filter])? + .collect_vec(); + + let new_target = LogicalPlan::Filter(Filter::try_new(filter, target.into())?); + (new_target, files) + } else { + (target, snapshot.files().iter().collect_vec()) + } + }; + + let source = DataFrame::new(state.clone(), source); + let source = source.with_column(SOURCE_COLUMN, lit(true))?; // TODO: This is here to prevent predicate pushdowns. In the future we can replace this node to allow pushdowns depending on which operations are being used. let target = LogicalPlan::Extension(Extension { @@ -726,14 +1021,6 @@ async fn execute( let target = DataFrame::new(state.clone(), target); let target = target.with_column(TARGET_COLUMN, lit(true))?; - let source_schema = source.schema(); - let target_schema = target.schema(); - let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; - let predicate = match predicate { - Expression::DataFusion(expr) => expr, - Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, - }; - let join = source.join(target, JoinType::Full, &[], &[], Some(predicate.clone()))?; let join_schema_df = join.schema().to_owned(); @@ -999,7 +1286,6 @@ async fn execute( let project = filtered.select(write_projection)?; let optimized = &project.into_optimized_plan()?; - let state = state.with_query_planner(Arc::new(MergePlanner {})); let write = state.create_physical_plan(optimized).await?; let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); @@ -1034,7 +1320,7 @@ async fn execute( let mut actions: Vec = add_actions.into_iter().map(Action::Add).collect(); metrics.num_target_files_added = actions.len(); - for action in snapshot.files() { + for action in files { metrics.num_target_files_removed += 1; actions.push(Action::Remove(Remove { path: action.path.clone(), @@ -1162,6 +1448,8 @@ mod tests { use crate::kernel::DataType; use crate::kernel::PrimitiveType; use crate::kernel::StructField; + use crate::operations::merge::generalize_filter; + use crate::operations::merge::try_construct_early_filter; use crate::operations::DeltaOps; use crate::protocol::*; use crate::writer::test_utils::datafusion::get_data; @@ -1175,11 +1463,21 @@ mod tests { use arrow_schema::DataType as ArrowDataType; use arrow_schema::Field; use datafusion::assert_batches_sorted_eq; + use datafusion::datasource::provider_as_source; use datafusion::prelude::DataFrame; use datafusion::prelude::SessionContext; + use datafusion_common::Column; + use datafusion_common::ScalarValue; + use datafusion_common::TableReference; use datafusion_expr::col; + use datafusion_expr::expr::Placeholder; use datafusion_expr::lit; + use datafusion_expr::Expr; + use datafusion_expr::LogicalPlanBuilder; + use datafusion_expr::Operator; use serde_json::json; + use std::collections::HashMap; + use std::ops::Neg; use std::sync::Arc; use super::MergeMetrics; @@ -1508,7 +1806,7 @@ mod tests { #[tokio::test] async fn test_merge_partitions() { - /* Validate the join predicate works with partition columns */ + /* Validate the join predicate works with table partitions */ let schema = get_arrow_schema(&None); let table = setup_table(Some(vec!["modified"])).await; @@ -1596,6 +1894,78 @@ mod tests { assert_batches_sorted_eq!(&expected, &actual); } + #[tokio::test] + async fn test_merge_partitions_skipping() { + /* Validate the join predicate can be used for skipping partitions */ + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + let table = write_data(table, &schema).await; + assert_eq!(table.version(), 1); + assert_eq!(table.get_file_uris().count(), 4); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![999, 999, 999])), + Arc::new(arrow::array::StringArray::from(vec![ + "2023-07-04", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let (table, metrics) = DeltaOps(table) + .merge(source, col("target.id").eq(col("source.id"))) + .with_source_alias("source") + .with_target_alias("target") + .when_matched_update(|update| { + update + .update("value", col("source.value")) + .update("modified", col("source.modified")) + }) + .unwrap() + .when_not_matched_insert(|insert| { + insert + .set("id", col("source.id")) + .set("value", col("source.value")) + .set("modified", col("source.modified")) + }) + .unwrap() + .await + .unwrap(); + + assert_eq!(table.version(), 2); + assert!(table.get_file_uris().count() >= 3); + assert_eq!(metrics.num_target_files_added, 3); + assert_eq!(metrics.num_target_files_removed, 2); + assert_eq!(metrics.num_target_rows_copied, 0); + assert_eq!(metrics.num_target_rows_updated, 2); + assert_eq!(metrics.num_target_rows_inserted, 1); + assert_eq!(metrics.num_target_rows_deleted, 0); + assert_eq!(metrics.num_output_rows, 3); + assert_eq!(metrics.num_source_rows, 3); + + let expected = vec![ + "+-------+------------+----+", + "| value | modified | id |", + "+-------+------------+----+", + "| 1 | 2021-02-01 | A |", + "| 100 | 2021-02-02 | D |", + "| 999 | 2023-07-04 | B |", + "| 999 | 2023-07-04 | C |", + "| 999 | 2023-07-04 | X |", + "+-------+------------+----+", + ]; + let actual = get_data(&table).await; + assert_batches_sorted_eq!(&expected, &actual); + } + #[tokio::test] async fn test_merge_delete_matched() { // Validate behaviours of match delete @@ -2015,4 +2385,228 @@ mod tests { let actual = get_data(&table).await; assert_batches_sorted_eq!(&expected, &actual); } + + #[tokio::test] + async fn test_generalize_filter_with_partitions() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_with_partitions_captures_expression() { + // Check that when generalizing the filter, the placeholder map captures the expression needed to make the statement the same + // when the distinct values are substitiuted in + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .neg() + .eq(col(Column::new(target.clone().into(), "id"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + + assert_eq!(placeholders.len(), 1); + + let placeholder_expr = &placeholders["id_0"]; + + let expected_placeholder = col(Column::new(source.clone().into(), "id")).neg(); + + assert_eq!(placeholder_expr, &expected_placeholder); + } + + #[tokio::test] + async fn test_generalize_filter_keeps_static_target_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(target.clone().into(), "id")).eq(lit("C"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_generalize_filter_removes_source_references() { + let source = TableReference::parse_str("source"); + let target = TableReference::parse_str("target"); + + let parsed_filter = col(Column::new(source.clone().into(), "id")) + .eq(col(Column::new(target.clone().into(), "id"))) + .and(col(Column::new(source.clone().into(), "id")).eq(lit("C"))); + + let mut placeholders = HashMap::default(); + + let generalized = generalize_filter( + parsed_filter, + &vec!["id".to_owned()], + &source, + &target, + &mut placeholders, + ) + .unwrap(); + + let expected_filter = Expr::Placeholder(Placeholder { + id: "id_0".to_owned(), + data_type: None, + }) + .eq(col(Column::new(target.clone().into(), "id"))); + + assert_eq!(generalized, expected_filter); + } + + #[tokio::test] + async fn test_try_construct_early_filter_with_partitions_expands() { + let schema = get_arrow_schema(&None); + let table = setup_table(Some(vec!["id"])).await; + + assert_eq!(table.version(), 0); + assert_eq!(table.get_file_uris().count(), 0); + + let ctx = SessionContext::new(); + let batch = RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(arrow::array::StringArray::from(vec!["B", "C", "X"])), + Arc::new(arrow::array::Int32Array::from(vec![10, 20, 30])), + Arc::new(arrow::array::StringArray::from(vec![ + "2021-02-02", + "2023-07-04", + "2023-07-04", + ])), + ], + ) + .unwrap(); + let source = ctx.read_batch(batch).unwrap(); + + let source_name = TableReference::parse_str("source"); + let target_name = TableReference::parse_str("target"); + + let source = LogicalPlanBuilder::scan( + source_name.clone(), + provider_as_source(source.into_view()), + None, + ) + .unwrap() + .build() + .unwrap(); + + let join_predicate = col(Column { + relation: Some(source_name.clone()), + name: "id".to_owned(), + }) + .eq(col(Column { + relation: Some(target_name.clone()), + name: "id".to_owned(), + })); + + let pred = try_construct_early_filter( + join_predicate, + &table.state, + &ctx.state(), + &source, + &source_name, + &target_name, + ) + .await + .unwrap(); + + assert!(pred.is_some()); + + let split_pred = { + fn split(expr: Expr, parts: &mut Vec<(String, String)>) { + match expr { + Expr::BinaryExpr(ex) if ex.op == Operator::Or => { + split(*ex.left, parts); + split(*ex.right, parts); + } + Expr::BinaryExpr(ex) if ex.op == Operator::Eq => { + let col = match *ex.right { + Expr::Column(col) => col.name, + ex => panic!("expected column in pred, got {ex}!"), + }; + + let value = match *ex.left { + Expr::Literal(ScalarValue::Utf8(Some(value))) => value, + ex => panic!("expected value in predicate, got {ex}!"), + }; + + parts.push((col, value)) + } + + expr => panic!("expected either = or OR, got {expr}"), + } + } + + let mut parts = vec![]; + split(pred.unwrap(), &mut parts); + parts.sort(); + parts + }; + + let expected_pred_parts = [ + ("id".to_owned(), "B".to_owned()), + ("id".to_owned(), "C".to_owned()), + ("id".to_owned(), "X".to_owned()), + ]; + + assert_eq!(split_pred, expected_pred_parts); + } }