Skip to content

Commit

Permalink
[bugfix] ScalarFunctionExpr does not preserve the nullable flag on ro…
Browse files Browse the repository at this point in the history
…undtrip (#13830)

* [test] coalesce round trip schema mismatch

* [proto] added the nullable flag in PhysicalScalarUdfNode

* [bugfix] propagate the nullable flag for serialized scalar UDFS
  • Loading branch information
ccciudatu authored Dec 19, 2024
1 parent 2439979 commit 3467011
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 7 deletions.
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,7 @@ message PhysicalScalarUdfNode {
repeated PhysicalExprNode args = 2;
optional bytes fun_definition = 3;
datafusion_common.ArrowType return_type = 4;
bool nullable = 5;
}

message PhysicalAggregateExprNode {
Expand Down
17 changes: 17 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 9 additions & 6 deletions datafusion/proto/src/physical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,12 +358,15 @@ pub fn parse_physical_expr(

let args = parse_physical_exprs(&e.args, registry, input_schema, codec)?;

Arc::new(ScalarFunctionExpr::new(
e.name.as_str(),
scalar_fun_def,
args,
convert_required!(e.return_type)?,
))
Arc::new(
ScalarFunctionExpr::new(
e.name.as_str(),
scalar_fun_def,
args,
convert_required!(e.return_type)?,
)
.with_nullable(e.nullable),
)
}
ExprType::LikeExpr(like_expr) => Arc::new(LikeExpr::new(
like_expr.negated,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/physical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ pub fn serialize_physical_expr(
args: serialize_physical_exprs(expr.args(), codec)?,
fun_definition: (!buf.is_empty()).then_some(buf),
return_type: Some(expr.return_type().try_into()?),
nullable: expr.nullable(),
},
)),
})
Expand Down
43 changes: 42 additions & 1 deletion datafusion/proto/tests/cases/roundtrip_physical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use crate::cases::{
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::compute::kernels::sort::SortOptions;
use datafusion::arrow::datatypes::{DataType, Field, IntervalUnit, Schema};
use datafusion::datasource::empty::EmptyTable;
use datafusion::datasource::file_format::csv::CsvSink;
use datafusion::datasource::file_format::json::JsonSink;
use datafusion::datasource::file_format::parquet::ParquetSink;
Expand Down Expand Up @@ -83,7 +84,7 @@ use datafusion::physical_plan::windows::{
WindowAggExec,
};
use datafusion::physical_plan::{
ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics,
displayable, ExecutionPlan, InputOrderMode, Partitioning, PhysicalExpr, Statistics,
};
use datafusion::prelude::SessionContext;
use datafusion::scalar::ScalarValue;
Expand All @@ -106,6 +107,7 @@ use datafusion_proto::physical_plan::{
AsExecutionPlan, DefaultPhysicalExtensionCodec, PhysicalExtensionCodec,
};
use datafusion_proto::protobuf;
use datafusion_proto::protobuf::PhysicalPlanNode;

/// Perform a serde roundtrip and assert that the string representation of the before and after plans
/// are identical. Note that this often isn't sufficient to guarantee that no information is
Expand Down Expand Up @@ -1525,3 +1527,42 @@ fn roundtrip_unnest() -> Result<()> {
);
roundtrip_test(Arc::new(unnest))
}

#[tokio::test]
async fn roundtrip_coalesce() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_table(
"t",
Arc::new(EmptyTable::new(Arc::new(Schema::new(Fields::from([
Arc::new(Field::new("f", DataType::Int64, false)),
]))))),
)?;
let df = ctx.sql("select coalesce(f) as f from t").await?;
let plan = df.create_physical_plan().await?;

let node = PhysicalPlanNode::try_from_physical_plan(
plan.clone(),
&DefaultPhysicalExtensionCodec {},
)?;
let node = PhysicalPlanNode::decode(node.encode_to_vec().as_slice())
.map_err(|e| DataFusionError::External(Box::new(e)))?;
let restored = node.try_into_physical_plan(
&ctx,
ctx.runtime_env().as_ref(),
&DefaultPhysicalExtensionCodec {},
)?;

assert_eq!(
plan.schema(),
restored.schema(),
"Schema mismatch for plans:\n>> initial:\n{}>> final: \n{}",
displayable(plan.as_ref())
.set_show_schema(true)
.indent(true),
displayable(restored.as_ref())
.set_show_schema(true)
.indent(true),
);

Ok(())
}

0 comments on commit 3467011

Please sign in to comment.