From 04491a6dfc32001a2b3ac8d5d0be6ed350c46d8b Mon Sep 17 00:00:00 2001 From: Victor Barua Date: Thu, 5 Dec 2024 10:06:14 +0000 Subject: [PATCH] feat(substrait): modular substrait consumer --- .../substrait/src/logical_plan/consumer.rs | 2325 ++++++++++------- .../substrait/src/logical_plan/producer.rs | 13 +- .../tests/cases/roundtrip_logical_plan.rs | 59 +- datafusion/substrait/tests/utils.rs | 27 +- 4 files changed, 1434 insertions(+), 990 deletions(-) diff --git a/datafusion/substrait/src/logical_plan/consumer.rs b/datafusion/substrait/src/logical_plan/consumer.rs index a9e411e35ae88..bc3215fae85ec 100644 --- a/datafusion/substrait/src/logical_plan/consumer.rs +++ b/datafusion/substrait/src/logical_plan/consumer.rs @@ -21,23 +21,24 @@ use datafusion::arrow::array::{GenericListArray, MapArray}; use datafusion::arrow::datatypes::{ DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, }; -use datafusion::common::plan_err; use datafusion::common::{ - not_impl_err, plan_datafusion_err, substrait_datafusion_err, substrait_err, DFSchema, - DFSchemaRef, + not_impl_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, + substrait_datafusion_err, substrait_err, DFSchema, DFSchemaRef, }; use datafusion::datasource::provider_as_source; use datafusion::logical_expr::expr::{Exists, InSubquery, Sort}; use datafusion::logical_expr::{ - Aggregate, BinaryExpr, Case, EmptyRelation, Expr, ExprSchemable, LogicalPlan, - Operator, Projection, SortExpr, TryCast, Values, + Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, LogicalPlan, + Operator, Projection, SortExpr, Subquery, TryCast, Values, }; use substrait::proto::aggregate_rel::Grouping; +use substrait::proto::expression as substrait_expression; use substrait::proto::expression::subquery::set_predicate::PredicateOp; use substrait::proto::expression_reference::ExprType; use url::Url; +use super::state::SubstraitPlanningState; use crate::extensions::Extensions; use crate::variation_const::{ DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF, @@ -53,16 +54,19 @@ use crate::variation_const::{ TIMESTAMP_MICRO_TYPE_VARIATION_REF, TIMESTAMP_MILLI_TYPE_VARIATION_REF, TIMESTAMP_NANO_TYPE_VARIATION_REF, TIMESTAMP_SECOND_TYPE_VARIATION_REF, }; +use async_trait::async_trait; use datafusion::arrow::array::{new_empty_array, AsArray}; use datafusion::arrow::temporal_conversions::NANOSECONDS; +use datafusion::catalog::TableProvider; use datafusion::common::scalar::ScalarStructBuilder; +use datafusion::execution::{FunctionRegistry, SessionState}; use datafusion::logical_expr::builder::project; use datafusion::logical_expr::expr::InList; use datafusion::logical_expr::{ - col, expr, Cast, Extension, GroupingSet, Like, LogicalPlanBuilder, Partitioning, - Repartition, Subquery, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, + col, expr, GroupingSet, Like, LogicalPlanBuilder, Partitioning, Repartition, + WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition, }; -use datafusion::prelude::JoinType; +use datafusion::prelude::{JoinType, SessionContext}; use datafusion::sql::TableReference; use datafusion::{ error::Result, logical_expr::utils::split_conjunction, prelude::Column, @@ -70,17 +74,21 @@ use datafusion::{ }; use std::collections::HashSet; use std::sync::Arc; +use substrait::proto; use substrait::proto::exchange_rel::ExchangeKind; use substrait::proto::expression::cast::FailureBehavior::ReturnNull; use substrait::proto::expression::literal::user_defined::Val; use substrait::proto::expression::literal::{ interval_day_to_second, IntervalCompound, IntervalDayToSecond, IntervalYearToMonth, - UserDefined, }; use substrait::proto::expression::subquery::SubqueryType; -use substrait::proto::expression::{FieldReference, Literal, ScalarFunction}; +use substrait::proto::expression::{ + Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction, + SingularOrList, SwitchExpression, WindowFunction, +}; use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile; use substrait::proto::rel_common::{Emit, EmitKind}; +use substrait::proto::set_rel::SetOp; use substrait::proto::{ aggregate_function::AggregationInvocation, expression::{ @@ -94,13 +102,406 @@ use substrait::proto::{ join_rel, plan_rel, r#type, read_rel::ReadType, rel::RelType, - rel_common, set_rel, + rel_common, sort_field::{SortDirection, SortKind::*}, - AggregateFunction, Expression, NamedStruct, Plan, Rel, RelCommon, Type, + AggregateFunction, AggregateRel, ConsistentPartitionWindowRel, CrossRel, ExchangeRel, + Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel, + FilterRel, JoinRel, NamedStruct, Plan, ProjectRel, ReadRel, Rel, RelCommon, SetRel, + SortRel, Type, }; use substrait::proto::{ExtendedExpression, FunctionArgument, SortField}; -use super::state::SubstraitPlanningState; +#[async_trait] +/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// # Example Usage +/// +/// ``` +/// use async_trait::async_trait; +/// use datafusion::catalog::TableProvider; +/// use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference}; +/// use datafusion::error::Result; +/// use datafusion::execution::SessionState; +/// use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +/// use std::sync::Arc; +/// use substrait::proto; +/// use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel}; +/// use datafusion::arrow::datatypes::DataType; +/// use datafusion::logical_expr::expr::ScalarFunction; +/// use datafusion_substrait::extensions::Extensions; +/// use datafusion_substrait::logical_plan::consumer::{ +/// from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer +/// }; +/// +/// use datafusion_substrait::logical_plan::state::SubstraitPlanningState; +/// +/// struct CustomSubstraitConsumer { +/// extensions: Arc, +/// state: Arc, +/// } +/// +/// #[async_trait] +/// impl SubstraitConsumer for CustomSubstraitConsumer { +/// async fn resolve_table_ref( +/// &self, +/// table_ref: &TableReference, +/// ) -> Result>> { +/// self.state.table(table_ref).await +/// } +/// +/// fn get_extensions(&self) -> &Extensions { +/// self.extensions.as_ref() +/// } +/// +/// fn get_state(&self) -> &SessionState { +/// self.state.as_ref() +/// } +/// +/// // You can reuse existing consumer code to assist in handling advanced extensions +/// async fn consume_project(&self, rel: &ProjectRel) -> Result { +/// let df_plan = from_project_rel(self, rel).await?; +/// if let Some(advanced_extension) = rel.advanced_extension.as_ref() { +/// not_impl_err!( +/// "decode and handle an advanced extension: {:?}", +/// advanced_extension +/// ) +/// } else { +/// Ok(df_plan) +/// } +/// } +/// +/// // You can implement a fully custom consumer method if you need special handling +/// async fn consume_filter(&self, rel: &FilterRel) -> Result { +/// let input = from_substrait_rel(self, rel.input.as_ref().unwrap()).await?; +/// let expression = +/// from_substrait_rex(self, rel.condition.as_ref().unwrap(), input.schema()) +/// .await?; +/// // though this one is quite boring +/// LogicalPlanBuilder::from(input).filter(expression)?.build() +/// } +/// +/// // You can add handlers for extension relations +/// async fn consume_extension_leaf( +/// &self, +/// rel: &ExtensionLeafRel, +/// ) -> Result { +/// not_impl_err!( +/// "handle protobuf Any {} as you need", +/// rel.detail.as_ref().unwrap().type_url +/// ) +/// } +/// +/// // and handlers for user-define types +/// async fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&typ.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// +/// // and user-defined literals +/// async fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result { +/// let type_string = self.extensions.types.get(&literal.type_reference).unwrap(); +/// match type_string.as_str() { +/// "u!foo" => not_impl_err!("handle foo conversion"), +/// "u!bar" => not_impl_err!("handle bar conversion"), +/// _ => substrait_err!("unexpected type") +/// } +/// } +/// } +/// ``` +/// +pub trait SubstraitConsumer: Send + Sync + Sized { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> Result>>; + + // TODO: Remove these two methods + // Ideally, the abstract consumer should not place any constraints on implementations. + // The functionality for which the Extensions and SessionState is needed should be abstracted + // out into methods on the trait. As an example, resolve_table_reference is such a method. + fn get_extensions(&self) -> &Extensions; + fn get_state(&self) -> &SessionState; + + // Relation Methods + // There is one method per Substrait relation to allow for easy overriding of consumer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + async fn consume_read(&self, rel: &ReadRel) -> Result { + from_read_rel(self, rel).await + } + + async fn consume_filter(&self, rel: &FilterRel) -> Result { + from_filter_rel(self, rel).await + } + + async fn consume_fetch(&self, rel: &FetchRel) -> Result { + from_fetch_rel(self, rel).await + } + + async fn consume_aggregate(&self, rel: &AggregateRel) -> Result { + from_aggregate_rel(self, rel).await + } + + async fn consume_sort(&self, rel: &SortRel) -> Result { + from_sort_rel(self, rel).await + } + + async fn consume_join(&self, rel: &JoinRel) -> Result { + from_join_rel(self, rel).await + } + + async fn consume_project(&self, rel: &ProjectRel) -> Result { + from_project_rel(self, rel).await + } + + async fn consume_set(&self, rel: &SetRel) -> Result { + from_set_rel(self, rel).await + } + + async fn consume_cross(&self, rel: &CrossRel) -> Result { + from_cross_rel(self, rel).await + } + + async fn consume_consistent_partition_window( + &self, + _rel: &ConsistentPartitionWindowRel, + ) -> Result { + not_impl_err!("Consistent Partition Window Rel not supported") + } + + async fn consume_exchange(&self, rel: &ExchangeRel) -> Result { + from_exchange_rel(self, rel).await + } + + // Expression Methods + // There is one method per Substrait expression to allow for easy overriding of consumer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + async fn consume_literal(&self, expr: &Literal) -> Result { + from_literal(self, expr).await + } + + async fn consume_selection( + &self, + expr: &FieldReference, + input_schema: &DFSchema, + ) -> Result { + from_selection(self, expr, input_schema).await + } + + async fn consume_scalar_function( + &self, + expr: &ScalarFunction, + input_schema: &DFSchema, + ) -> Result { + from_scalar_function(self, expr, input_schema).await + } + + async fn consume_window_function( + &self, + expr: &WindowFunction, + input_schema: &DFSchema, + ) -> Result { + from_window_function(self, expr, input_schema).await + } + + async fn consume_if_then( + &self, + expr: &IfThen, + input_schema: &DFSchema, + ) -> Result { + from_if_then(self, expr, input_schema).await + } + + async fn consume_switch( + &self, + _expr: &SwitchExpression, + _input_schema: &DFSchema, + ) -> Result { + not_impl_err!("Switch expression not supported") + } + + async fn consume_singular_or_list( + &self, + expr: &SingularOrList, + input_schema: &DFSchema, + ) -> Result { + from_singular_or_list(self, expr, input_schema).await + } + + async fn consume_multi_or_list( + &self, + _expr: &MultiOrList, + _input_schema: &DFSchema, + ) -> Result { + not_impl_err!("Multi Or List expression not supported") + } + + async fn consume_cast( + &self, + expr: &substrait_expression::Cast, + input_schema: &DFSchema, + ) -> Result { + from_cast(self, expr, input_schema).await + } + + async fn consume_subquery( + &self, + expr: &substrait_expression::Subquery, + input_schema: &DFSchema, + ) -> Result { + from_subquery(self, expr, input_schema).await + } + + async fn consume_nested( + &self, + _expr: &Nested, + _input_schema: &DFSchema, + ) -> Result { + not_impl_err!("Nested expression not supported") + } + + async fn consume_enum(&self, _expr: &Enum, _input_schema: &DFSchema) -> Result { + not_impl_err!("Enum expression not supported") + } + + // User-Defined Functionality + + // The details of extension relations, and how to handle them, are fully up to users to specify. + // The following methods allow users to customize the consumer behaviour + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionLeafRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionLeafRel") + } + + async fn consume_extension_single( + &self, + rel: &ExtensionSingleRel, + ) -> Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionSingleRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionSingleRel") + } + + async fn consume_extension_multi( + &self, + rel: &ExtensionMultiRel, + ) -> Result { + if let Some(detail) = rel.detail.as_ref() { + return substrait_err!( + "Missing handler for ExtensionMultiRel: {}", + detail.type_url + ); + } + substrait_err!("Missing handler for ExtensionMultiRel") + } + + // Users can bring their own types to Substrait which require custom handling + + fn consume_user_defined_type( + &self, + user_defined_type: &r#type::UserDefined, + ) -> Result { + substrait_err!( + "Missing handler for user-defined type: {}", + user_defined_type.type_reference + ) + } + + fn consume_user_defined_literal( + &self, + user_defined_literal: &proto::expression::literal::UserDefined, + ) -> Result { + substrait_err!( + "Missing handler for user-defined literals {}", + user_defined_literal.type_reference + ) + } +} + +/// Convert Substrait Rel to DataFusion DataFrame +#[async_recursion] +pub async fn from_substrait_rel( + consumer: &impl SubstraitConsumer, + relation: &Rel, +) -> Result { + let plan: Result = match &relation.rel_type { + Some(rel_type) => match rel_type { + RelType::Read(rel) => consumer.consume_read(rel).await, + RelType::Filter(rel) => consumer.consume_filter(rel).await, + RelType::Fetch(rel) => consumer.consume_fetch(rel).await, + RelType::Aggregate(rel) => consumer.consume_aggregate(rel).await, + RelType::Sort(rel) => consumer.consume_sort(rel).await, + RelType::Join(rel) => consumer.consume_join(rel).await, + RelType::Project(rel) => consumer.consume_project(rel).await, + RelType::Set(rel) => consumer.consume_set(rel).await, + RelType::ExtensionSingle(rel) => consumer.consume_extension_single(rel).await, + RelType::ExtensionMulti(rel) => consumer.consume_extension_multi(rel).await, + RelType::ExtensionLeaf(rel) => consumer.consume_extension_leaf(rel).await, + RelType::Cross(rel) => consumer.consume_cross(rel).await, + RelType::Window(rel) => { + consumer.consume_consistent_partition_window(rel).await + } + RelType::Exchange(rel) => consumer.consume_exchange(rel).await, + rt => not_impl_err!("{rt:?} rel not supported yet"), + }, + None => return substrait_err!("rel must set rel_type"), + }; + apply_emit_kind(retrieve_rel_common(relation), plan?) +} + +/// Can be used to consume standard Substrait without user-defined extensions +pub struct DefaultSubstraitConsumer { + extensions: Arc, + state: Arc, +} + +impl Default for DefaultSubstraitConsumer { + fn default() -> Self { + DefaultSubstraitConsumer { + extensions: Arc::new(Extensions::default()), + state: Arc::new(SessionContext::default().state()), + } + } +} + +#[async_trait] +impl SubstraitConsumer for DefaultSubstraitConsumer { + async fn resolve_table_ref( + &self, + table_ref: &TableReference, + ) -> Result>> { + self.state.table(table_ref).await + } + + fn get_extensions(&self) -> &Extensions { + self.extensions.as_ref() + } + + fn get_state(&self) -> &SessionState { + self.state.as_ref() + } +} // Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which // is the same as the expectation for any non-empty timezone in DF, so any non-empty timezone @@ -202,16 +603,15 @@ fn split_eq_and_noneq_join_predicate_with_nulls_equality( } async fn union_rels( + consumer: &impl SubstraitConsumer, rels: &[Rel], - state: &dyn SubstraitPlanningState, - extensions: &Extensions, is_all: bool, ) -> Result { let mut union_builder = Ok(LogicalPlanBuilder::from( - from_substrait_rel(state, &rels[0], extensions).await?, + from_substrait_rel(consumer, &rels[0]).await?, )); for input in &rels[1..] { - let rel_plan = from_substrait_rel(state, input, extensions).await?; + let rel_plan = from_substrait_rel(consumer, input).await?; union_builder = if is_all { union_builder?.union(rel_plan) @@ -223,17 +623,16 @@ async fn union_rels( } async fn intersect_rels( + consumer: &impl SubstraitConsumer, rels: &[Rel], - state: &dyn SubstraitPlanningState, - extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(consumer, &rels[0]).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::intersect( rel, - from_substrait_rel(state, input, extensions).await?, + from_substrait_rel(consumer, input).await?, is_all, )? } @@ -242,17 +641,16 @@ async fn intersect_rels( } async fn except_rels( + consumer: &impl SubstraitConsumer, rels: &[Rel], - state: &dyn SubstraitPlanningState, - extensions: &Extensions, is_all: bool, ) -> Result { - let mut rel = from_substrait_rel(state, &rels[0], extensions).await?; + let mut rel = from_substrait_rel(consumer, &rels[0]).await?; for input in &rels[1..] { rel = LogicalPlanBuilder::except( rel, - from_substrait_rel(state, input, extensions).await?, + from_substrait_rel(consumer, input).await?, is_all, )? } @@ -262,7 +660,25 @@ async fn except_rels( /// Convert Substrait Plan to DataFusion LogicalPlan pub async fn from_substrait_plan( - state: &dyn SubstraitPlanningState, + state: &SessionState, + plan: &Plan, +) -> Result { + // Register function extension + let extensions = Extensions::try_from(&plan.extensions)?; + if !extensions.type_variations.is_empty() { + return not_impl_err!("Type variation extensions are not supported"); + } + + let consumer = DefaultSubstraitConsumer { + extensions: Arc::new(extensions), + state: Arc::new(state.clone()), + }; + from_substrait_plan_with_consumer(&consumer, plan).await +} + +/// Convert Substrait Plan to DataFusion LogicalPlan using the given consumer +pub async fn from_substrait_plan_with_consumer( + consumer: &impl SubstraitConsumer, plan: &Plan, ) -> Result { // Register function extension @@ -277,10 +693,10 @@ pub async fn from_substrait_plan( match plan.relations[0].rel_type.as_ref() { Some(rt) => match rt { plan_rel::RelType::Rel(rel) => { - Ok(from_substrait_rel(state, rel, &extensions).await?) + Ok(from_substrait_rel(consumer, rel).await?) }, plan_rel::RelType::Root(root) => { - let plan = from_substrait_rel(state, root.input.as_ref().unwrap(), &extensions).await?; + let plan = from_substrait_rel(consumer, root.input.as_ref().unwrap()).await?; if root.names.is_empty() { // Backwards compatibility for plans missing names return Ok(plan); @@ -341,7 +757,7 @@ pub struct ExprContainer { /// between systems. This is often useful for scenarios like pushdown where filter /// expressions need to be sent to remote systems. pub async fn from_substrait_extended_expr( - state: &dyn SubstraitPlanningState, + state: &SessionState, extended_expr: &ExtendedExpression, ) -> Result { // Register function extension @@ -350,8 +766,13 @@ pub async fn from_substrait_extended_expr( return not_impl_err!("Type variation extensions are not supported"); } + let consumer = DefaultSubstraitConsumer { + extensions: Arc::new(extensions), + state: Arc::new(state.clone()), + }; + let input_schema = DFSchemaRef::new(match &extended_expr.base_schema { - Some(base_schema) => from_substrait_named_struct(base_schema, &extensions), + Some(base_schema) => from_substrait_named_struct(&consumer, base_schema), None => { plan_err!("required property `base_schema` missing from Substrait ExtendedExpression message") } @@ -369,8 +790,7 @@ pub async fn from_substrait_extended_expr( plan_err!("required property `expr_type` missing from Substrait ExpressionReference message") } }?; - let expr = - from_substrait_rex(state, scalar_expr, &input_schema, &extensions).await?; + let expr = from_substrait_rex(&consumer, scalar_expr, &input_schema).await?; let (output_type, expected_nullability) = expr.data_type_and_nullable(&input_schema)?; let output_field = Field::new("", output_type, expected_nullability); @@ -557,570 +977,510 @@ fn make_renamed_schema( ) } -/// Convert Substrait Rel to DataFusion DataFrame -#[allow(deprecated)] #[async_recursion] -pub async fn from_substrait_rel( - state: &dyn SubstraitPlanningState, - rel: &Rel, - extensions: &Extensions, +pub async fn from_project_rel( + consumer: &impl SubstraitConsumer, + p: &ProjectRel, ) -> Result { - let plan: Result = match &rel.rel_type { - Some(RelType::Project(p)) => { - if let Some(input) = p.input.as_ref() { - let mut input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let original_schema = input.schema().clone(); - - // Ensure that all expressions have a unique display name, so that - // validate_unique_names does not fail when constructing the project. - let mut name_tracker = NameTracker::new(); - - // By default, a Substrait Project emits all inputs fields followed by all expressions. - // We build the explicit expressions first, and then the input expressions to avoid - // adding aliases to the explicit expressions (as part of ensuring unique names). - // - // This is helpful for plan visualization and tests, because when DataFusion produces - // Substrait Projects it adds an output mapping that excludes all input columns - // leaving only explicit expressions. - - let mut explicit_exprs: Vec = vec![]; - for expr in &p.expressions { - let e = from_substrait_rex( - state, - expr, - input.clone().schema(), - extensions, - ) - .await?; - // if the expression is WindowFunction, wrap in a Window relation - if let Expr::WindowFunction(_) = &e { - // Adding the same expression here and in the project below - // works because the project's builder uses columnize_expr(..) - // to transform it into a column reference - input = input.window(vec![e.clone()])? - } - explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - - let mut final_exprs: Vec = vec![]; - for index in 0..original_schema.fields().len() { - let e = Expr::Column(Column::from( - original_schema.qualified_field(index), - )); - final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); - } - final_exprs.append(&mut explicit_exprs); - - input.project(final_exprs)?.build() - } else { - not_impl_err!("Projection without an input is not supported") + if let Some(input) = p.input.as_ref() { + let mut input = + LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let original_schema = input.schema().clone(); + + // Ensure that all expressions have a unique display name, so that + // validate_unique_names does not fail when constructing the project. + let mut name_tracker = NameTracker::new(); + + // By default, a Substrait Project emits all inputs fields followed by all expressions. + // We build the explicit expressions first, and then the input expressions to avoid + // adding aliases to the explicit expressions (as part of ensuring unique names). + // + // This is helpful for plan visualization and tests, because when DataFusion produces + // Substrait Projects it adds an output mapping that excludes all input columns + // leaving only explicit expressions. + + let mut explicit_exprs: Vec = vec![]; + for expr in &p.expressions { + let e = from_substrait_rex(consumer, expr, input.clone().schema()).await?; + // if the expression is WindowFunction, wrap in a Window relation + if let Expr::WindowFunction(_) = &e { + // Adding the same expression here and in the project below + // works because the project's builder uses columnize_expr(..) + // to transform it into a column reference + input = input.window(vec![e.clone()])? } + explicit_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } - Some(RelType::Filter(filter)) => { - if let Some(input) = filter.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - if let Some(condition) = filter.condition.as_ref() { - let expr = - from_substrait_rex(state, condition, input.schema(), extensions) - .await?; - input.filter(expr)?.build() - } else { - not_impl_err!("Filter without an condition is not valid") - } - } else { - not_impl_err!("Filter without an input is not valid") - } + + let mut final_exprs: Vec = vec![]; + for index in 0..original_schema.fields().len() { + let e = Expr::Column(Column::from(original_schema.qualified_field(index))); + final_exprs.push(name_tracker.get_uniquely_named_expr(e)?); } - Some(RelType::Fetch(fetch)) => { - if let Some(input) = fetch.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let offset = fetch.offset as usize; - // -1 means that ALL records should be returned - let count = if fetch.count == -1 { - None - } else { - Some(fetch.count as usize) - }; - input.limit(offset, count)?.build() - } else { - not_impl_err!("Fetch without an input is not valid") - } + final_exprs.append(&mut explicit_exprs); + input.project(final_exprs)?.build() + } else { + not_impl_err!("Projection without an input is not supported") + } +} + +#[async_recursion] +pub async fn from_filter_rel( + consumer: &impl SubstraitConsumer, + filter: &FilterRel, +) -> Result { + if let Some(input) = filter.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + if let Some(condition) = filter.condition.as_ref() { + let expr = from_substrait_rex(consumer, condition, input.schema()).await?; + input.filter(expr)?.build() + } else { + not_impl_err!("Filter without an condition is not valid") } - Some(RelType::Sort(sort)) => { - if let Some(input) = sort.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let sorts = - from_substrait_sorts(state, &sort.sorts, input.schema(), extensions) - .await?; - input.sort(sorts)?.build() - } else { - not_impl_err!("Sort without an input is not valid") - } + } else { + not_impl_err!("Filter without an input is not valid") + } +} + +#[async_recursion] +pub async fn from_fetch_rel( + consumer: &impl SubstraitConsumer, + fetch: &FetchRel, +) -> Result { + if let Some(input) = fetch.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let offset = fetch.offset as usize; + // -1 means that ALL records should be returned + let count = if fetch.count == -1 { + None + } else { + Some(fetch.count as usize) + }; + input.limit(offset, count)?.build() + } else { + not_impl_err!("Fetch without an input is not valid") + } +} + +pub async fn from_sort_rel( + consumer: &impl SubstraitConsumer, + sort: &SortRel, +) -> Result { + if let Some(input) = sort.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let sorts = from_substrait_sorts(consumer, &sort.sorts, input.schema()).await?; + input.sort(sorts)?.build() + } else { + not_impl_err!("Sort without an input is not valid") + } +} + +pub async fn from_aggregate_rel( + consumer: &impl SubstraitConsumer, + agg: &AggregateRel, +) -> Result { + if let Some(input) = agg.input.as_ref() { + let input = LogicalPlanBuilder::from(from_substrait_rel(consumer, input).await?); + let mut ref_group_exprs = vec![]; + + for e in &agg.grouping_expressions { + let x = from_substrait_rex(consumer, e, input.schema()).await?; + ref_group_exprs.push(x); } - Some(RelType::Aggregate(agg)) => { - if let Some(input) = agg.input.as_ref() { - let input = LogicalPlanBuilder::from( - from_substrait_rel(state, input, extensions).await?, - ); - let mut ref_group_exprs = vec![]; - for e in &agg.grouping_expressions { - let x = - from_substrait_rex(state, e, input.schema(), extensions).await?; - ref_group_exprs.push(x); + let mut group_exprs = vec![]; + let mut aggr_exprs = vec![]; + + match agg.groupings.len() { + 1 => { + group_exprs.extend_from_slice( + &from_substrait_grouping( + consumer, + &agg.groupings[0], + &ref_group_exprs, + input.schema(), + ) + .await?, + ); + } + _ => { + let mut grouping_sets = vec![]; + for grouping in &agg.groupings { + let grouping_set = from_substrait_grouping( + consumer, + grouping, + &ref_group_exprs, + input.schema(), + ) + .await?; + grouping_sets.push(grouping_set); } + // Single-element grouping expression of type Expr::GroupingSet. + // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when + // parsed by the producer and consumer, since Substrait does not have a type dedicated + // to ROLLUP. Only vector of Groupings (grouping sets) is available. + group_exprs + .push(Expr::GroupingSet(GroupingSet::GroupingSets(grouping_sets))); + } + }; - let mut group_exprs = vec![]; - let mut aggr_exprs = vec![]; - - match agg.groupings.len() { - 1 => { - group_exprs.extend_from_slice( - &from_substrait_grouping( - state, - &agg.groupings[0], - &ref_group_exprs, - input.schema(), - extensions, - ) - .await?, - ); - } - _ => { - let mut grouping_sets = vec![]; - for grouping in &agg.groupings { - let grouping_set = from_substrait_grouping( - state, - grouping, - &ref_group_exprs, - input.schema(), - extensions, - ) - .await?; - grouping_sets.push(grouping_set); + for m in &agg.measures { + let filter = match &m.filter { + Some(fil) => Some(Box::new( + from_substrait_rex(consumer, fil, input.schema()).await?, + )), + None => None, + }; + let agg_func = match &m.measure { + Some(f) => { + let distinct = match f.invocation { + _ if f.invocation == AggregationInvocation::Distinct as i32 => { + true } - // Single-element grouping expression of type Expr::GroupingSet. - // Note that GroupingSet::Rollup would become GroupingSet::GroupingSets, when - // parsed by the producer and consumer, since Substrait does not have a type dedicated - // to ROLLUP. Only vector of Groupings (grouping sets) is available. - group_exprs.push(Expr::GroupingSet(GroupingSet::GroupingSets( - grouping_sets, - ))); - } - }; - - for m in &agg.measures { - let filter = match &m.filter { - Some(fil) => Some(Box::new( - from_substrait_rex(state, fil, input.schema(), extensions) - .await?, - )), - None => None, + _ if f.invocation == AggregationInvocation::All as i32 => false, + _ => false, }; - let agg_func = match &m.measure { - Some(f) => { - let distinct = match f.invocation { - _ if f.invocation - == AggregationInvocation::Distinct as i32 => - { - true - } - _ if f.invocation - == AggregationInvocation::All as i32 => - { - false - } - _ => false, - }; - let order_by = if !f.sorts.is_empty() { - Some( - from_substrait_sorts( - state, - &f.sorts, - input.schema(), - extensions, - ) - .await?, - ) - } else { - None - }; - - from_substrait_agg_func( - state, - f, - input.schema(), - extensions, - filter, - order_by, - distinct, - ) - .await - } - None => not_impl_err!( - "Aggregate without aggregate function is not supported" - ), + let order_by = if !f.sorts.is_empty() { + Some( + from_substrait_sorts(consumer, &f.sorts, input.schema()) + .await?, + ) + } else { + None }; - aggr_exprs.push(agg_func?.as_ref().clone()); - } - input.aggregate(group_exprs, aggr_exprs)?.build() - } else { - not_impl_err!("Aggregate without an input is not valid") - } - } - Some(RelType::Join(join)) => { - if join.post_join_filter.is_some() { - return not_impl_err!( - "JoinRel with post_join_filter is not yet supported" - ); - } - let left: LogicalPlanBuilder = LogicalPlanBuilder::from( - from_substrait_rel(state, join.left.as_ref().unwrap(), extensions) - .await?, - ); - let right = LogicalPlanBuilder::from( - from_substrait_rel(state, join.right.as_ref().unwrap(), extensions) - .await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - - let join_type = from_substrait_jointype(join.r#type)?; - // The join condition expression needs full input schema and not the output schema from join since we lose columns from - // certain join types such as semi and anti joins - let in_join_schema = left.schema().join(right.schema())?; - - // If join expression exists, parse the `on` condition expression, build join and return - // Otherwise, build join with only the filter, without join keys - match &join.expression.as_ref() { - Some(expr) => { - let on = from_substrait_rex(state, expr, &in_join_schema, extensions) - .await?; - // The join expression can contain both equal and non-equal ops. - // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. - // So we extract each part as follows: - // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector - // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) - let (join_ons, nulls_equal_nulls, join_filter) = - split_eq_and_noneq_join_predicate_with_nulls_equality(&on); - let (left_cols, right_cols): (Vec<_>, Vec<_>) = - itertools::multiunzip(join_ons); - left.join_detailed( - right.build()?, - join_type, - (left_cols, right_cols), - join_filter, - nulls_equal_nulls, - )? - .build() + from_substrait_agg_func( + consumer, + f, + input.schema(), + filter, + order_by, + distinct, + ) + .await } None => { - let on: Vec = vec![]; - left.join_detailed( - right.build()?, - join_type, - (on.clone(), on), - None, - false, - )? - .build() + not_impl_err!("Aggregate without aggregate function is not supported") } - } + }; + aggr_exprs.push(agg_func?.as_ref().clone()); } - Some(RelType::Cross(cross)) => { - let left = LogicalPlanBuilder::from( - from_substrait_rel(state, cross.left.as_ref().unwrap(), extensions) - .await?, - ); - let right = LogicalPlanBuilder::from( - from_substrait_rel(state, cross.right.as_ref().unwrap(), extensions) - .await?, - ); - let (left, right) = requalify_sides_if_needed(left, right)?; - left.cross_join(right.build()?)?.build() + input.aggregate(group_exprs, aggr_exprs)?.build() + } else { + not_impl_err!("Aggregate without an input is not valid") + } +} + +pub async fn from_join_rel( + consumer: &impl SubstraitConsumer, + join: &JoinRel, +) -> Result { + if join.post_join_filter.is_some() { + return not_impl_err!("JoinRel with post_join_filter is not yet supported"); + } + + let left: LogicalPlanBuilder = LogicalPlanBuilder::from( + from_substrait_rel(consumer, join.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + from_substrait_rel(consumer, join.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + + let join_type = from_substrait_jointype(join.r#type)?; + // The join condition expression needs full input schema and not the output schema from join since we lose columns from + // certain join types such as semi and anti joins + let in_join_schema = left.schema().join(right.schema())?; + + // If join expression exists, parse the `on` condition expression, build join and return + // Otherwise, build join with only the filter, without join keys + match &join.expression.as_ref() { + Some(expr) => { + let on = from_substrait_rex(consumer, expr, &in_join_schema).await?; + // The join expression can contain both equal and non-equal ops. + // As of datafusion 31.0.0, the equal and non equal join conditions are in separate fields. + // So we extract each part as follows: + // - If an Eq or IsNotDistinctFrom op is encountered, add the left column, right column and is_null_equal_nulls to `join_ons` vector + // - Otherwise we add the expression to join_filter (use conjunction if filter already exists) + let (join_ons, nulls_equal_nulls, join_filter) = + split_eq_and_noneq_join_predicate_with_nulls_equality(&on); + let (left_cols, right_cols): (Vec<_>, Vec<_>) = + itertools::multiunzip(join_ons); + left.join_detailed( + right.build()?, + join_type, + (left_cols, right_cols), + join_filter, + nulls_equal_nulls, + )? + .build() } - Some(RelType::Read(read)) => { - async fn read_with_schema( - state: &dyn SubstraitPlanningState, - table_ref: TableReference, - schema: DFSchema, - projection: &Option, - ) -> Result { - let schema = schema.replace_qualifier(table_ref.clone()); - - let plan = { - let provider = match state.table(&table_ref).await? { - Some(ref provider) => Arc::clone(provider), - _ => return plan_err!("No table named '{table_ref}'"), - }; + None => { + let on: Vec = vec![]; + left.join_detailed(right.build()?, join_type, (on.clone(), on), None, false)? + .build() + } + } +} - LogicalPlanBuilder::scan( - table_ref, - provider_as_source(Arc::clone(&provider)), - None, - )? - .build()? - }; +pub async fn from_cross_rel( + consumer: &impl SubstraitConsumer, + cross: &CrossRel, +) -> Result { + let left = LogicalPlanBuilder::from( + from_substrait_rel(consumer, cross.left.as_ref().unwrap()).await?, + ); + let right = LogicalPlanBuilder::from( + from_substrait_rel(consumer, cross.right.as_ref().unwrap()).await?, + ); + let (left, right) = requalify_sides_if_needed(left, right)?; + left.cross_join(right.build()?)?.build() +} - ensure_schema_compatability(plan.schema(), schema.clone())?; +#[allow(deprecated)] +pub async fn from_read_rel( + consumer: &impl SubstraitConsumer, + read: &ReadRel, +) -> Result { + async fn read_with_schema( + consumer: &impl SubstraitConsumer, + table_ref: TableReference, + schema: DFSchema, + projection: &Option, + ) -> Result { + let schema = schema.replace_qualifier(table_ref.clone()); + + let plan = { + let provider = match consumer.resolve_table_ref(&table_ref).await? { + Some(ref provider) => Arc::clone(provider), + _ => return plan_err!("No table named '{table_ref}'"), + }; - let schema = apply_masking(schema, projection)?; + LogicalPlanBuilder::scan( + table_ref, + provider_as_source(Arc::clone(&provider)), + None, + )? + .build()? + }; - apply_projection(plan, schema) - } + ensure_schema_compatability(plan.schema(), schema.clone())?; - let named_struct = read.base_schema.as_ref().ok_or_else(|| { - substrait_datafusion_err!("No base schema provided for Read Relation") - })?; + let schema = apply_masking(schema, projection)?; - let substrait_schema = from_substrait_named_struct(named_struct, extensions)?; + apply_projection(plan, schema) + } - match &read.as_ref().read_type { - Some(ReadType::NamedTable(nt)) => { - let table_reference = match nt.names.len() { - 0 => { - return plan_err!("No table name found in NamedTable"); - } - 1 => TableReference::Bare { - table: nt.names[0].clone().into(), - }, - 2 => TableReference::Partial { - schema: nt.names[0].clone().into(), - table: nt.names[1].clone().into(), - }, - _ => TableReference::Full { - catalog: nt.names[0].clone().into(), - schema: nt.names[1].clone().into(), - table: nt.names[2].clone().into(), - }, - }; + let named_struct = read.base_schema.as_ref().ok_or_else(|| { + substrait_datafusion_err!("No base schema provided for Read Relation") + })?; - read_with_schema( - state, - table_reference, - substrait_schema, - &read.projection, - ) - .await + let substrait_schema = from_substrait_named_struct(consumer, named_struct)?; + + match &read.read_type { + Some(ReadType::NamedTable(nt)) => { + let table_reference = match nt.names.len() { + 0 => { + return plan_err!("No table name found in NamedTable"); } - Some(ReadType::VirtualTable(vt)) => { - if vt.values.is_empty() { - return Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: DFSchemaRef::new(substrait_schema), - })); + 1 => TableReference::Bare { + table: nt.names[0].clone().into(), + }, + 2 => TableReference::Partial { + schema: nt.names[0].clone().into(), + table: nt.names[1].clone().into(), + }, + _ => TableReference::Full { + catalog: nt.names[0].clone().into(), + schema: nt.names[1].clone().into(), + table: nt.names[2].clone().into(), + }, + }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + ) + .await + } + Some(ReadType::VirtualTable(vt)) => { + if vt.values.is_empty() { + return Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: DFSchemaRef::new(substrait_schema), + })); + } + + let values = vt + .values + .iter() + .map(|row| { + let mut name_idx = 0; + let lits = row + .fields + .iter() + .map(|lit| { + name_idx += 1; // top-level names are provided through schema + Ok(Expr::Literal(from_substrait_literal( + consumer, + lit, + &named_struct.names, + &mut name_idx, + )?)) + }) + .collect::>()?; + if name_idx != named_struct.names.len() { + return substrait_err!( + "Names list must match exactly to nested schema, but found {} uses for {} names", + name_idx, + named_struct.names.len() + ); } + Ok(lits) + }) + .collect::>()?; + + Ok(LogicalPlan::Values(Values { + schema: DFSchemaRef::new(substrait_schema), + values, + })) + } + Some(ReadType::LocalFiles(lf)) => { + fn extract_filename(name: &str) -> Option { + let corrected_url = + if name.starts_with("file://") && !name.starts_with("file:///") { + name.replacen("file://", "file:///", 1) + } else { + name.to_string() + }; + + Url::parse(&corrected_url).ok().and_then(|url| { + let path = url.path(); + std::path::Path::new(path) + .file_name() + .map(|filename| filename.to_string_lossy().to_string()) + }) + } + + // we could use the file name to check the original table provider + // TODO: currently does not support multiple local files + let filename: Option = + lf.items.first().and_then(|x| match x.path_type.as_ref() { + Some(UriFile(name)) => extract_filename(name), + _ => None, + }); + + if lf.items.len() > 1 || filename.is_none() { + return not_impl_err!("Only single file reads are supported"); + } + let name = filename.unwrap(); + // directly use unwrap here since we could determine it is a valid one + let table_reference = TableReference::Bare { table: name.into() }; + + read_with_schema( + consumer, + table_reference, + substrait_schema, + &read.projection, + ) + .await + } + _ => { + not_impl_err!("Unsupported ReadType: {:?}", read.read_type) + } + } +} + +pub async fn from_set_rel( + consumer: &impl SubstraitConsumer, + set: &SetRel, +) -> Result { + if set.inputs.len() < 2 { + substrait_err!("Set operation requires at least two inputs") + } else { + match set.op() { + SetOp::UnionAll => union_rels(consumer, &set.inputs, true).await, + SetOp::UnionDistinct => union_rels(consumer, &set.inputs, false).await, + SetOp::IntersectionPrimary => LogicalPlanBuilder::intersect( + from_substrait_rel(consumer, &set.inputs[0]).await?, + union_rels(consumer, &set.inputs[1..], true).await?, + false, + ), + SetOp::IntersectionMultiset => { + intersect_rels(consumer, &set.inputs, false).await + } + SetOp::IntersectionMultisetAll => { + intersect_rels(consumer, &set.inputs, true).await + } + SetOp::MinusPrimary => except_rels(consumer, &set.inputs, false).await, + SetOp::MinusPrimaryAll => except_rels(consumer, &set.inputs, true).await, + set_op => not_impl_err!("Unsupported set operator: {set_op:?}"), + } + } +} - let values = vt - .values - .iter() - .map(|row| { - let mut name_idx = 0; - let lits = row - .fields - .iter() - .map(|lit| { - name_idx += 1; // top-level names are provided through schema - Ok(Expr::Literal(from_substrait_literal( - lit, - extensions, - &named_struct.names, - &mut name_idx, - )?)) - }) - .collect::>()?; - if name_idx != named_struct.names.len() { - return substrait_err!( - "Names list must match exactly to nested schema, but found {} uses for {} names", - name_idx, - named_struct.names.len() - ); - } - Ok(lits) - }) - .collect::>()?; +pub async fn from_extension_leaf_rel( + consumer: &impl SubstraitConsumer, + extension_leaf_rel: &ExtensionLeafRel, +) -> Result { + consumer.consume_extension_leaf(extension_leaf_rel).await +} - Ok(LogicalPlan::Values(Values { - schema: DFSchemaRef::new(substrait_schema), - values, - })) - } - Some(ReadType::LocalFiles(lf)) => { - fn extract_filename(name: &str) -> Option { - let corrected_url = if name.starts_with("file://") - && !name.starts_with("file:///") - { - name.replacen("file://", "file:///", 1) - } else { - name.to_string() - }; +pub async fn from_extension_single_rel( + consumer: &impl SubstraitConsumer, + extension_single_rel: &ExtensionSingleRel, +) -> Result { + consumer + .consume_extension_single(extension_single_rel) + .await +} - Url::parse(&corrected_url).ok().and_then(|url| { - let path = url.path(); - std::path::Path::new(path) - .file_name() - .map(|filename| filename.to_string_lossy().to_string()) - }) - } +pub async fn from_extension_multi_rel( + consumer: &impl SubstraitConsumer, + extension_multi_rel: &ExtensionMultiRel, +) -> Result { + consumer.consume_extension_multi(extension_multi_rel).await +} - // we could use the file name to check the original table provider - // TODO: currently does not support multiple local files - let filename: Option = - lf.items.first().and_then(|x| match x.path_type.as_ref() { - Some(UriFile(name)) => extract_filename(name), - _ => None, - }); +pub async fn from_exchange_rel( + consumer: &impl SubstraitConsumer, + exchange: &ExchangeRel, +) -> Result { + let Some(input) = exchange.input.as_ref() else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + let input = Arc::new(from_substrait_rel(consumer, input).await?); - if lf.items.len() > 1 || filename.is_none() { - return not_impl_err!("Only single file reads are supported"); - } - let name = filename.unwrap(); - // directly use unwrap here since we could determine it is a valid one - let table_reference = TableReference::Bare { table: name.into() }; - - read_with_schema( - state, - table_reference, - substrait_schema, - &read.projection, - ) - .await - } - _ => { - not_impl_err!("Unsupported ReadType: {:?}", &read.as_ref().read_type) - } - } - } - Some(RelType::Set(set)) => match set_rel::SetOp::try_from(set.op) { - Ok(set_op) => { - if set.inputs.len() < 2 { - substrait_err!("Set operation requires at least two inputs") - } else { - match set_op { - set_rel::SetOp::UnionAll => { - union_rels(&set.inputs, state, extensions, true).await - } - set_rel::SetOp::UnionDistinct => { - union_rels(&set.inputs, state, extensions, false).await - } - set_rel::SetOp::IntersectionPrimary => { - LogicalPlanBuilder::intersect( - from_substrait_rel(state, &set.inputs[0], extensions) - .await?, - union_rels(&set.inputs[1..], state, extensions, true) - .await?, - false, - ) - } - set_rel::SetOp::IntersectionMultiset => { - intersect_rels(&set.inputs, state, extensions, false).await - } - set_rel::SetOp::IntersectionMultisetAll => { - intersect_rels(&set.inputs, state, extensions, true).await - } - set_rel::SetOp::MinusPrimary => { - except_rels(&set.inputs, state, extensions, false).await - } - set_rel::SetOp::MinusPrimaryAll => { - except_rels(&set.inputs, state, extensions, true).await - } - _ => not_impl_err!("Unsupported set operator: {set_op:?}"), - } - } + let Some(exchange_kind) = &exchange.exchange_kind else { + return substrait_err!("Unexpected empty input in ExchangeRel"); + }; + + // ref: https://substrait.io/relations/physical_relations/#exchange-types + let partitioning_scheme = match exchange_kind { + ExchangeKind::ScatterByFields(scatter_fields) => { + let mut partition_columns = vec![]; + let input_schema = input.schema(); + for field_ref in &scatter_fields.fields { + let column = from_substrait_field_reference(field_ref, input_schema)?; + partition_columns.push(column); } - Err(e) => not_impl_err!("Invalid set operation type {}: {e}", set.op), - }, - Some(RelType::ExtensionLeaf(extension)) => { - let Some(ext_detail) = &extension.detail else { - return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); - }; - let plan = state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) - } - Some(RelType::ExtensionSingle(extension)) => { - let Some(ext_detail) = &extension.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let Some(input_rel) = &extension.input else { - return substrait_err!( - "ExtensionSingleRel doesn't contains input rel. Try use ExtensionLeafRel instead" - ); - }; - let input_plan = from_substrait_rel(state, input_rel, extensions).await?; - let plan = - plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + Partitioning::Hash(partition_columns, exchange.partition_count as usize) } - Some(RelType::ExtensionMulti(extension)) => { - let Some(ext_detail) = &extension.detail else { - return substrait_err!("Unexpected empty detail in ExtensionSingleRel"); - }; - let plan = state - .serializer_registry() - .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; - let mut inputs = Vec::with_capacity(extension.inputs.len()); - for input in &extension.inputs { - let input_plan = from_substrait_rel(state, input, extensions).await?; - inputs.push(input_plan); - } - let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?; - Ok(LogicalPlan::Extension(Extension { node: plan })) + ExchangeKind::RoundRobin(_) => { + Partitioning::RoundRobinBatch(exchange.partition_count as usize) } - Some(RelType::Exchange(exchange)) => { - let Some(input) = exchange.input.as_ref() else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - let input = Arc::new(from_substrait_rel(state, input, extensions).await?); - - let Some(exchange_kind) = &exchange.exchange_kind else { - return substrait_err!("Unexpected empty input in ExchangeRel"); - }; - - // ref: https://substrait.io/relations/physical_relations/#exchange-types - let partitioning_scheme = match exchange_kind { - ExchangeKind::ScatterByFields(scatter_fields) => { - let mut partition_columns = vec![]; - let input_schema = input.schema(); - for field_ref in &scatter_fields.fields { - let column = - from_substrait_field_reference(field_ref, input_schema)?; - partition_columns.push(column); - } - Partitioning::Hash( - partition_columns, - exchange.partition_count as usize, - ) - } - ExchangeKind::RoundRobin(_) => { - Partitioning::RoundRobinBatch(exchange.partition_count as usize) - } - ExchangeKind::SingleTarget(_) - | ExchangeKind::MultiTarget(_) - | ExchangeKind::Broadcast(_) => { - return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); - } - }; - Ok(LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - })) + ExchangeKind::SingleTarget(_) + | ExchangeKind::MultiTarget(_) + | ExchangeKind::Broadcast(_) => { + return not_impl_err!("Unsupported exchange kind: {exchange_kind:?}"); } - _ => not_impl_err!("Unsupported RelType: {:?}", rel.rel_type), }; - apply_emit_kind(retrieve_rel_common(rel), plan?) + Ok(LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + })) } fn retrieve_rel_common(rel: &Rel) -> Option<&RelCommon> { @@ -1371,7 +1731,7 @@ fn compatible_nullabilities( } /// (Re)qualify the sides of a join if needed, i.e. if the columns from one side would otherwise -/// conflict with the columns from the other. +/// conflict with the columns from the other. /// Substrait doesn't currently allow specifying aliases, neither for columns nor for tables. For /// Substrait the names don't matter since it only refers to columns by indices, however DataFusion /// requires columns to be uniquely identifiable, in some places (see e.g. DFSchema::check_names). @@ -1417,16 +1777,14 @@ fn from_substrait_jointype(join_type: i32) -> Result { /// Convert Substrait Sorts to DataFusion Exprs pub async fn from_substrait_sorts( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, substrait_sorts: &Vec, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result> { let mut sorts: Vec = vec![]; for s in substrait_sorts { let expr = - from_substrait_rex(state, s.expr.as_ref().unwrap(), input_schema, extensions) - .await?; + from_substrait_rex(consumer, s.expr.as_ref().unwrap(), input_schema).await?; let asc_nullfirst = match &s.sort_kind { Some(k) => match k { Direction(d) => { @@ -1467,15 +1825,13 @@ pub async fn from_substrait_sorts( /// Convert Substrait Expressions to DataFusion Exprs pub async fn from_substrait_rex_vec( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, exprs: &Vec, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result> { let mut expressions: Vec = vec![]; for expr in exprs { - let expression = - from_substrait_rex(state, expr, input_schema, extensions).await?; + let expression = from_substrait_rex(consumer, expr, input_schema).await?; expressions.push(expression); } Ok(expressions) @@ -1483,16 +1839,15 @@ pub async fn from_substrait_rex_vec( /// Convert Substrait FunctionArguments to DataFusion Exprs pub async fn from_substrait_func_args( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, arguments: &Vec, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result> { let mut args: Vec = vec![]; for arg in arguments { let arg_expr = match &arg.arg_type { Some(ArgType::Value(e)) => { - from_substrait_rex(state, e, input_schema, extensions).await + from_substrait_rex(consumer, e, input_schema).await } _ => not_impl_err!("Function argument non-Value type not supported"), }; @@ -1503,370 +1858,416 @@ pub async fn from_substrait_func_args( /// Convert Substrait AggregateFunction to DataFusion Expr pub async fn from_substrait_agg_func( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, f: &AggregateFunction, input_schema: &DFSchema, - extensions: &Extensions, filter: Option>, order_by: Option>, distinct: bool, ) -> Result> { - let args = - from_substrait_func_args(state, &f.arguments, input_schema, extensions).await?; - - let Some(function_name) = extensions.functions.get(&f.function_reference) else { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { return plan_err!( "Aggregate function not registered: function anchor = {:?}", f.function_reference ); }; - let function_name = substrait_fun_name(function_name); - // try udaf first, then built-in aggr fn. - if let Ok(fun) = state.udaf(function_name) { - // deal with situation that count(*) got no arguments - let args = if fun.name() == "count" && args.is_empty() { - vec![Expr::Literal(ScalarValue::Int64(Some(1)))] - } else { - args - }; - - Ok(Arc::new(Expr::AggregateFunction( - expr::AggregateFunction::new_udf(fun, args, distinct, filter, order_by, None), - ))) - } else { - not_impl_err!( + let fn_name = substrait_fun_name(fn_signature); + let udaf = consumer.get_state().udaf(fn_name); + let udaf = udaf.map_err(|_| { + not_impl_datafusion_err!( "Aggregate function {} is not supported: function anchor = {:?}", - function_name, + fn_signature, f.function_reference ) - } + })?; + + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // deal with situation that count(*) got no arguments + let args = if udaf.name() == "count" && args.is_empty() { + vec![Expr::Literal(ScalarValue::Int64(Some(1)))] + } else { + args + }; + + Ok(Arc::new(Expr::AggregateFunction( + expr::AggregateFunction::new_udf(udaf, args, distinct, filter, order_by, None), + ))) } /// Convert Substrait Rex to DataFusion Expr -#[async_recursion] pub async fn from_substrait_rex( - state: &dyn SubstraitPlanningState, - e: &Expression, + consumer: &impl SubstraitConsumer, + expression: &Expression, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { - match &e.rex_type { - Some(RexType::SingularOrList(s)) => { - let substrait_expr = s.value.as_ref().unwrap(); - let substrait_list = s.options.as_ref(); - Ok(Expr::InList(InList { - expr: Box::new( - from_substrait_rex(state, substrait_expr, input_schema, extensions) - .await?, - ), - list: from_substrait_rex_vec( - state, - substrait_list, - input_schema, - extensions, - ) - .await?, - negated: false, - })) - } - Some(RexType::Selection(field_ref)) => { - Ok(from_substrait_field_reference(field_ref, input_schema)?) - } - Some(RexType::IfThen(if_then)) => { - // Parse `ifs` - // If the first element does not have a `then` part, then we can assume it's a base expression - let mut when_then_expr: Vec<(Box, Box)> = vec![]; - let mut expr = None; - for (i, if_expr) in if_then.ifs.iter().enumerate() { - if i == 0 { - // Check if the first element is type base expression - if if_expr.then.is_none() { - expr = Some(Box::new( - from_substrait_rex( - state, - if_expr.r#if.as_ref().unwrap(), - input_schema, - extensions, - ) - .await?, - )); - continue; - } - } - when_then_expr.push(( - Box::new( - from_substrait_rex( - state, - if_expr.r#if.as_ref().unwrap(), - input_schema, - extensions, - ) - .await?, - ), - Box::new( - from_substrait_rex( - state, - if_expr.then.as_ref().unwrap(), - input_schema, - extensions, - ) - .await?, - ), - )); + match &expression.rex_type { + Some(t) => match t { + RexType::Literal(expr) => consumer.consume_literal(expr).await, + RexType::Selection(expr) => { + consumer.consume_selection(expr, input_schema).await + } + RexType::ScalarFunction(expr) => { + consumer.consume_scalar_function(expr, input_schema).await + } + RexType::WindowFunction(expr) => { + consumer.consume_window_function(expr, input_schema).await + } + RexType::IfThen(expr) => consumer.consume_if_then(expr, input_schema).await, + RexType::SwitchExpression(expr) => { + consumer.consume_switch(expr, input_schema).await + } + RexType::SingularOrList(expr) => { + consumer.consume_singular_or_list(expr, input_schema).await } - // Parse `else` - let else_expr = match &if_then.r#else { - Some(e) => Some(Box::new( - from_substrait_rex(state, e, input_schema, extensions).await?, - )), - None => None, - }; - Ok(Expr::Case(Case { - expr, - when_then_expr, - else_expr, - })) - } - Some(RexType::ScalarFunction(f)) => { - let Some(fn_name) = extensions.functions.get(&f.function_reference) else { - return plan_err!( - "Scalar function not found: function reference = {:?}", - f.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_name); - let args = - from_substrait_func_args(state, &f.arguments, input_schema, extensions) - .await?; + RexType::MultiOrList(expr) => { + consumer.consume_multi_or_list(expr, input_schema).await + } - // try to first match the requested function into registered udfs, then built-in ops - // and finally built-in expressions - if let Ok(func) = state.udf(fn_name) { - Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( - func.to_owned(), - args, - ))) - } else if let Some(op) = name_to_op(fn_name) { - if f.arguments.len() < 2 { - return not_impl_err!( - "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", - f.arguments.len() - ); - } - // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. - // In those cases we iterate through all the arguments, applying the binary expression against them all - let combined_expr = args - .into_iter() - .fold(None, |combined_expr: Option, arg: Expr| { - Some(match combined_expr { - Some(expr) => Expr::BinaryExpr(BinaryExpr { - left: Box::new(expr), - op, - right: Box::new(arg), - }), - None => arg, - }) - }) - .unwrap(); + RexType::Cast(expr) => { + consumer.consume_cast(expr.as_ref(), input_schema).await + } - Ok(combined_expr) - } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { - builder.build(state, f, input_schema, extensions).await - } else { - not_impl_err!("Unsupported function name: {fn_name:?}") + RexType::Subquery(expr) => { + consumer.consume_subquery(expr.as_ref(), input_schema).await } - } - Some(RexType::Literal(lit)) => { - let scalar_value = from_substrait_literal_without_names(lit, extensions)?; - Ok(Expr::Literal(scalar_value)) - } - Some(RexType::Cast(cast)) => match cast.as_ref().r#type.as_ref() { - Some(output_type) => { - let input_expr = Box::new( + RexType::Nested(expr) => consumer.consume_nested(expr, input_schema).await, + RexType::Enum(expr) => consumer.consume_enum(expr, input_schema).await, + }, + None => substrait_err!("Expression must set rex_type: {:?}", expression), + } +} + +pub async fn from_singular_or_list( + consumer: &impl SubstraitConsumer, + expr: &SingularOrList, + input_schema: &DFSchema, +) -> Result { + let substrait_expr = expr.value.as_ref().unwrap(); + let substrait_list = expr.options.as_ref(); + Ok(Expr::InList(InList { + expr: Box::new(from_substrait_rex(consumer, substrait_expr, input_schema).await?), + list: from_substrait_rex_vec(consumer, substrait_list, input_schema).await?, + negated: false, + })) +} + +pub async fn from_selection( + _consumer: &impl SubstraitConsumer, + field_ref: &FieldReference, + input_schema: &DFSchema, +) -> Result { + from_substrait_field_reference(field_ref, input_schema) +} + +pub async fn from_if_then( + consumer: &impl SubstraitConsumer, + if_then: &IfThen, + input_schema: &DFSchema, +) -> Result { + // Parse `ifs` + // If the first element does not have a `then` part, then we can assume it's a base expression + let mut when_then_expr: Vec<(Box, Box)> = vec![]; + let mut expr = None; + for (i, if_expr) in if_then.ifs.iter().enumerate() { + if i == 0 { + // Check if the first element is type base expression + if if_expr.then.is_none() { + expr = Some(Box::new( from_substrait_rex( - state, - cast.as_ref().input.as_ref().unwrap().as_ref(), + consumer, + if_expr.r#if.as_ref().unwrap(), input_schema, - extensions, ) .await?, - ); - let data_type = - from_substrait_type_without_names(output_type, extensions)?; - if cast.failure_behavior() == ReturnNull { - Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) - } else { - Ok(Expr::Cast(Cast::new(input_expr, data_type))) - } + )); + continue; } - None => substrait_err!("Cast expression without output type is not allowed"), - }, - Some(RexType::WindowFunction(window)) => { - let Some(fn_name) = extensions.functions.get(&window.function_reference) - else { - return plan_err!( - "Window function not found: function reference = {:?}", - window.function_reference - ); - }; - let fn_name = substrait_fun_name(fn_name); - - // check udwf first, then udaf, then built-in window and aggregate functions - let fun = if let Ok(udwf) = state.udwf(fn_name) { - Ok(WindowFunctionDefinition::WindowUDF(udwf)) - } else if let Ok(udaf) = state.udaf(fn_name) { - Ok(WindowFunctionDefinition::AggregateUDF(udaf)) - } else { - not_impl_err!( - "Window function {} is not supported: function anchor = {:?}", - fn_name, - window.function_reference + } + when_then_expr.push(( + Box::new( + from_substrait_rex( + consumer, + if_expr.r#if.as_ref().unwrap(), + input_schema, ) - }?; - - let order_by = - from_substrait_sorts(state, &window.sorts, input_schema, extensions) - .await?; - - let bound_units = - match BoundsType::try_from(window.bounds_type).map_err(|e| { - plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) - })? { - BoundsType::Rows => WindowFrameUnits::Rows, - BoundsType::Range => WindowFrameUnits::Range, - BoundsType::Unspecified => { - // If the plan does not specify the bounds type, then we use a simple logic to determine the units - // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary - // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row - if order_by.is_empty() { - WindowFrameUnits::Rows - } else { - WindowFrameUnits::Range - } - } - }; - Ok(Expr::WindowFunction(expr::WindowFunction { - fun, - args: from_substrait_func_args( - state, - &window.arguments, + .await?, + ), + Box::new( + from_substrait_rex( + consumer, + if_expr.then.as_ref().unwrap(), input_schema, - extensions, ) .await?, - partition_by: from_substrait_rex_vec( - state, - &window.partitions, + ), + )); + } + // Parse `else` + let else_expr = match &if_then.r#else { + Some(e) => Some(Box::new( + from_substrait_rex(consumer, e, input_schema).await?, + )), + None => None, + }; + Ok(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + })) +} + +pub async fn from_scalar_function( + consumer: &impl SubstraitConsumer, + f: &ScalarFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&f.function_reference) + else { + return plan_err!( + "Scalar function not found: function reference = {:?}", + f.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + let args = from_substrait_func_args(consumer, &f.arguments, input_schema).await?; + + // try to first match the requested function into registered udfs, then built-in ops + // and finally built-in expressions + if let Ok(func) = consumer.get_state().udf(fn_name) { + Ok(Expr::ScalarFunction(expr::ScalarFunction::new_udf( + func.to_owned(), + args, + ))) + } else if let Some(op) = name_to_op(fn_name) { + if f.arguments.len() < 2 { + return not_impl_err!( + "Expect at least two arguments for binary operator {op:?}, the provided number of operators is {:?}", + f.arguments.len() + ); + } + // Some expressions are binary in DataFusion but take in a variadic number of args in Substrait. + // In those cases we iterate through all the arguments, applying the binary expression against them all + let combined_expr = args + .into_iter() + .fold(None, |combined_expr: Option, arg: Expr| { + Some(match combined_expr { + Some(expr) => Expr::BinaryExpr(BinaryExpr { + left: Box::new(expr), + op, + right: Box::new(arg), + }), + None => arg, + }) + }) + .unwrap(); + + Ok(combined_expr) + } else if let Some(builder) = BuiltinExprBuilder::try_from_name(fn_name) { + builder.build(consumer, f, input_schema).await + } else { + not_impl_err!("Unsupported function name: {fn_name:?}") + } +} + +pub async fn from_literal( + consumer: &impl SubstraitConsumer, + expr: &Literal, +) -> Result { + let scalar_value = from_substrait_literal_without_names(consumer, expr)?; + Ok(Expr::Literal(scalar_value)) +} + +pub async fn from_cast( + consumer: &impl SubstraitConsumer, + cast: &substrait_expression::Cast, + input_schema: &DFSchema, +) -> Result { + match cast.r#type.as_ref() { + Some(output_type) => { + let input_expr = Box::new( + from_substrait_rex( + consumer, + cast.input.as_ref().unwrap().as_ref(), input_schema, - extensions, ) .await?, - order_by, - window_frame: datafusion::logical_expr::WindowFrame::new_bounds( - bound_units, - from_substrait_bound(&window.lower_bound, true)?, - from_substrait_bound(&window.upper_bound, false)?, - ), - null_treatment: None, - })) + ); + let data_type = from_substrait_type_without_names(consumer, output_type)?; + if cast.failure_behavior() == ReturnNull { + Ok(Expr::TryCast(TryCast::new(input_expr, data_type))) + } else { + Ok(Expr::Cast(Cast::new(input_expr, data_type))) + } } - Some(RexType::Subquery(subquery)) => match &subquery.as_ref().subquery_type { - Some(subquery_type) => match subquery_type { - SubqueryType::InPredicate(in_predicate) => { - if in_predicate.needles.len() != 1 { - substrait_err!("InPredicate Subquery type must have exactly one Needle expression") - } else { - let needle_expr = &in_predicate.needles[0]; - let haystack_expr = &in_predicate.haystack; - if let Some(haystack_expr) = haystack_expr { - let haystack_expr = - from_substrait_rel(state, haystack_expr, extensions) - .await?; - let outer_refs = haystack_expr.all_out_ref_exprs(); - Ok(Expr::InSubquery(InSubquery { - expr: Box::new( - from_substrait_rex( - state, - needle_expr, - input_schema, - extensions, - ) + None => substrait_err!("Cast expression without output type is not allowed"), + } +} + +pub async fn from_window_function( + consumer: &impl SubstraitConsumer, + window: &WindowFunction, + input_schema: &DFSchema, +) -> Result { + let Some(fn_signature) = consumer + .get_extensions() + .functions + .get(&window.function_reference) + else { + return plan_err!( + "Window function not found: function reference = {:?}", + window.function_reference + ); + }; + let fn_name = substrait_fun_name(fn_signature); + + // check udwf first, then udaf, then built-in window and aggregate functions + let fun = if let Ok(udwf) = consumer.get_state().udwf(fn_name) { + Ok(WindowFunctionDefinition::WindowUDF(udwf)) + } else if let Ok(udaf) = consumer.get_state().udaf(fn_name) { + Ok(WindowFunctionDefinition::AggregateUDF(udaf)) + } else { + not_impl_err!( + "Window function {} is not supported: function anchor = {:?}", + fn_name, + window.function_reference + ) + }?; + + let order_by = from_substrait_sorts(consumer, &window.sorts, input_schema).await?; + + let bound_units = match BoundsType::try_from(window.bounds_type).map_err(|e| { + plan_datafusion_err!("Invalid bound type {}: {e}", window.bounds_type) + })? { + BoundsType::Rows => WindowFrameUnits::Rows, + BoundsType::Range => WindowFrameUnits::Range, + BoundsType::Unspecified => { + // If the plan does not specify the bounds type, then we use a simple logic to determine the units + // If there is no `ORDER BY`, then by default, the frame counts each row from the lower up to upper boundary + // If there is `ORDER BY`, then by default, each frame is a range starting from unbounded preceding to current row + if order_by.is_empty() { + WindowFrameUnits::Rows + } else { + WindowFrameUnits::Range + } + } + }; + Ok(Expr::WindowFunction(expr::WindowFunction { + fun, + args: from_substrait_func_args(consumer, &window.arguments, input_schema).await?, + partition_by: from_substrait_rex_vec(consumer, &window.partitions, input_schema) + .await?, + order_by, + window_frame: datafusion::logical_expr::WindowFrame::new_bounds( + bound_units, + from_substrait_bound(&window.lower_bound, true)?, + from_substrait_bound(&window.upper_bound, false)?, + ), + null_treatment: None, + })) +} + +pub async fn from_subquery( + consumer: &impl SubstraitConsumer, + subquery: &substrait_expression::Subquery, + input_schema: &DFSchema, +) -> Result { + match &subquery.subquery_type { + Some(subquery_type) => match subquery_type { + SubqueryType::InPredicate(in_predicate) => { + if in_predicate.needles.len() != 1 { + substrait_err!("InPredicate Subquery type must have exactly one Needle expression") + } else { + let needle_expr = &in_predicate.needles[0]; + let haystack_expr = &in_predicate.haystack; + if let Some(haystack_expr) = haystack_expr { + let haystack_expr = + from_substrait_rel(consumer, haystack_expr).await?; + let outer_refs = haystack_expr.all_out_ref_exprs(); + Ok(Expr::InSubquery(InSubquery { + expr: Box::new( + from_substrait_rex(consumer, needle_expr, input_schema) .await?, - ), - subquery: Subquery { - subquery: Arc::new(haystack_expr), - outer_ref_columns: outer_refs, - }, - negated: false, - })) - } else { - substrait_err!("InPredicate Subquery type must have a Haystack expression") - } + ), + subquery: Subquery { + subquery: Arc::new(haystack_expr), + outer_ref_columns: outer_refs, + }, + negated: false, + })) + } else { + substrait_err!( + "InPredicate Subquery type must have a Haystack expression" + ) } } - SubqueryType::Scalar(query) => { - let plan = from_substrait_rel( - state, - &(query.input.clone()).unwrap_or_default(), - extensions, - ) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - })) - } - SubqueryType::SetPredicate(predicate) => { - match predicate.predicate_op() { - // exist - PredicateOp::Exists => { - let relation = &predicate.tuples; - let plan = from_substrait_rel( - state, - &relation.clone().unwrap_or_default(), - extensions, - ) - .await?; - let outer_ref_columns = plan.all_out_ref_exprs(); - Ok(Expr::Exists(Exists::new( - Subquery { - subquery: Arc::new(plan), - outer_ref_columns, - }, - false, - ))) - } - other_type => substrait_err!( - "unimplemented type {:?} for set predicate", - other_type - ), + } + SubqueryType::Scalar(query) => { + let plan = from_substrait_rel( + consumer, + &(query.input.clone()).unwrap_or_default(), + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::ScalarSubquery(Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + })) + } + SubqueryType::SetPredicate(predicate) => { + match predicate.predicate_op() { + // exist + PredicateOp::Exists => { + let relation = &predicate.tuples; + let plan = from_substrait_rel( + consumer, + &relation.clone().unwrap_or_default(), + ) + .await?; + let outer_ref_columns = plan.all_out_ref_exprs(); + Ok(Expr::Exists(Exists::new( + Subquery { + subquery: Arc::new(plan), + outer_ref_columns, + }, + false, + ))) } + other_type => substrait_err!( + "unimplemented type {:?} for set predicate", + other_type + ), } - other_type => { - substrait_err!("Subquery type {:?} not implemented", other_type) - } - }, - None => { - substrait_err!("Subquery expression without SubqueryType is not allowed") + } + other_type => { + substrait_err!("Subquery type {:?} not implemented", other_type) } }, - _ => not_impl_err!("unsupported rex_type"), + None => { + substrait_err!("Subquery expression without SubqueryType is not allowed") + } } } pub(crate) fn from_substrait_type_without_names( + consumer: &impl SubstraitConsumer, dt: &Type, - extensions: &Extensions, ) -> Result { - from_substrait_type(dt, extensions, &[], &mut 0) + from_substrait_type(consumer, dt, &[], &mut 0) } fn from_substrait_type( + consumer: &impl SubstraitConsumer, dt: &Type, - extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -1979,7 +2380,7 @@ fn from_substrait_type( substrait_datafusion_err!("List type must have inner type") })?; let field = Arc::new(Field::new_list_field( - from_substrait_type(inner_type, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, inner_type, dfs_names, name_idx)?, // We ignore Substrait's nullability here to match to_substrait_literal // which always creates nullable lists true, @@ -2001,12 +2402,12 @@ fn from_substrait_type( })?; let key_field = Arc::new(Field::new( "key", - from_substrait_type(key_type, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, key_type, dfs_names, name_idx)?, false, )); let value_field = Arc::new(Field::new( "value", - from_substrait_type(value_type, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, value_type, dfs_names, name_idx)?, true, )); Ok(DataType::Map( @@ -2037,42 +2438,48 @@ fn from_substrait_type( Ok(DataType::Interval(IntervalUnit::MonthDayNano)) } r#type::Kind::UserDefined(u) => { - if let Some(name) = extensions.types.get(&u.type_reference) { + if let Ok(data_type) = consumer.consume_user_defined_type(u) { + return Ok(data_type); + } + + // TODO: remove the code below once the producer has been updated + if let Some(name) = consumer.get_extensions().types.get(&u.type_reference) + { #[allow(deprecated)] - match name.as_ref() { - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), + match name.as_ref() { + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), - } + } } else { #[allow(deprecated)] - match u.type_reference { - // Kept for backwards compatibility, producers should use IntervalYear instead - INTERVAL_YEAR_MONTH_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::YearMonth)) - } - // Kept for backwards compatibility, producers should use IntervalDay instead - INTERVAL_DAY_TIME_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::DayTime)) - } - // Kept for backwards compatibility, producers should use IntervalCompound instead - INTERVAL_MONTH_DAY_NANO_TYPE_REF => { - Ok(DataType::Interval(IntervalUnit::MonthDayNano)) - } - _ => not_impl_err!( + match u.type_reference { + // Kept for backwards compatibility, producers should use IntervalYear instead + INTERVAL_YEAR_MONTH_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::YearMonth)) + } + // Kept for backwards compatibility, producers should use IntervalDay instead + INTERVAL_DAY_TIME_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::DayTime)) + } + // Kept for backwards compatibility, producers should use IntervalCompound instead + INTERVAL_MONTH_DAY_NANO_TYPE_REF => { + Ok(DataType::Interval(IntervalUnit::MonthDayNano)) + } + _ => not_impl_err!( "Unsupported Substrait user defined type with ref {} and variation {}", u.type_reference, u.type_variation_reference ), - } + } } } r#type::Kind::Struct(s) => Ok(DataType::Struct(from_substrait_struct_type( - s, extensions, dfs_names, name_idx, + consumer, s, dfs_names, name_idx, )?)), r#type::Kind::Varchar(_) => Ok(DataType::Utf8), r#type::Kind::FixedChar(_) => Ok(DataType::Utf8), @@ -2083,8 +2490,8 @@ fn from_substrait_type( } fn from_substrait_struct_type( + consumer: &impl SubstraitConsumer, s: &r#type::Struct, - extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -2092,7 +2499,7 @@ fn from_substrait_struct_type( for (i, f) in s.types.iter().enumerate() { let field = Field::new( next_struct_field_name(i, dfs_names, name_idx)?, - from_substrait_type(f, extensions, dfs_names, name_idx)?, + from_substrait_type(consumer, f, dfs_names, name_idx)?, true, // We assume everything to be nullable since that's easier than ensuring it matches ); fields.push(field); @@ -2120,15 +2527,15 @@ fn next_struct_field_name( /// Convert Substrait NamedStruct to DataFusion DFSchemaRef pub fn from_substrait_named_struct( + consumer: &impl SubstraitConsumer, base_schema: &NamedStruct, - extensions: &Extensions, ) -> Result { let mut name_idx = 0; let fields = from_substrait_struct_type( + consumer, base_schema.r#struct.as_ref().ok_or_else(|| { substrait_datafusion_err!("Named struct must contain a struct") })?, - extensions, &base_schema.names, &mut name_idx, ); @@ -2189,15 +2596,15 @@ fn from_substrait_bound( } pub(crate) fn from_substrait_literal_without_names( + consumer: &impl SubstraitConsumer, lit: &Literal, - extensions: &Extensions, ) -> Result { - from_substrait_literal(lit, extensions, &vec![], &mut 0) + from_substrait_literal(consumer, lit, &vec![], &mut 0) } fn from_substrait_literal( + consumer: &impl SubstraitConsumer, lit: &Literal, - extensions: &Extensions, dfs_names: &Vec, name_idx: &mut usize, ) -> Result { @@ -2333,12 +2740,7 @@ fn from_substrait_literal( .iter() .map(|el| { element_name_idx = *name_idx; - from_substrait_literal( - el, - extensions, - dfs_names, - &mut element_name_idx, - ) + from_substrait_literal(consumer, el, dfs_names, &mut element_name_idx) }) .collect::>>()?; *name_idx = element_name_idx; @@ -2362,8 +2764,8 @@ fn from_substrait_literal( } Some(LiteralType::EmptyList(l)) => { let element_type = from_substrait_type( + consumer, l.r#type.clone().unwrap().as_ref(), - extensions, dfs_names, name_idx, )?; @@ -2389,14 +2791,14 @@ fn from_substrait_literal( .map(|kv| { entry_name_idx = *name_idx; let key_sv = from_substrait_literal( + consumer, kv.key.as_ref().unwrap(), - extensions, dfs_names, &mut entry_name_idx, )?; let value_sv = from_substrait_literal( + consumer, kv.value.as_ref().unwrap(), - extensions, dfs_names, &mut entry_name_idx, )?; @@ -2434,8 +2836,8 @@ fn from_substrait_literal( Some(v) => Ok(v), _ => plan_err!("Missing value type for empty map"), }?; - let key_type = from_substrait_type(key, extensions, dfs_names, name_idx)?; - let value_type = from_substrait_type(value, extensions, dfs_names, name_idx)?; + let key_type = from_substrait_type(consumer, key, dfs_names, name_idx)?; + let value_type = from_substrait_type(consumer, value, dfs_names, name_idx)?; // new_empty_array on a MapType creates a too empty array // We want it to contain an empty struct array to align with an empty MapBuilder one @@ -2461,7 +2863,7 @@ fn from_substrait_literal( let mut builder = ScalarStructBuilder::new(); for (i, field) in s.fields.iter().enumerate() { let name = next_struct_field_name(i, dfs_names, name_idx)?; - let sv = from_substrait_literal(field, extensions, dfs_names, name_idx)?; + let sv = from_substrait_literal(consumer, field, dfs_names, name_idx)?; // We assume everything to be nullable, since Arrow's strict about things matching // and it's hard to match otherwise. builder = builder.with_scalar(Field::new(name, sv.data_type(), true), sv); @@ -2469,7 +2871,7 @@ fn from_substrait_literal( builder.build()? } Some(LiteralType::Null(ntype)) => { - from_substrait_null(ntype, extensions, dfs_names, name_idx)? + from_substrait_null(consumer, ntype, dfs_names, name_idx)? } Some(LiteralType::IntervalDayToSecond(IntervalDayToSecond { days, @@ -2533,9 +2935,15 @@ fn from_substrait_literal( }, Some(LiteralType::FixedChar(c)) => ScalarValue::Utf8(Some(c.clone())), Some(LiteralType::UserDefined(user_defined)) => { + if let Ok(value) = consumer.consume_user_defined_literal(user_defined) { + return Ok(value); + } + + // TODO: remove the code below once the producer has been updated + // Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed let interval_month_day_nano = - |user_defined: &UserDefined| -> Result { + |user_defined: &proto::expression::literal::UserDefined| -> Result { let Some(Val::Value(raw_val)) = user_defined.val.as_ref() else { return substrait_err!("Interval month day nano value is empty"); }; @@ -2559,7 +2967,11 @@ fn from_substrait_literal( ))) }; - if let Some(name) = extensions.types.get(&user_defined.type_reference) { + if let Some(name) = consumer + .get_extensions() + .types + .get(&user_defined.type_reference) + { match name.as_ref() { // Kept for backwards compatibility - producers should use IntervalCompound instead #[allow(deprecated)] @@ -2632,8 +3044,8 @@ fn from_substrait_literal( } fn from_substrait_null( + consumer: &impl SubstraitConsumer, null_type: &Type, - extensions: &Extensions, dfs_names: &[String], name_idx: &mut usize, ) -> Result { @@ -2751,8 +3163,8 @@ fn from_substrait_null( r#type::Kind::List(l) => { let field = Field::new_list_field( from_substrait_type( + consumer, l.r#type.clone().unwrap().as_ref(), - extensions, dfs_names, name_idx, )?, @@ -2779,9 +3191,9 @@ fn from_substrait_null( })?; let key_type = - from_substrait_type(key_type, extensions, dfs_names, name_idx)?; + from_substrait_type(consumer, key_type, dfs_names, name_idx)?; let value_type = - from_substrait_type(value_type, extensions, dfs_names, name_idx)?; + from_substrait_type(consumer, value_type, dfs_names, name_idx)?; let entries_field = Arc::new(Field::new_struct( "entries", vec![ @@ -2795,7 +3207,7 @@ fn from_substrait_null( } r#type::Kind::Struct(s) => { let fields = - from_substrait_struct_type(s, extensions, dfs_names, name_idx)?; + from_substrait_struct_type(consumer, s, dfs_names, name_idx)?; Ok(ScalarStructBuilder::new_null(fields)) } _ => not_impl_err!("Unsupported Substrait type for null: {kind:?}"), @@ -2807,16 +3219,15 @@ fn from_substrait_null( #[allow(deprecated)] async fn from_substrait_grouping( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, grouping: &Grouping, expressions: &[Expr], input_schema: &DFSchemaRef, - extensions: &Extensions, ) -> Result> { let mut group_exprs = vec![]; if !grouping.grouping_expressions.is_empty() { for e in &grouping.grouping_expressions { - let expr = from_substrait_rex(state, e, input_schema, extensions).await?; + let expr = from_substrait_rex(consumer, e, input_schema).await?; group_exprs.push(expr); } return Ok(group_exprs); @@ -2869,29 +3280,17 @@ impl BuiltinExprBuilder { pub async fn build( self, - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { match self.expr_name.as_str() { - "like" => { - Self::build_like_expr(state, false, f, input_schema, extensions).await - } - "ilike" => { - Self::build_like_expr(state, true, f, input_schema, extensions).await - } + "like" => Self::build_like_expr(consumer, false, f, input_schema).await, + "ilike" => Self::build_like_expr(consumer, true, f, input_schema).await, "not" | "negative" | "negate" | "is_null" | "is_not_null" | "is_true" | "is_false" | "is_not_true" | "is_not_false" | "is_unknown" | "is_not_unknown" => { - Self::build_unary_expr( - state, - &self.expr_name, - f, - input_schema, - extensions, - ) - .await + Self::build_unary_expr(consumer, &self.expr_name, f, input_schema).await } _ => { not_impl_err!("Unsupported builtin expression: {}", self.expr_name) @@ -2900,11 +3299,10 @@ impl BuiltinExprBuilder { } async fn build_unary_expr( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, fn_name: &str, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { if f.arguments.len() != 1 { return substrait_err!("Expect one argument for {fn_name} expr"); @@ -2912,8 +3310,7 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for {fn_name} expr"); }; - let arg = - from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; + let arg = from_substrait_rex(consumer, expr_substrait, input_schema).await?; let arg = Box::new(arg); let expr = match fn_name { @@ -2934,11 +3331,10 @@ impl BuiltinExprBuilder { } async fn build_like_expr( - state: &dyn SubstraitPlanningState, + consumer: &impl SubstraitConsumer, case_insensitive: bool, f: &ScalarFunction, input_schema: &DFSchema, - extensions: &Extensions, ) -> Result { let fn_name = if case_insensitive { "ILIKE" } else { "LIKE" }; if f.arguments.len() != 2 && f.arguments.len() != 3 { @@ -2948,14 +3344,12 @@ impl BuiltinExprBuilder { let Some(ArgType::Value(expr_substrait)) = &f.arguments[0].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let expr = - from_substrait_rex(state, expr_substrait, input_schema, extensions).await?; + let expr = from_substrait_rex(consumer, expr_substrait, input_schema).await?; let Some(ArgType::Value(pattern_substrait)) = &f.arguments[1].arg_type else { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; let pattern = - from_substrait_rex(state, pattern_substrait, input_schema, extensions) - .await?; + from_substrait_rex(consumer, pattern_substrait, input_schema).await?; // Default case: escape character is Literal(Utf8(None)) let escape_char = if f.arguments.len() == 3 { @@ -2964,13 +3358,8 @@ impl BuiltinExprBuilder { return substrait_err!("Invalid arguments type for `{fn_name}` expr"); }; - let escape_char_expr = from_substrait_rex( - state, - escape_char_substrait, - input_schema, - extensions, - ) - .await?; + let escape_char_expr = + from_substrait_rex(consumer, escape_char_substrait, input_schema).await?; match escape_char_expr { Expr::Literal(ScalarValue::Utf8(escape_char_string)) => { @@ -2999,8 +3388,9 @@ impl BuiltinExprBuilder { #[cfg(test)] mod test { - use crate::extensions::Extensions; - use crate::logical_plan::consumer::from_substrait_literal_without_names; + use crate::logical_plan::consumer::{ + from_substrait_literal_without_names, DefaultSubstraitConsumer, + }; use arrow_buffer::IntervalMonthDayNano; use datafusion::error::Result; use datafusion::scalar::ScalarValue; @@ -3033,8 +3423,9 @@ mod test { })), }; + let consumer = DefaultSubstraitConsumer::default(); assert_eq!( - from_substrait_literal_without_names(&substrait, &Extensions::default())?, + from_substrait_literal_without_names(&consumer, &substrait)?, ScalarValue::IntervalMonthDayNano(Some(IntervalMonthDayNano { months: 14, days: 3, diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index a128b90e6889e..2a6e0068c76d0 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -2205,6 +2205,7 @@ mod test { use crate::logical_plan::consumer::{ from_substrait_extended_expr, from_substrait_literal_without_names, from_substrait_named_struct, from_substrait_type_without_names, + DefaultSubstraitConsumer, }; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion::arrow::array::{ @@ -2215,6 +2216,10 @@ mod test { use datafusion::common::DFSchema; use datafusion::execution::SessionStateBuilder; + fn test_consumer() -> DefaultSubstraitConsumer { + DefaultSubstraitConsumer::default() + } + #[test] fn round_trip_literals() -> Result<()> { round_trip_literal(ScalarValue::Boolean(None))?; @@ -2339,7 +2344,7 @@ mod test { let mut extensions = Extensions::default(); let substrait_literal = to_substrait_literal(&scalar, &mut extensions)?; let roundtrip_scalar = - from_substrait_literal_without_names(&substrait_literal, &extensions)?; + from_substrait_literal_without_names(&test_consumer(), &substrait_literal)?; assert_eq!(scalar, roundtrip_scalar); Ok(()) } @@ -2418,8 +2423,8 @@ mod test { // As DataFusion doesn't consider nullability as a property of the type, but field, // it doesn't matter if we set nullability to true or false here. let substrait = to_substrait_type(&dt, true)?; - let roundtrip_dt = - from_substrait_type_without_names(&substrait, &Extensions::default())?; + let consumer = test_consumer(); + let roundtrip_dt = from_substrait_type_without_names(&consumer, &substrait)?; assert_eq!(dt, roundtrip_dt); Ok(()) } @@ -2470,7 +2475,7 @@ mod test { ); let roundtrip_schema = - from_substrait_named_struct(&named_struct, &Extensions::default())?; + from_substrait_named_struct(&test_consumer(), &named_struct)?; assert_eq!(schema.as_ref(), &roundtrip_schema); Ok(()) } diff --git a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs index f836dea03c61a..8ae64cc64e09a 100644 --- a/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/substrait/tests/cases/roundtrip_logical_plan.rs @@ -25,24 +25,31 @@ use datafusion_substrait::logical_plan::{ use std::cmp::Ordering; use std::mem::size_of_val; +use async_trait::async_trait; use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit}; -use datafusion::common::{not_impl_err, plan_err, DFSchema, DFSchemaRef}; +use datafusion::catalog::TableProvider; +use datafusion::catalog_common::TableReference; +use datafusion::common::{not_impl_err, plan_err, substrait_err, DFSchema, DFSchemaRef}; use datafusion::error::Result; use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::session_state::SessionStateBuilder; +use datafusion::execution::SessionState; use datafusion::logical_expr::{ Extension, LogicalPlan, PartitionEvaluator, Repartition, UserDefinedLogicalNode, Values, Volatility, }; use datafusion::optimizer::simplify_expressions::expr_simplifier::THRESHOLD_INLINE_INLIST; use datafusion::prelude::*; +use datafusion_substrait::extensions::Extensions; +use datafusion_substrait::logical_plan::consumer::{ + from_substrait_plan_with_consumer, SubstraitConsumer, +}; use std::hash::Hash; use std::sync::Arc; - -use datafusion::execution::session_state::SessionStateBuilder; use substrait::proto::extensions::simple_extension_declaration::MappingType; use substrait::proto::rel::RelType; -use substrait::proto::{plan_rel, Plan, Rel}; +use substrait::proto::{plan_rel, ExtensionLeafRel, Plan, Rel}; #[derive(Debug)] struct MockSerializerRegistry; @@ -977,6 +984,43 @@ async fn new_test_grammar() -> Result<()> { .await } +struct MockSubstraitConsumer { + extensions: Extensions, + state: SessionState, +} + +#[async_trait] +impl SubstraitConsumer for MockSubstraitConsumer { + async fn resolve_table_ref( + &self, + _: &TableReference, + ) -> Result>> { + not_impl_err!("MockSubstraitConsumer does not implement resolve_table_ref") + } + + fn get_extensions(&self) -> &Extensions { + &self.extensions + } + + fn get_state(&self) -> &SessionState { + &self.state + } + + async fn consume_extension_leaf( + &self, + rel: &ExtensionLeafRel, + ) -> Result { + let Some(ext_detail) = &rel.detail else { + return substrait_err!("Unexpected empty detail in ExtensionLeafRel"); + }; + let plan = self + .state + .serializer_registry() + .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?; + Ok(LogicalPlan::Extension(Extension { node: plan })) + } +} + #[tokio::test] async fn extension_logical_plan() -> Result<()> { let ctx = create_context().await?; @@ -989,8 +1033,13 @@ async fn extension_logical_plan() -> Result<()> { }), }); + let consumer = MockSubstraitConsumer { + state: ctx.state(), + extensions: Extensions::default(), + }; + let proto = to_substrait_plan(&ext_plan, &ctx.state())?; - let plan2 = from_substrait_plan(&ctx.state(), &proto).await?; + let plan2 = from_substrait_plan_with_consumer(&consumer, &proto).await?; let plan1str = format!("{ext_plan}"); let plan2str = format!("{plan2}"); diff --git a/datafusion/substrait/tests/utils.rs b/datafusion/substrait/tests/utils.rs index 00cbfb0c412cf..37f1b84fb46cb 100644 --- a/datafusion/substrait/tests/utils.rs +++ b/datafusion/substrait/tests/utils.rs @@ -23,8 +23,9 @@ pub mod test { use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::prelude::SessionContext; - use datafusion_substrait::extensions::Extensions; - use datafusion_substrait::logical_plan::consumer::from_substrait_named_struct; + use datafusion_substrait::logical_plan::consumer::{ + from_substrait_named_struct, DefaultSubstraitConsumer, SubstraitConsumer, + }; use std::collections::HashMap; use std::fs::File; use std::io::BufReader; @@ -50,7 +51,8 @@ pub mod test { ctx: SessionContext, plan: &Plan, ) -> Result { - let schemas = TestSchemaCollector::collect_schemas(plan)?; + let consumer = Arc::new(DefaultSubstraitConsumer::default()); + let schemas = TestSchemaCollector::collect_schemas(consumer, plan)?; let mut schema_map: HashMap> = HashMap::new(); for (table_reference, table) in schemas.into_iter() { @@ -71,21 +73,24 @@ pub mod test { Ok(ctx) } - pub struct TestSchemaCollector { + pub struct TestSchemaCollector { schemas: Vec<(TableReference, Arc)>, + consumer: Arc, } - impl TestSchemaCollector { - fn new() -> Self { + impl TestSchemaCollector { + fn new(consumer: Arc) -> Self { TestSchemaCollector { schemas: Vec::new(), + consumer, } } fn collect_schemas( + consumer: Arc, plan: &Plan, ) -> Result)>> { - let mut schema_collector = Self::new(); + let mut schema_collector = Self::new(consumer); for plan_rel in plan.relations.iter() { let rel_type = plan_rel @@ -132,14 +137,8 @@ pub mod test { "No base schema found for NamedTable: {}", table_reference ))?; - let empty_extensions = Extensions { - functions: Default::default(), - types: Default::default(), - type_variations: Default::default(), - }; - let df_schema = - from_substrait_named_struct(substrait_schema, &empty_extensions)? + from_substrait_named_struct(self.consumer.as_ref(), substrait_schema)? .replace_qualifier(table_reference.clone()); let table = EmptyTable::new(df_schema.inner().clone());