Skip to content

Commit

Permalink
remove panic sources
Browse files Browse the repository at this point in the history
  • Loading branch information
Blajda committed Oct 13, 2023
1 parent d0d0e63 commit 9ff3057
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 103 deletions.
16 changes: 10 additions & 6 deletions rust/src/delta_datafusion/logical.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,22 @@
//! Logical Operations for DataFusion
use datafusion_expr::{LogicalPlan, UserDefinedLogicalNodeCore};
pub const METRIC_OBSERVER_NAME: &str = "MetricObserver";

// Metric Observer is used to update DataFusion metrics from a record batch.
// See MetricObserverExec for the physical implementation

#[derive(Debug, Hash, Eq, PartialEq)]
pub(crate) struct MetricObserver {
// This acts as an anchor when converting a to physical operator
pub anchor: String,
// id is preserved during conversion to physical node
pub id: String,
pub input: LogicalPlan,
}

impl UserDefinedLogicalNodeCore for MetricObserver {
// Predicate push down is not supported for this node. Try to limit usage
// near the end of plan.
fn name(&self) -> &str {
METRIC_OBSERVER_NAME
"MetricObserver"
}

fn inputs(&self) -> Vec<&datafusion_expr::LogicalPlan> {
Expand All @@ -28,7 +32,7 @@ impl UserDefinedLogicalNodeCore for MetricObserver {
}

fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "MetricObserver name={}", &self.anchor)
write!(f, "MetricObserver id={}", &self.id)
}

fn from_template(
Expand All @@ -37,7 +41,7 @@ impl UserDefinedLogicalNodeCore for MetricObserver {
inputs: &[datafusion_expr::LogicalPlan],
) -> Self {
MetricObserver {
anchor: self.anchor.clone(),
id: self.id.clone(),
input: inputs[0].clone(),
}
}
Expand Down
61 changes: 48 additions & 13 deletions rust/src/delta_datafusion/physical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,34 +11,53 @@ use datafusion::physical_plan::{
};
use futures::{Stream, StreamExt};

use crate::DeltaTableError;

// Metric Observer is used to update DataFusion metrics from a record batch.
// Typically the null count for a particular column is pulled after performing a
// projection since this count is easy to obtain

pub(crate) type MetricObserverFunction = fn(&RecordBatch, &ExecutionPlanMetricsSet) -> ();

pub(crate) struct MetricObserverExec {
parent: Arc<dyn ExecutionPlan>,
anchor: String,
id: String,
metrics: ExecutionPlanMetricsSet,
update: MetricObserverFunction,
}

impl MetricObserverExec {
pub fn new(anchor: String, parent: Arc<dyn ExecutionPlan>, f: MetricObserverFunction) -> Self {
pub fn new(id: String, parent: Arc<dyn ExecutionPlan>, f: MetricObserverFunction) -> Self {
MetricObserverExec {
parent,
anchor,
id,
metrics: ExecutionPlanMetricsSet::new(),
update: f,
}
}

pub fn anchor(&self) -> &str {
&self.anchor
pub fn try_new(id: String, inputs: &[Arc<dyn ExecutionPlan>], f: MetricObserverFunction) -> DataFusionResult<Arc<dyn ExecutionPlan>> {
match inputs {
[input] => Ok(Arc::new(MetricObserverExec::new(
id,
input.clone(),
f
))),
_ => Err(datafusion_common::DataFusionError::External(Box::new(
DeltaTableError::Generic("MetricObserverExec expects only one child".into()),
))),
}
}

pub fn id(&self) -> &str {
&self.id
}
}

impl std::fmt::Debug for MetricObserverExec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MetricObserverExec")
.field("anchor", &self.anchor)
.field("id", &self.id)
.field("metrics", &self.metrics)
.finish()
}
Expand All @@ -50,7 +69,7 @@ impl DisplayAs for MetricObserverExec {
_: datafusion::physical_plan::DisplayFormatType,
f: &mut std::fmt::Formatter,
) -> std::fmt::Result {
write!(f, "MetricObserverExec anchor={}", self.anchor)
write!(f, "MetricObserverExec id={}", self.id)
}
}

Expand Down Expand Up @@ -97,12 +116,7 @@ impl ExecutionPlan for MetricObserverExec {
self: Arc<Self>,
children: Vec<Arc<dyn ExecutionPlan>>,
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
//TODO: Error on multiple children
Ok(Arc::new(MetricObserverExec::new(
self.anchor.clone(),
children.get(0).unwrap().clone(),
self.update,
)))
MetricObserverExec::try_new(self.id.clone(), &children, self.update)
}

fn metrics(&self) -> Option<MetricsSet> {
Expand Down Expand Up @@ -143,3 +157,24 @@ impl RecordBatchStream for MetricObserverStream {
self.schema.clone()
}
}

pub(crate) fn find_metric_node(
id: &str,
parent: &Arc<dyn ExecutionPlan>,
) -> Option<Arc<dyn ExecutionPlan>> {
//! Used to locate the physical MetricCountExec Node after the planner converts the logical node
if let Some(metric) = parent.as_any().downcast_ref::<MetricObserverExec>() {
if metric.id().eq(id) {
return Some(parent.to_owned());
}
}

for child in &parent.children() {
let res = find_metric_node(id, child);
if res.is_some() {
return res;
}
}

None
}
147 changes: 63 additions & 84 deletions rust/src/operations/merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
//! specified matter. See [`MergeBuilder`] for more information
//!
//! *WARNING* The current implementation rewrites the entire delta table so only
//! use on small to medium sized tables. The solution also cannot take advantage
//! of multiple threads and is limited to a single single thread.
//! use on small to medium sized tables.
//! Enhancements tracked at #850
//!
//! # Example
Expand Down Expand Up @@ -62,7 +61,7 @@ use super::datafusion_utils::{into_expr, maybe_into_expr, Expression};
use super::transaction::commit;
use crate::delta_datafusion::expr::{fmt_expr_to_sql, parse_predicate_expression};
use crate::delta_datafusion::logical::MetricObserver;
use crate::delta_datafusion::physical::MetricObserverExec;
use crate::delta_datafusion::physical::{find_metric_node, MetricObserverExec};
use crate::delta_datafusion::register_store;
use crate::delta_datafusion::{DeltaScanConfig, DeltaTableProvider};
use crate::{
Expand Down Expand Up @@ -573,59 +572,59 @@ impl ExtensionPlanner for MergeMetricExtensionPlanner {
physical_inputs: &[Arc<dyn ExecutionPlan>],
_session_state: &SessionState,
) -> DataFusionResult<Option<Arc<dyn ExecutionPlan>>> {
let metric_observer = node.as_any().downcast_ref::<MetricObserver>().unwrap();

if metric_observer.anchor.eq(SOURCE_COUNT_ID) {
return Ok(Some(Arc::new(MetricObserverExec::new(
SOURCE_COUNT_ID.into(),
physical_inputs[0].clone(),
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(SOURCE_COUNT_METRIC)
.add(batch.num_rows());
},
))));
}
if let Some(metric_observer) = node.as_any().downcast_ref::<MetricObserver>() {
if metric_observer.id.eq(SOURCE_COUNT_ID) {
return Ok(Some(MetricObserverExec::try_new(
SOURCE_COUNT_ID.into(),
physical_inputs,
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(SOURCE_COUNT_METRIC)
.add(batch.num_rows());
},
)?));
}

if metric_observer.anchor.eq(TARGET_COUNT_ID) {
return Ok(Some(Arc::new(MetricObserverExec::new(
TARGET_COUNT_ID.into(),
physical_inputs[0].clone(),
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(TARGET_INSERTED_METRIC)
.add(
batch
.column_by_name(TARGET_INSERT_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_UPDATED_METRIC)
.add(
batch
.column_by_name(TARGET_UPDATE_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_DELETED_METRIC)
.add(
batch
.column_by_name(TARGET_DELETE_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_COPY_METRIC)
.add(
batch
.column_by_name(TARGET_COPY_COLUMN)
.unwrap()
.null_count(),
);
},
))));
if metric_observer.id.eq(TARGET_COUNT_ID) {
return Ok(Some(MetricObserverExec::try_new(
TARGET_COUNT_ID.into(),
physical_inputs,
|batch, metrics| {
MetricBuilder::new(metrics)
.global_counter(TARGET_INSERTED_METRIC)
.add(
batch
.column_by_name(TARGET_INSERT_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_UPDATED_METRIC)
.add(
batch
.column_by_name(TARGET_UPDATE_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_DELETED_METRIC)
.add(
batch
.column_by_name(TARGET_DELETE_COLUMN)
.unwrap()
.null_count(),
);
MetricBuilder::new(metrics)
.global_counter(TARGET_COPY_METRIC)
.add(
batch
.column_by_name(TARGET_COPY_COLUMN)
.unwrap()
.null_count(),
);
},
)?));
}
}

Ok(None)
Expand Down Expand Up @@ -681,7 +680,7 @@ async fn execute(

let source = LogicalPlan::Extension(Extension {
node: Arc::new(MetricObserver {
anchor: SOURCE_COUNT_ID.into(),
id: SOURCE_COUNT_ID.into(),
input: source,
}),
});
Expand All @@ -702,7 +701,7 @@ async fn execute(

let source_schema = source.schema();
let target_schema = target.schema();
let join_schema_df = build_join_schema(&source_schema, &target_schema, &JoinType::Full)?;
let join_schema_df = build_join_schema(source_schema, target_schema, &JoinType::Full)?;
let predicate = match predicate {
Expression::DataFusion(expr) => expr,
Expression::String(s) => parse_predicate_expression(&join_schema_df, s, &state)?,
Expand Down Expand Up @@ -959,7 +958,7 @@ async fn execute(
let new_columns = new_columns.into_optimized_plan()?;
let operation_count = LogicalPlan::Extension(Extension {
node: Arc::new(MetricObserver {
anchor: TARGET_COUNT_ID.into(),
id: TARGET_COUNT_ID.into(),
input: new_columns,
}),
});
Expand All @@ -968,35 +967,15 @@ async fn execute(
let filtered = operation_count.filter(col(DELETE_COLUMN).is_false())?;

let project = filtered.select(write_projection)?;
let optimized = &project.into_optimized_plan()?;
dbg!("{:?}", &optimized);

let state = state.with_query_planner(Arc::new(MergePlanner {}));
let write = state
.create_physical_plan(&project.into_unoptimized_plan())
.await?;
let write = state.create_physical_plan(optimized).await?;

fn find_metric_node(
anchor: &str,
parent: &Arc<dyn ExecutionPlan>,
) -> Option<Arc<dyn ExecutionPlan>> {
if let Some(metric) = parent.as_any().downcast_ref::<MetricObserverExec>() {
if metric.anchor().eq(anchor) {
return Some(parent.to_owned());
}
}

for child in &parent.children() {
let res = find_metric_node(anchor, child);
if res.is_some() {
return res;
}
}

return None;
}
//Find the count nodes..
//TODO: don't unwrap...
let source_count = find_metric_node(SOURCE_COUNT_ID, &write).unwrap();
let op_count = find_metric_node(TARGET_COUNT_ID, &write).unwrap();
let err = || DeltaTableError::Generic("Unable to locate expected metric node".into());
let source_count = find_metric_node(SOURCE_COUNT_ID, &write).ok_or_else(err)?;
let op_count = find_metric_node(TARGET_COUNT_ID, &write).ok_or_else(err)?;

// write projected records
let table_partition_cols = current_metadata.partition_columns.clone();
Expand Down

0 comments on commit 9ff3057

Please sign in to comment.