diff --git a/rust/src/delta_datafusion/logical.rs b/rust/src/delta_datafusion/logical.rs index 05bddf146f..7b05dd57d9 100644 --- a/rust/src/delta_datafusion/logical.rs +++ b/rust/src/delta_datafusion/logical.rs @@ -1,18 +1,22 @@ //! Logical Operations for DataFusion use datafusion_expr::{LogicalPlan, UserDefinedLogicalNodeCore}; -pub const METRIC_OBSERVER_NAME: &str = "MetricObserver"; + +// Metric Observer is used to update DataFusion metrics from a record batch. +// See MetricObserverExec for the physical implementation #[derive(Debug, Hash, Eq, PartialEq)] pub(crate) struct MetricObserver { - // This acts as an anchor when converting a to physical operator - pub anchor: String, + // id is preserved during conversion to physical node + pub id: String, pub input: LogicalPlan, } impl UserDefinedLogicalNodeCore for MetricObserver { + // Predicate push down is not supported for this node. Try to limit usage + // near the end of plan. fn name(&self) -> &str { - METRIC_OBSERVER_NAME + "MetricObserver" } fn inputs(&self) -> Vec<&datafusion_expr::LogicalPlan> { @@ -28,7 +32,7 @@ impl UserDefinedLogicalNodeCore for MetricObserver { } fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "MetricObserver name={}", &self.anchor) + write!(f, "MetricObserver id={}", &self.id) } fn from_template( @@ -37,7 +41,7 @@ impl UserDefinedLogicalNodeCore for MetricObserver { inputs: &[datafusion_expr::LogicalPlan], ) -> Self { MetricObserver { - anchor: self.anchor.clone(), + id: self.id.clone(), input: inputs[0].clone(), } } diff --git a/rust/src/delta_datafusion/physical.rs b/rust/src/delta_datafusion/physical.rs index 4238372bc6..c357f286d3 100644 --- a/rust/src/delta_datafusion/physical.rs +++ b/rust/src/delta_datafusion/physical.rs @@ -11,34 +11,53 @@ use datafusion::physical_plan::{ }; use futures::{Stream, StreamExt}; +use crate::DeltaTableError; + +// Metric Observer is used to update DataFusion metrics from a record batch. +// Typically the null count for a particular column is pulled after performing a +// projection since this count is easy to obtain + pub(crate) type MetricObserverFunction = fn(&RecordBatch, &ExecutionPlanMetricsSet) -> (); pub(crate) struct MetricObserverExec { parent: Arc, - anchor: String, + id: String, metrics: ExecutionPlanMetricsSet, update: MetricObserverFunction, } impl MetricObserverExec { - pub fn new(anchor: String, parent: Arc, f: MetricObserverFunction) -> Self { + pub fn new(id: String, parent: Arc, f: MetricObserverFunction) -> Self { MetricObserverExec { parent, - anchor, + id, metrics: ExecutionPlanMetricsSet::new(), update: f, } } - pub fn anchor(&self) -> &str { - &self.anchor + pub fn try_new(id: String, inputs: &[Arc], f: MetricObserverFunction) -> DataFusionResult> { + match inputs { + [input] => Ok(Arc::new(MetricObserverExec::new( + id, + input.clone(), + f + ))), + _ => Err(datafusion_common::DataFusionError::External(Box::new( + DeltaTableError::Generic("MetricObserverExec expects only one child".into()), + ))), + } + } + + pub fn id(&self) -> &str { + &self.id } } impl std::fmt::Debug for MetricObserverExec { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("MetricObserverExec") - .field("anchor", &self.anchor) + .field("id", &self.id) .field("metrics", &self.metrics) .finish() } @@ -50,7 +69,7 @@ impl DisplayAs for MetricObserverExec { _: datafusion::physical_plan::DisplayFormatType, f: &mut std::fmt::Formatter, ) -> std::fmt::Result { - write!(f, "MetricObserverExec anchor={}", self.anchor) + write!(f, "MetricObserverExec id={}", self.id) } } @@ -97,12 +116,7 @@ impl ExecutionPlan for MetricObserverExec { self: Arc, children: Vec>, ) -> datafusion_common::Result> { - //TODO: Error on multiple children - Ok(Arc::new(MetricObserverExec::new( - self.anchor.clone(), - children.get(0).unwrap().clone(), - self.update, - ))) + MetricObserverExec::try_new(self.id.clone(), &children, self.update) } fn metrics(&self) -> Option { @@ -143,3 +157,24 @@ impl RecordBatchStream for MetricObserverStream { self.schema.clone() } } + +pub(crate) fn find_metric_node( + id: &str, + parent: &Arc, +) -> Option> { + //! Used to locate the physical MetricCountExec Node after the planner converts the logical node + if let Some(metric) = parent.as_any().downcast_ref::() { + if metric.id().eq(id) { + return Some(parent.to_owned()); + } + } + + for child in &parent.children() { + let res = find_metric_node(id, child); + if res.is_some() { + return res; + } + } + + None +} diff --git a/rust/src/operations/merge.rs b/rust/src/operations/merge.rs index 9c3050650e..1009ce6315 100644 --- a/rust/src/operations/merge.rs +++ b/rust/src/operations/merge.rs @@ -8,8 +8,7 @@ //! specified matter. See [`MergeBuilder`] for more information //! //! *WARNING* The current implementation rewrites the entire delta table so only -//! use on small to medium sized tables. The solution also cannot take advantage -//! of multiple threads and is limited to a single single thread. +//! use on small to medium sized tables. //! Enhancements tracked at #850 //! //! # Example @@ -62,7 +61,7 @@ use super::datafusion_utils::{into_expr, maybe_into_expr, Expression}; use super::transaction::commit; use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression}; use crate::delta_datafusion::logical::MetricObserver; -use crate::delta_datafusion::physical::MetricObserverExec; +use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec}; use crate::delta_datafusion::register_store; use crate::delta_datafusion::{DeltaScanConfig, DeltaTableProvider}; use crate::{ @@ -573,59 +572,59 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner { physical_inputs: &[Arc], _session_state: &SessionState, ) -> DataFusionResult>> { - let metric_observer = node.as_any().downcast_ref::().unwrap(); - - if metric_observer.anchor.eq(SOURCE_COUNT_ID) { - return Ok(Some(Arc::new(MetricObserverExec::new( - SOURCE_COUNT_ID.into(), - physical_inputs[0].clone(), - |batch, metrics| { - MetricBuilder::new(metrics) - .global_counter(SOURCE_COUNT_METRIC) - .add(batch.num_rows()); - }, - )))); - } + if let Some(metric_observer) = node.as_any().downcast_ref::() { + if metric_observer.id.eq(SOURCE_COUNT_ID) { + return Ok(Some(MetricObserverExec::try_new( + SOURCE_COUNT_ID.into(), + physical_inputs, + |batch, metrics| { + MetricBuilder::new(metrics) + .global_counter(SOURCE_COUNT_METRIC) + .add(batch.num_rows()); + }, + )?)); + } - if metric_observer.anchor.eq(TARGET_COUNT_ID) { - return Ok(Some(Arc::new(MetricObserverExec::new( - TARGET_COUNT_ID.into(), - physical_inputs[0].clone(), - |batch, metrics| { - MetricBuilder::new(metrics) - .global_counter(TARGET_INSERTED_METRIC) - .add( - batch - .column_by_name(TARGET_INSERT_COLUMN) - .unwrap() - .null_count(), - ); - MetricBuilder::new(metrics) - .global_counter(TARGET_UPDATED_METRIC) - .add( - batch - .column_by_name(TARGET_UPDATE_COLUMN) - .unwrap() - .null_count(), - ); - MetricBuilder::new(metrics) - .global_counter(TARGET_DELETED_METRIC) - .add( - batch - .column_by_name(TARGET_DELETE_COLUMN) - .unwrap() - .null_count(), - ); - MetricBuilder::new(metrics) - .global_counter(TARGET_COPY_METRIC) - .add( - batch - .column_by_name(TARGET_COPY_COLUMN) - .unwrap() - .null_count(), - ); - }, - )))); + if metric_observer.id.eq(TARGET_COUNT_ID) { + return Ok(Some(MetricObserverExec::try_new( + TARGET_COUNT_ID.into(), + physical_inputs, + |batch, metrics| { + MetricBuilder::new(metrics) + .global_counter(TARGET_INSERTED_METRIC) + .add( + batch + .column_by_name(TARGET_INSERT_COLUMN) + .unwrap() + .null_count(), + ); + MetricBuilder::new(metrics) + .global_counter(TARGET_UPDATED_METRIC) + .add( + batch + .column_by_name(TARGET_UPDATE_COLUMN) + .unwrap() + .null_count(), + ); + MetricBuilder::new(metrics) + .global_counter(TARGET_DELETED_METRIC) + .add( + batch + .column_by_name(TARGET_DELETE_COLUMN) + .unwrap() + .null_count(), + ); + MetricBuilder::new(metrics) + .global_counter(TARGET_COPY_METRIC) + .add( + batch + .column_by_name(TARGET_COPY_COLUMN) + .unwrap() + .null_count(), + ); + }, + )?)); + } } Ok(None) @@ -681,7 +680,7 @@ async fn execute( let source = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { - anchor: SOURCE_COUNT_ID.into(), + id: SOURCE_COUNT_ID.into(), input: source, }), }); @@ -702,7 +701,7 @@ async fn execute( let source_schema = source.schema(); let target_schema = target.schema(); - let join_schema_df = build_join_schema(&source_schema, &target_schema, &JoinType::Full)?; + let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?; let predicate = match predicate { Expression::DataFusion(expr) => expr, Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?, @@ -959,7 +958,7 @@ async fn execute( let new_columns = new_columns.into_optimized_plan()?; let operation_count = LogicalPlan::Extension(Extension { node: Arc::new(MetricObserver { - anchor: TARGET_COUNT_ID.into(), + id: TARGET_COUNT_ID.into(), input: new_columns, }), }); @@ -968,35 +967,15 @@ async fn execute( let filtered = operation_count.filter(col(DELETE_COLUMN).is_false())?; let project = filtered.select(write_projection)?; + let optimized = &project.into_optimized_plan()?; + dbg!("{:?}", &optimized); let state = state.with_query_planner(Arc::new(MergePlanner {})); - let write = state - .create_physical_plan(&project.into_unoptimized_plan()) - .await?; + let write = state.create_physical_plan(optimized).await?; - fn find_metric_node( - anchor: &str, - parent: &Arc, - ) -> Option> { - if let Some(metric) = parent.as_any().downcast_ref::() { - if metric.anchor().eq(anchor) { - return Some(parent.to_owned()); - } - } - - for child in &parent.children() { - let res = find_metric_node(anchor, child); - if res.is_some() { - return res; - } - } - - return None; - } - //Find the count nodes.. - //TODO: don't unwrap... - let source_count = find_metric_node(SOURCE_COUNT_ID, &write).unwrap(); - let op_count = find_metric_node(TARGET_COUNT_ID, &write).unwrap(); + let err = || DeltaTableError::Generic("Unable to locate expected metric node".into()); + let source_count = find_metric_node(SOURCE_COUNT_ID, &write).ok_or_else(err)?; + let op_count = find_metric_node(TARGET_COUNT_ID, &write).ok_or_else(err)?; // write projected records let table_partition_cols = current_metadata.partition_columns.clone();