Skip to content

Commit

Permalink
[substrait] Add support for ExtensionTable
Browse files Browse the repository at this point in the history
  • Loading branch information
ccciudatu authored and adragomir committed Jan 9, 2025
1 parent 02861d6 commit e6d2b03
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 66 deletions.
24 changes: 2 additions & 22 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ use datafusion_expr::{
expr_rewriter::FunctionRewrite,
logical_plan::{DdlStatement, Statement},
planner::ExprPlanner,
Expr, UserDefinedLogicalNode, WindowUDF,
Expr, WindowUDF,
};

// backwards compatibility
Expand Down Expand Up @@ -1682,27 +1682,7 @@ pub enum RegisterFunction {
#[derive(Debug)]
pub struct EmptySerializerRegistry;

impl SerializerRegistry for EmptySerializerRegistry {
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}

fn deserialize_logical_plan(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}
}
impl SerializerRegistry for EmptySerializerRegistry {}

/// Describes which SQL statements can be run.
///
Expand Down
44 changes: 39 additions & 5 deletions datafusion/expr/src/registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
use crate::expr_rewriter::FunctionRewrite;
use crate::planner::ExprPlanner;
use crate::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
use crate::{AggregateUDF, ScalarUDF, TableSource, UserDefinedLogicalNode, WindowUDF};
use datafusion_common::{not_impl_err, plan_datafusion_err, HashMap, Result};
use std::collections::HashSet;
use std::fmt::Debug;
Expand Down Expand Up @@ -123,24 +123,58 @@ pub trait FunctionRegistry {
}
}

/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]
/// and custom table providers for which the name alone is meaningless in the target
/// execution context, e.g. UDTFs, manually registered tables etc.
pub trait SerializerRegistry: Debug + Send + Sync {
/// Serialize this node to a byte array. This serialization should not include
/// input plans.
fn serialize_logical_plan(
&self,
node: &dyn UserDefinedLogicalNode,
) -> Result<Vec<u8>>;
) -> Result<NamedBytes> {
not_impl_err!(
"Serializing user defined logical plan node `{}` is not supported",
node.name()
)
}

/// Deserialize user defined logical plan node ([UserDefinedLogicalNode]) from
/// bytes.
fn deserialize_logical_plan(
&self,
name: &str,
bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>>;
_bytes: &[u8],
) -> Result<Arc<dyn UserDefinedLogicalNode>> {
not_impl_err!(
"Deserializing user defined logical plan node `{name}` is not supported"
)
}

/// Serialized table definition for UDTFs or some other table provider implementation that
/// can't be marshaled by reference.
fn serialize_custom_table(
&self,
_table: &dyn TableSource,
) -> Result<Option<NamedBytes>> {
Ok(None)
}

/// Deserialize a custom table.
fn deserialize_custom_table(
&self,
name: &str,
_bytes: &[u8],
) -> Result<Arc<dyn TableSource>> {
not_impl_err!("Deserializing custom table `{name}` is not supported")
}
}

/// A sequence of bytes with a string qualifier. Meant to encapsulate serialized extensions
/// that need to carry their type, e.g. the `type_url` for protobuf messages.
#[derive(Debug, Clone)]
pub struct NamedBytes(pub String, pub Vec<u8>);

/// A [`FunctionRegistry`] that uses in memory [`HashMap`]s
#[derive(Default, Debug)]
pub struct MemoryFunctionRegistry {
Expand Down
90 changes: 69 additions & 21 deletions datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::logical_expr::expr::{Exists, InSubquery, Sort};

use datafusion::logical_expr::{
Aggregate, BinaryExpr, Case, Cast, EmptyRelation, Expr, ExprSchemable, Extension,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TryCast, Values,
LogicalPlan, Operator, Projection, SortExpr, Subquery, TableSource, TryCast, Values,
};
use substrait::proto::aggregate_rel::Grouping;
use substrait::proto::expression as substrait_expression;
Expand Down Expand Up @@ -86,6 +86,7 @@ use substrait::proto::expression::{
SingularOrList, SwitchExpression, WindowFunction,
};
use substrait::proto::read_rel::local_files::file_or_files::PathType::UriFile;
use substrait::proto::read_rel::ExtensionTable;
use substrait::proto::rel_common::{Emit, EmitKind};
use substrait::proto::set_rel::SetOp;
use substrait::proto::{
Expand Down Expand Up @@ -457,6 +458,20 @@ pub trait SubstraitConsumer: Send + Sync + Sized {
user_defined_literal.type_reference
)
}

fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
) -> Result<Arc<dyn TableSource>> {
if let Some(ext_detail) = extension_table.detail.as_ref() {
substrait_err!(
"Missing handler for extension table: {}",
&ext_detail.type_url
)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
}
}

/// Convert Substrait Rel to DataFusion DataFrame
Expand Down Expand Up @@ -578,6 +593,19 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
Ok(LogicalPlan::Extension(Extension { node: plan }))
}

fn consume_extension_table(
&self,
extension_table: &ExtensionTable,
) -> Result<Arc<dyn TableSource>> {
if let Some(ext_detail) = &extension_table.detail {
self.state
.serializer_registry()
.deserialize_custom_table(&ext_detail.type_url, &ext_detail.value)
} else {
substrait_err!("Unexpected empty detail in ExtensionTable")
}
}
}

// Substrait PrecisionTimestampTz indicates that the timestamp is relative to UTC, which
Expand Down Expand Up @@ -1323,26 +1351,14 @@ pub async fn from_read_rel(
read: &ReadRel,
) -> Result<LogicalPlan> {
async fn read_with_schema(
consumer: &impl SubstraitConsumer,
table_ref: TableReference,
table_source: Arc<dyn TableSource>,
schema: DFSchema,
projection: &Option<MaskExpression>,
) -> Result<LogicalPlan> {
let schema = schema.replace_qualifier(table_ref.clone());

let plan = {
let provider = match consumer.resolve_table_ref(&table_ref).await? {
Some(ref provider) => Arc::clone(provider),
_ => return plan_err!("No table named '{table_ref}'"),
};

LogicalPlanBuilder::scan(
table_ref,
provider_as_source(Arc::clone(&provider)),
None,
)?
.build()?
};
let plan = { LogicalPlanBuilder::scan(table_ref, table_source, None)?.build()? };

ensure_schema_compatibility(plan.schema(), schema.clone())?;

Expand All @@ -1351,6 +1367,17 @@ pub async fn from_read_rel(
apply_projection(plan, schema)
}

async fn table_source(
consumer: &impl SubstraitConsumer,
table_ref: &TableReference,
) -> Result<Arc<dyn TableSource>> {
if let Some(provider) = consumer.resolve_table_ref(table_ref).await? {
Ok(provider_as_source(provider))
} else {
plan_err!("No table named '{table_ref}'")
}
}

let named_struct = read.base_schema.as_ref().ok_or_else(|| {
substrait_datafusion_err!("No base schema provided for Read Relation")
})?;
Expand All @@ -1376,10 +1403,10 @@ pub async fn from_read_rel(
table: nt.names[2].clone().into(),
},
};

let table_source = table_source(consumer, &table_reference).await?;
read_with_schema(
consumer,
table_reference,
table_source,
substrait_schema,
&read.projection,
)
Expand Down Expand Up @@ -1458,17 +1485,38 @@ pub async fn from_read_rel(
let name = filename.unwrap();
// directly use unwrap here since we could determine it is a valid one
let table_reference = TableReference::Bare { table: name.into() };
let table_source = table_source(consumer, &table_reference).await?;

read_with_schema(
consumer,
table_reference,
table_source,
substrait_schema,
&read.projection,
)
.await
}
Some(ReadType::ExtensionTable(ext)) => {
// look for the original table name under `rel.common.hint.alias`
// in case the producer was kind enough to put it there.
let name_hint = read
.common
.as_ref()
.and_then(|rel_common| rel_common.hint.as_ref())
.map(|hint| hint.alias.as_str().trim())
.filter(|alias| !alias.is_empty());
// if no name hint was provided, use the name that datafusion
// sets for UDTFs
let table_name = name_hint.unwrap_or("tmp_table");
read_with_schema(
TableReference::from(table_name),
consumer.consume_extension_table(ext)?,
substrait_schema,
&read.projection,
)
.await
}
_ => {
not_impl_err!("Unsupported ReadType: {:?}", read.read_type)
None => {
substrait_err!("Unexpected empty read_type")
}
}
}
Expand Down Expand Up @@ -1871,7 +1919,7 @@ pub async fn from_substrait_sorts(
},
None => not_impl_err!("Sort without sort kind is invalid"),
};
let (asc, nulls_first) = asc_nullfirst.unwrap();
let (asc, nulls_first) = asc_nullfirst?;
sorts.push(Sort {
expr,
asc,
Expand Down
Loading

0 comments on commit e6d2b03

Please sign in to comment.