Skip to content

Commit

Permalink
fix: Ignore nullability of list elements when consuming Substrait (#1…
Browse files Browse the repository at this point in the history
…0874)

* Ignore nullability of list elements when consuming Substrait

DataFusion (= Arrow) is quite strict about nullability, specifically,
when using e.g. LogicalPlan::Values, the given schema must match the
given literals exactly - including nullability.
This is non-trivial to do when converting schema and literals separately.

The existing implementation for from_substrait_literal already creates
lists that are always nullable
(see ScalarValue::new_list => array_into_list_array).
This reverts part of #10640 to
align from_substrait_type with that behavior.

This is the error I was hitting:
```
ArrowError(InvalidArgumentError("column types must match schema types, expected
List(Field { name: \"item\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }) but found
List(Field { name: \"item\", data_type: Int32, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }) at column index 0"), None)
```

* use `Field::new_list_field` in `array_into_(large_)list_array`

just for consistency, to reduce the places where "item" is written out

* add a test for non-nullable lists
  • Loading branch information
Blizzara authored Jun 12, 2024
1 parent 87d8267 commit dfdda7c
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 21 deletions.
14 changes: 7 additions & 7 deletions datafusion/common/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
ListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
offsets,
arr,
None,
Expand All @@ -366,7 +366,7 @@ pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
pub fn array_into_large_list_array(arr: ArrayRef) -> LargeListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
LargeListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
offsets,
arr,
None,
Expand All @@ -379,7 +379,7 @@ pub fn array_into_fixed_size_list_array(
) -> FixedSizeListArray {
let list_size = list_size as i32;
FixedSizeListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
Arc::new(Field::new_list_field(arr.data_type().to_owned(), true)),
list_size,
arr,
None,
Expand Down Expand Up @@ -420,7 +420,7 @@ pub fn arrays_into_list_array(
let data_type = arr[0].data_type().to_owned();
let values = arr.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
Ok(ListArray::new(
Arc::new(Field::new("item", data_type, true)),
Arc::new(Field::new_list_field(data_type, true)),
OffsetBuffer::from_lengths(lens),
arrow::compute::concat(values.as_slice())?,
None,
Expand All @@ -435,7 +435,7 @@ pub fn arrays_into_list_array(
/// use datafusion_common::utils::base_type;
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
/// assert_eq!(base_type(&data_type), DataType::Int32);
///
/// let data_type = DataType::Int32;
Expand All @@ -458,10 +458,10 @@ pub fn base_type(data_type: &DataType) -> DataType {
/// use datafusion_common::utils::coerced_type_with_base_type_only;
/// use std::sync::Arc;
///
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
/// let data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true)));
/// let base_type = DataType::Float64;
/// let coerced_type = coerced_type_with_base_type_only(&data_type, &base_type);
/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new("item", DataType::Float64, true))));
/// assert_eq!(coerced_type, DataType::List(Arc::new(Field::new_list_field(DataType::Float64, true))));
pub fn coerced_type_with_base_type_only(
data_type: &DataType,
base_type: &DataType,
Expand Down
4 changes: 3 additions & 1 deletion datafusion/substrait/src/logical_plan/consumer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,9 @@ fn from_substrait_type(
})?;
let field = Arc::new(Field::new_list_field(
from_substrait_type(inner_type, dfs_names, name_idx)?,
is_substrait_type_nullable(inner_type)?,
// We ignore Substrait's nullability here to match to_substrait_literal
// which always creates nullable lists
true,
));
match list.type_variation_reference {
DEFAULT_CONTAINER_TYPE_VARIATION_REF => Ok(DataType::List(field)),
Expand Down
14 changes: 6 additions & 8 deletions datafusion/substrait/src/logical_plan/producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2309,14 +2309,12 @@ mod test {
round_trip_type(DataType::Decimal128(10, 2))?;
round_trip_type(DataType::Decimal256(30, 2))?;

for nullable in [true, false] {
round_trip_type(DataType::List(
Field::new_list_field(DataType::Int32, nullable).into(),
))?;
round_trip_type(DataType::LargeList(
Field::new_list_field(DataType::Int32, nullable).into(),
))?;
}
round_trip_type(DataType::List(
Field::new_list_field(DataType::Int32, true).into(),
))?;
round_trip_type(DataType::LargeList(
Field::new_list_field(DataType::Int32, true).into(),
))?;

round_trip_type(DataType::Struct(
vec![
Expand Down
32 changes: 27 additions & 5 deletions datafusion/substrait/tests/cases/logical_plans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#[cfg(test)]
mod tests {
use datafusion::common::Result;
use datafusion::dataframe::DataFrame;
use datafusion::prelude::{CsvReadOptions, SessionContext};
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
Expand All @@ -38,11 +39,7 @@ mod tests {

// File generated with substrait-java's Isthmus:
// ./isthmus-cli/build/graal/isthmus "select not d from data" -c "create table data (d boolean)"
let path = "tests/testdata/select_not_bool.substrait.json";
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");
let proto = read_json("tests/testdata/select_not_bool.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;

Expand All @@ -54,6 +51,31 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn non_nullable_lists() -> Result<()> {
// DataFusion's Substrait consumer treats all lists as nullable, even if the Substrait plan specifies them as non-nullable.
// That's because implementing the non-nullability consistently is non-trivial.
// This test confirms that reading a plan with non-nullable lists works as expected.
let ctx = create_context().await?;
let proto = read_json("tests/testdata/non_nullable_lists.substrait.json");

let plan = from_substrait_plan(&ctx, &proto).await?;

assert_eq!(format!("{:?}", &plan), "Values: (List([1, 2]))");

// Need to trigger execution to ensure that Arrow has validated the plan
DataFrame::new(ctx.state(), plan).show().await?;

Ok(())
}

fn read_json(path: &str) -> Plan {
serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json")
}

async fn create_context() -> datafusion::common::Result<SessionContext> {
let ctx = SessionContext::new();
ctx.register_csv("DATA", "tests/testdata/data.csv", CsvReadOptions::new())
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
{
"extensionUris": [],
"extensions": [],
"relations": [
{
"root": {
"input": {
"read": {
"common": {
"direct": {
}
},
"baseSchema": {
"names": [
"col"
],
"struct": {
"types": [
{
"list": {
"type": {
"i32": {
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
}
],
"typeVariationReference": 0,
"nullability": "NULLABILITY_REQUIRED"
}
},
"virtualTable": {
"values": [
{
"fields": [
{
"list": {
"values": [
{
"i32": 1,
"nullable": false,
"typeVariationReference": 0
},
{
"i32": 2,
"nullable": false,
"typeVariationReference": 0
}
]
},
"nullable": false,
"typeVariationReference": 0
}
]
}
]
}
}
},
"names": [
"col"
]
}
}
],
"expectedTypeUrls": []
}

0 comments on commit dfdda7c

Please sign in to comment.