From 82ef6bf695186071ceb31c870b06ef957606d7bd Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sat, 18 Jan 2025 00:51:00 +0100 Subject: [PATCH] refactor: logical plans in writer Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> --- crates/core/src/operations/write.rs | 210 ++++++++------------ crates/core/tests/integration_datafusion.rs | 76 +++---- python/tests/test_writer.py | 15 ++ 3 files changed, 127 insertions(+), 174 deletions(-) diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write.rs index 914f88887c..b4e6edce7d 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write.rs @@ -33,15 +33,13 @@ use std::vec; use arrow_array::RecordBatch; use arrow_cast::can_cast_types; use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; +use datafusion::catalog::TableProvider; +use datafusion::datasource::{provider_as_source, MemTable}; use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; +use datafusion::prelude::DataFrame; use datafusion_common::DFSchema; -use datafusion_expr::{col, lit, when, Expr, ExprSchemable}; -use datafusion_physical_expr::expressions::{self}; -use datafusion_physical_expr::PhysicalExpr; -use datafusion_physical_plan::filter::FilterExec; -use datafusion_physical_plan::projection::ProjectionExec; -use datafusion_physical_plan::union::UnionExec; -use datafusion_physical_plan::{memory::MemoryExec, ExecutionPlan}; +use datafusion_expr::{col, lit, when, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; +use datafusion_physical_plan::ExecutionPlan; use futures::future::BoxFuture; use futures::StreamExt; use object_store::prefix::PrefixStore; @@ -58,7 +56,7 @@ use super::{CreateBuilder, CustomExecuteHandler, Operation}; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::expr::parse_predicate_expression; use crate::delta_datafusion::{ - find_files, register_store, DeltaScanBuilder, DeltaScanConfigBuilder, + find_files, register_store, DeltaScanConfigBuilder, DeltaTableProvider, }; use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; use crate::errors::{DeltaResult, DeltaTableError}; @@ -136,7 +134,7 @@ pub struct WriteBuilder { /// Delta object store for handling data files log_store: LogStoreRef, /// The input plan - input: Option>, + input: Option>, /// Datafusion session state relevant for executing the input plan state: Option, /// SaveMode defines how to treat data already written to table location @@ -245,8 +243,8 @@ impl WriteBuilder { self } - /// Execution plan that produces the data to be written to the delta table - pub fn with_input_execution_plan(mut self, plan: Arc) -> Self { + /// Logical execution plan that produces the data to be written to the delta table + pub fn with_input_execution_plan(mut self, plan: Arc) -> Self { self.input = Some(plan); self } @@ -351,7 +349,7 @@ impl WriteBuilder { }; let schema: StructType = match &self.input { - Some(plan) => (plan.schema()).try_into()?, + Some(plan) => (plan.schema().as_arrow()).try_into()?, None => (batches[0].schema()).try_into()?, }; @@ -602,7 +600,7 @@ async fn execute_non_empty_expr( writer_properties: Option, writer_stats_config: WriterStatsConfig, partition_scan: bool, - insert_plan: Arc, + insert_df: DataFrame, operation_id: Uuid, ) -> DeltaResult> { // For each identified file perform a parquet scan + filter + limit (1) + count. @@ -611,30 +609,30 @@ async fn execute_non_empty_expr( // Take the insert plan schema since it might have been schema evolved, if its not // it is simply the table schema - let df_schema = insert_plan.schema(); - let input_dfschema: DFSchema = df_schema.as_ref().clone().try_into()?; - let scan_config = DeltaScanConfigBuilder::new() .with_schema(snapshot.input_schema()?) .build(snapshot)?; - let scan = DeltaScanBuilder::new(snapshot, log_store.clone(), &state) - .with_files(rewrite) - // Use input schema which doesn't wrap partition values, otherwise divide_by_partition_value won't work on UTF8 partitions - // Since it can't fetch a scalar from a dictionary type - .with_scan_config(scan_config) - .build() - .await?; - let scan = Arc::new(scan); + let target_provider = Arc::new( + DeltaTableProvider::try_new(snapshot.clone(), log_store.clone(), scan_config.clone())? + .with_files(rewrite.to_vec()), + ); + let target_provider = provider_as_source(target_provider); + let source = LogicalPlanBuilder::scan("target", target_provider.clone(), None)?.build()?; // We don't want to verify the predicate against existing data + + let df = DataFrame::new(state.clone(), source); + if !partition_scan { // Apply the negation of the filter and rewrite files let negated_expression = Expr::Not(Box::new(Expr::IsTrue(Box::new(expression.clone())))); - let predicate_expr = state.create_physical_expr(negated_expression, &input_dfschema)?; - let filter: Arc = - Arc::new(FilterExec::try_new(predicate_expr, scan.clone())?); + let filter = df + .clone() + .filter(negated_expression)? + .create_physical_plan() + .await?; let add_actions: Vec = write_execution_plan( Some(snapshot), @@ -662,13 +660,12 @@ async fn execute_non_empty_expr( snapshot, log_store, state.clone(), - scan, - input_dfschema, + df, expression, partition_columns, writer_properties, writer_stats_config, - insert_plan, + insert_df, operation_id, ) .await? @@ -685,71 +682,34 @@ pub(crate) async fn execute_non_empty_expr_cdc( snapshot: &DeltaTableState, log_store: LogStoreRef, state: SessionState, - scan: Arc, - input_dfschema: DFSchema, + scan: DataFrame, expression: &Expr, table_partition_cols: Vec, writer_properties: Option, writer_stats_config: WriterStatsConfig, - insert_plan: Arc, + insert_df: DataFrame, operation_id: Uuid, ) -> DeltaResult>> { match should_write_cdc(snapshot) { // Create CDC scan Ok(true) => { - let cdc_predicate_expr = - state.create_physical_expr(expression.clone(), &input_dfschema)?; - let cdc_scan: Arc = - Arc::new(FilterExec::try_new(cdc_predicate_expr, scan.clone())?); + let filter = scan.clone().filter(expression.clone())?; // Add literal column "_change_type" - let delete_change_type_expr = - state.create_physical_expr(lit("delete"), &input_dfschema)?; + let delete_change_type_expr = lit("delete").alias("_change_type"); - let insert_change_type_expr = - state.create_physical_expr(lit("insert"), &input_dfschema)?; + let insert_change_type_expr = lit("insert").alias("_change_type"); - // Project columns and lit - let mut delete_project_expressions: Vec<(Arc, String)> = scan - .schema() - .fields() - .into_iter() - .enumerate() - .map(|(idx, field)| -> (Arc, String) { - ( - Arc::new(expressions::Column::new(field.name(), idx)), - field.name().to_owned(), - ) - }) - .collect(); + let delete_df = filter.with_column("_change_type", delete_change_type_expr)?; - let mut insert_project_expressions = delete_project_expressions.clone(); - delete_project_expressions.insert( - delete_project_expressions.len(), - (delete_change_type_expr, "_change_type".to_owned()), - ); - insert_project_expressions.insert( - insert_project_expressions.len(), - (insert_change_type_expr, "_change_type".to_owned()), - ); - - let delete_plan: Arc = Arc::new(ProjectionExec::try_new( - delete_project_expressions, - cdc_scan.clone(), - )?); + let insert_df = insert_df.with_column("_change_type", insert_change_type_expr)?; - let insert_plan: Arc = Arc::new(ProjectionExec::try_new( - insert_project_expressions, - insert_plan.clone(), - )?); - - let cdc_plan: Arc = - Arc::new(UnionExec::new(vec![delete_plan, insert_plan])); + let cdc_df = delete_df.union(insert_df)?; let cdc_actions = write_execution_plan_cdc( Some(snapshot), state.clone(), - cdc_plan.clone(), + cdc_df.create_physical_plan().await?, table_partition_cols.clone(), log_store.object_store(Some(operation_id)), Some(snapshot.table_config().target_file_size() as usize), @@ -776,7 +736,7 @@ async fn prepare_predicate_actions( writer_properties: Option, deletion_timestamp: i64, writer_stats_config: WriterStatsConfig, - insert_plan: Arc, + insert_df: DataFrame, operation_id: Uuid, ) -> DeltaResult> { let candidates = @@ -792,7 +752,7 @@ async fn prepare_predicate_actions( writer_properties, writer_stats_config, candidates.partition_scan, - insert_plan, + insert_df, operation_id, ) .await?; @@ -858,19 +818,28 @@ impl std::future::IntoFuture for WriteBuilder { Ok(this.partition_columns.unwrap_or_default()) }?; + let state = match this.state { + Some(state) => state, + None => { + let ctx = SessionContext::new(); + register_store(this.log_store.clone(), ctx.runtime_env()); + ctx.state() + } + }; + let generated_col_expressions = this .snapshot .as_ref() .map(|v| v.schema().get_generated_columns().unwrap_or_default()) .unwrap_or_default(); let mut schema_drift = false; - let mut plan = if let Some(plan) = this.input { + let mut df = if let Some(plan) = this.input { if this.schema_mode == Some(SchemaMode::Merge) { return Err(DeltaTableError::Generic( "Schema merge not supported yet for Datafusion".to_string(), )); } - Ok(plan) + Ok(DataFrame::new(state.clone(), plan.as_ref().clone())) } else if let Some(batches) = this.batches { if batches.is_empty() { Err(WriteError::MissingData) @@ -993,17 +962,19 @@ impl std::future::IntoFuture for WriteBuilder { } }; - Ok(Arc::new(MemoryExec::try_new( - &data, - new_schema.unwrap_or(schema).clone(), - None, - )?) as Arc) + let ctx = SessionContext::new(); + let table_provider: Arc = Arc::new( + MemTable::try_new(new_schema.unwrap_or(schema).clone(), data).unwrap(), + ); + let df = ctx.read_table(table_provider).unwrap(); + + Ok(df) } } else { Err(WriteError::MissingData) }?; - let schema = plan.schema(); + let schema = Arc::new(df.schema().as_arrow().clone()); if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { if let Some(snapshot) = &this.snapshot { let schema_struct: StructType = schema.clone().try_into()?; @@ -1025,47 +996,33 @@ impl std::future::IntoFuture for WriteBuilder { } } } - let state = match this.state { - Some(state) => state, - None => { - let ctx = SessionContext::new(); - register_store(this.log_store.clone(), ctx.runtime_env()); - ctx.state() - } - }; // Add when.then expr for generated columns if !generated_col_expressions.is_empty() { fn create_field( - idx: usize, field: &arrow_schema::Field, generated_cols_map: &HashMap, state: &datafusion::execution::session_state::SessionState, dfschema: &DFSchema, - ) -> DeltaResult<(Arc, String)> { + ) -> DeltaResult { match generated_cols_map.get(field.name()) { Some(generated_col) => { - let generation_expr = state.create_physical_expr( - when( - col(generated_col.get_name()).is_null(), - state.create_logical_expr( - generated_col.get_generation_expression(), - dfschema, - )?, - ) - .otherwise(col(generated_col.get_name()))? - .cast_to( - &arrow_schema::DataType::try_from(&generated_col.data_type)?, + let generation_expr = when( + col(generated_col.get_name()).is_null(), + state.create_logical_expr( + generated_col.get_generation_expression(), dfschema, )?, + ) + .otherwise(col(generated_col.get_name()))? + .cast_to( + &arrow_schema::DataType::try_from(&generated_col.data_type)?, dfschema, - )?; - Ok((generation_expr, field.name().to_owned())) + )? + .alias(field.name().to_owned()); + Ok(generation_expr) } - None => Ok(( - Arc::new(expressions::Column::new(field.name(), idx)), - field.name().to_owned(), - )), + None => Ok(col(field.name().to_owned())), } } @@ -1074,17 +1031,14 @@ impl std::future::IntoFuture for WriteBuilder { .into_iter() .map(|v| (v.name.clone(), v)) .collect::>(); - let current_fields: DeltaResult, String)>> = plan + let current_fields: DeltaResult> = df .schema() .fields() .into_iter() - .enumerate() - .map(|(idx, field)| { - create_field(idx, field, &generated_cols_map, &state, &dfschema) - }) + .map(|field| create_field(field, &generated_cols_map, &state, &dfschema)) .collect(); - plan = Arc::new(ProjectionExec::try_new(current_fields?, plan.clone())?); + df = df.select(current_fields?)?; }; let (predicate_str, predicate) = match this.predicate { @@ -1123,7 +1077,7 @@ impl std::future::IntoFuture for WriteBuilder { predicate.clone(), this.snapshot.as_ref(), state.clone(), - plan.clone(), + df.clone().create_physical_plan().await?, partition_columns.clone(), this.log_store.object_store(Some(operation_id)).clone(), target_file_size, @@ -1177,7 +1131,7 @@ impl std::future::IntoFuture for WriteBuilder { this.writer_properties, deletion_timestamp, writer_stats_config, - plan, + df, operation_id, ) .await?; @@ -2494,13 +2448,13 @@ mod tests { .await?; let ctx = SessionContext::new(); - let plan = ctx - .sql("SELECT 1 as id") - .await - .unwrap() - .create_physical_plan() - .await - .unwrap(); + let plan = Arc::new( + ctx.sql("SELECT 1 as id") + .await + .unwrap() + .logical_plan() + .clone(), + ); let writer = WriteBuilder::new(table.log_store.clone(), table.state) .with_input_execution_plan(plan) .with_save_mode(SaveMode::Overwrite); diff --git a/crates/core/tests/integration_datafusion.rs b/crates/core/tests/integration_datafusion.rs index 2e7a9f50be..080021d6e5 100644 --- a/crates/core/tests/integration_datafusion.rs +++ b/crates/core/tests/integration_datafusion.rs @@ -22,9 +22,9 @@ use datafusion_common::ScalarValue::*; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::Expr; use datafusion_proto::bytes::{ - physical_plan_from_bytes_with_extension_codec, physical_plan_to_bytes_with_extension_codec, + logical_plan_from_bytes_with_extension_codec, logical_plan_to_bytes_with_extension_codec, }; -use deltalake_core::delta_datafusion::{DeltaPhysicalCodec, DeltaScan}; +use deltalake_core::delta_datafusion::DeltaScan; use deltalake_core::kernel::{DataType, MapType, PrimitiveType, StructField, StructType}; use deltalake_core::logstore::logstore_for; use deltalake_core::operations::create::CreateBuilder; @@ -41,8 +41,16 @@ use serial_test::serial; use url::Url; mod local { - use datafusion::common::stats::Precision; - use deltalake_core::{logstore::default_logstore, writer::JsonWriter}; + use datafusion::{ + common::stats::Precision, datasource::provider_as_source, prelude::DataFrame, + }; + use datafusion_expr::LogicalPlanBuilder; + use deltalake_core::{ + delta_datafusion::{DeltaLogicalCodec, DeltaScanConfigBuilder, DeltaTableProvider}, + logstore::default_logstore, + writer::JsonWriter, + }; + use itertools::Itertools; use object_store::local::LocalFileSystem; use super::*; @@ -194,21 +202,26 @@ mod local { let ctx = SessionContext::new(); let state = ctx.state(); let source_table = open_table("../test/tests/data/delta-0.8.0-date").await?; - let source_scan = source_table.scan(&state, None, &[], None).await?; - physical_plan_to_bytes_with_extension_codec(source_scan, &DeltaPhysicalCodec {})? + + let target_provider = provider_as_source(Arc::new(source_table)); + let source = + LogicalPlanBuilder::scan("source", target_provider.clone(), None)?.build()?; + // We don't want to verify the predicate against existing data + logical_plan_to_bytes_with_extension_codec(&source, &DeltaLogicalCodec {})? }; // Build a new context from scratch and deserialize the plan let ctx = SessionContext::new(); let state = ctx.state(); - let source_scan = physical_plan_from_bytes_with_extension_codec( + let source_scan = Arc::new(logical_plan_from_bytes_with_extension_codec( &source_scan_bytes, &ctx, - &DeltaPhysicalCodec {}, - )?; - let schema = StructType::try_from(source_scan.schema()).unwrap(); + &DeltaLogicalCodec {}, + )?); + let schema = StructType::try_from(source_scan.schema().as_arrow()).unwrap(); let fields = schema.fields().cloned(); + dbg!(schema.fields().collect_vec().clone()); // Create target Delta Table let target_table = CreateBuilder::new() .with_location("memory:///target") @@ -216,35 +229,6 @@ mod local { .with_table_name("target") .await?; - // Trying to execute the write from the input plan without providing Datafusion with a session - // state containing the referenced object store in the registry results in an error. - assert!(WriteBuilder::new( - target_table.log_store(), - target_table.snapshot().ok().cloned() - ) - .with_input_execution_plan(source_scan.clone()) - .await - .unwrap_err() - .to_string() - .contains("No suitable object store found for delta-rs://")); - - // Register the missing source table object store - let source_uri = Url::parse( - &source_scan - .as_any() - .downcast_ref::() - .unwrap() - .table_uri - .clone(), - ) - .unwrap(); - let source_store = logstore_for(source_uri, HashMap::new(), None).unwrap(); - let object_store_url = source_store.object_store_url(); - let source_store_url: &Url = object_store_url.as_ref(); - state - .runtime_env() - .register_object_store(source_store_url, source_store.object_store(None)); - // Execute write to the target table with the proper state let target_table = WriteBuilder::new( target_table.log_store(), @@ -1129,13 +1113,13 @@ mod local { ); let tbl = tbl.await.unwrap(); let ctx = SessionContext::new(); - let plan = ctx - .sql("SELECT 1 as id") - .await - .unwrap() - .create_physical_plan() - .await - .unwrap(); + let plan = Arc::new( + ctx.sql("SELECT 1 as id") + .await + .unwrap() + .logical_plan() + .clone(), + ); let write_builder = WriteBuilder::new(log_store, tbl.state); let _ = write_builder .with_input_execution_plan(plan) diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 11320743e0..cbf40dbfd1 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -2031,3 +2031,18 @@ def test_write_structs(tmp_path: pathlib.Path): arrow_dt = dt.to_pyarrow_dataset() new_df = pl.scan_pyarrow_dataset(arrow_dt) new_df.collect() + + +@pytest.mark.polars +def test_write_type_coercion_predicate(tmp_path: pathlib.Path): + import polars as pl + + df = pl.DataFrame({"A": [1, 2], "B": ["hi", "hello"], "C": ["a", "b"]}) + df.write_delta(tmp_path) + + df = pl.DataFrame({"A": [10], "B": ["yeah"], "C": ["a"]}) + df.write_delta( + tmp_path, + mode="overwrite", + delta_write_options=dict(engine="rust", predicate="C = 'a'"), + )