diff --git a/parquet/src/arrow/schema/complex.rs b/parquet/src/arrow/schema/complex.rs index bf4835994ac3..1d527eab3b57 100644 --- a/parquet/src/arrow/schema/complex.rs +++ b/parquet/src/arrow/schema/complex.rs @@ -72,6 +72,54 @@ impl ParquetField { } } + /// Converts `self` into an arrow list, with its current type as the field type + /// accept an optional `list_data_type` to specify the type of list to create + /// + /// This is used to convert [deprecated repeated columns] (not in a list), into their arrow representation + /// + /// [deprecated repeated columns]: https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L649-L650 + fn into_list_with_arrow_list_hint( + self, + parquet_field_type: &Type, + list_data_type: Option, + ) -> Result { + let arrow_field = match &list_data_type { + Some(DataType::List(field_hint)) + | Some(DataType::LargeList(field_hint)) + | Some(DataType::FixedSizeList(field_hint, _)) => Some(field_hint.as_ref()), + Some(_) => return Err(general_err!( + "Internal error: should be validated earlier that list_data_type is only a type of list" + )), + None => None, + }; + + let arrow_field = convert_field( + parquet_field_type, + &self, + arrow_field, + // Only add the field id to the list and not to the element + false, + )? + .with_nullable(false); + + Ok(ParquetField { + rep_level: self.rep_level, + def_level: self.def_level, + nullable: false, + arrow_type: match list_data_type { + Some(DataType::List(_)) => DataType::List(Arc::new(arrow_field)), + Some(DataType::LargeList(_)) => DataType::LargeList(Arc::new(arrow_field)), + Some(DataType::FixedSizeList(_, len)) => { + DataType::FixedSizeList(Arc::new(arrow_field), len) + } + _ => DataType::List(Arc::new(arrow_field)), + }, + field_type: ParquetFieldType::Group { + children: vec![self], + }, + }) + } + /// Returns a list of [`ParquetField`] children if this is a group type pub fn children(&self) -> Option<&[Self]> { match &self.field_type { @@ -100,6 +148,13 @@ struct VisitorContext { def_level: i16, /// An optional [`DataType`] sourced from the embedded arrow schema data_type: Option, + + /// Whether to treat repeated types as list from arrow types + /// when true, if data_type provided it should be DataType::List() (or other list type) + /// and the list field data type would be treated as the hint for the parquet type + /// + /// when false, if data_type provided it will be treated as the hint without unwrapping + treat_repeated_as_list_arrow_hint: bool, } impl VisitorContext { @@ -143,7 +198,25 @@ impl Visitor { let repetition = get_repetition(primitive_type); let (def_level, rep_level, nullable) = context.levels(repetition); - let arrow_type = convert_primitive(primitive_type, context.data_type)?; + let primitive_arrow_data_type = match repetition { + Repetition::REPEATED if context.treat_repeated_as_list_arrow_hint => { + let arrow_field = match &context.data_type { + Some(DataType::List(f)) => Some(f.as_ref()), + Some(DataType::LargeList(f)) => Some(f.as_ref()), + Some(DataType::FixedSizeList(f, _)) => Some(f.as_ref()), + Some(d) => return Err(arrow_err!( + "incompatible arrow schema, expected list got {} for repeated primitive field", + d + )), + None => None, + }; + + arrow_field.map(|f| f.data_type().clone()) + } + _ => context.data_type.clone(), + }; + + let arrow_type = convert_primitive(primitive_type, primitive_arrow_data_type)?; let primitive_field = ParquetField { rep_level, @@ -157,6 +230,9 @@ impl Visitor { }; Ok(Some(match repetition { + Repetition::REPEATED if context.treat_repeated_as_list_arrow_hint => { + primitive_field.into_list_with_arrow_list_hint(primitive_type, context.data_type)? + } Repetition::REPEATED => primitive_field.into_list(primitive_type.name()), _ => primitive_field, })) @@ -174,7 +250,27 @@ impl Visitor { let parquet_fields = struct_type.get_fields(); // Extract any arrow fields from the hints - let arrow_fields = match &context.data_type { + let arrow_struct = match repetition { + Repetition::REPEATED if context.treat_repeated_as_list_arrow_hint => { + let arrow_field = match &context.data_type { + Some(DataType::List(f)) => Some(f.as_ref()), + Some(DataType::LargeList(f)) => Some(f.as_ref()), + Some(DataType::FixedSizeList(f, _)) => Some(f.as_ref()), + Some(d) => { + return Err(arrow_err!( + "incompatible arrow schema, expected list got {} for repeated struct field", + d + )) + } + None => None, + }; + + arrow_field.map(|f| f.data_type()) + } + _ => context.data_type.as_ref(), + }; + + let arrow_fields = match &arrow_struct { Some(DataType::Struct(fields)) => { if fields.len() != parquet_fields.len() { return Err(arrow_err!( @@ -219,12 +315,13 @@ impl Visitor { rep_level, def_level, data_type, + treat_repeated_as_list_arrow_hint: true, }; - if let Some(mut child) = self.dispatch(parquet_field, child_ctx)? { + if let Some(child) = self.dispatch(parquet_field, child_ctx)? { // The child type returned may be different from what is encoded in the arrow // schema in the event of a mismatch or a projection - child_fields.push(convert_field(parquet_field, &mut child, arrow_field)?); + child_fields.push(convert_field(parquet_field, &child, arrow_field, true)?); children.push(child); } } @@ -242,6 +339,9 @@ impl Visitor { }; Ok(Some(match repetition { + Repetition::REPEATED if context.treat_repeated_as_list_arrow_hint => { + struct_field.into_list_with_arrow_list_hint(struct_type, context.data_type)? + } Repetition::REPEATED => struct_field.into_list(struct_type.name()), _ => struct_field, })) @@ -336,6 +436,8 @@ impl Visitor { rep_level, def_level, data_type: arrow_key.map(|x| x.data_type().clone()), + // Key is not repeated + treat_repeated_as_list_arrow_hint: false, }; self.dispatch(map_key, context)? @@ -346,6 +448,8 @@ impl Visitor { rep_level, def_level, data_type: arrow_value.map(|x| x.data_type().clone()), + // Value type can be repeated + treat_repeated_as_list_arrow_hint: true, }; self.dispatch(map_value, context)? @@ -353,13 +457,13 @@ impl Visitor { // Need both columns to be projected match (maybe_key, maybe_value) { - (Some(mut key), Some(mut value)) => { + (Some(key), Some(value)) => { let key_field = Arc::new( - convert_field(map_key, &mut key, arrow_key)? + convert_field(map_key, &key, arrow_key, true)? // The key is always non-nullable (#5630) .with_nullable(false), ); - let value_field = Arc::new(convert_field(map_value, &mut value, arrow_value)?); + let value_field = Arc::new(convert_field(map_value, &value, arrow_value, true)?); let field_metadata = match arrow_map { Some(field) => field.metadata().clone(), _ => HashMap::default(), @@ -442,6 +546,7 @@ impl Visitor { rep_level: context.rep_level, def_level, data_type: arrow_field.map(|f| f.data_type().clone()), + treat_repeated_as_list_arrow_hint: false, }; return match self.visit_primitive(repeated_field, context) { @@ -473,6 +578,7 @@ impl Visitor { rep_level: context.rep_level, def_level, data_type: arrow_field.map(|f| f.data_type().clone()), + treat_repeated_as_list_arrow_hint: false, }; return match self.visit_struct(repeated_field, context) { @@ -493,11 +599,12 @@ impl Visitor { def_level, rep_level, data_type: arrow_field.map(|f| f.data_type().clone()), + treat_repeated_as_list_arrow_hint: true, }; match self.dispatch(item_type, new_context) { - Ok(Some(mut item)) => { - let item_field = Arc::new(convert_field(item_type, &mut item, arrow_field)?); + Ok(Some(item)) => { + let item_field = Arc::new(convert_field(item_type, &item, arrow_field, true)?); // Use arrow type as hint for index size let arrow_type = match context.data_type { @@ -547,8 +654,9 @@ impl Visitor { /// dictated by the `parquet_type`, and any metadata from `arrow_hint` fn convert_field( parquet_type: &Type, - field: &mut ParquetField, + field: &ParquetField, arrow_hint: Option<&Field>, + add_field_id: bool, ) -> Result { let name = parquet_type.name(); let data_type = field.arrow_type.clone(); @@ -572,7 +680,7 @@ fn convert_field( None => { let mut ret = Field::new(name, data_type, nullable); let basic_info = parquet_type.get_basic_info(); - if basic_info.has_id() { + if add_field_id && basic_info.has_id() { let mut meta = HashMap::with_capacity(1); meta.insert( PARQUET_FIELD_ID_META_KEY.to_string(), @@ -604,6 +712,7 @@ pub fn convert_schema( rep_level: 0, def_level: 0, data_type: embedded_arrow_schema.map(|fields| DataType::Struct(fields.clone())), + treat_repeated_as_list_arrow_hint: true, }; visitor.dispatch(&schema.root_schema_ptr(), context) @@ -620,7 +729,1070 @@ pub fn convert_type(parquet_type: &TypePtr) -> Result { rep_level: 0, def_level: 0, data_type: None, + // We might be inside list + treat_repeated_as_list_arrow_hint: false, }; Ok(visitor.dispatch(parquet_type, context)?.unwrap()) } + +#[cfg(test)] +mod tests { + use crate::arrow::schema::complex::convert_schema; + use crate::arrow::{ProjectionMask, PARQUET_FIELD_ID_META_KEY}; + use crate::schema::parser::parse_message_type; + use crate::schema::types::SchemaDescriptor; + use arrow_schema::{DataType, Field, Fields}; + use std::sync::Arc; + + trait WithFieldId { + fn with_field_id(self, id: i32) -> Self; + } + impl WithFieldId for arrow_schema::Field { + fn with_field_id(self, id: i32) -> Self { + let mut metadata = self.metadata().clone(); + metadata.insert(PARQUET_FIELD_ID_META_KEY.to_string(), id.to_string()); + self.with_metadata(metadata) + } + } + + fn test_roundtrip(message_type: &str) -> crate::errors::Result<()> { + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + // Should be able to convert the same thing + let converted_again = + convert_schema(&schema, ProjectionMask::all(), Some(schema_fields))?.unwrap(); + + // Assert that we changed to Utf8 + assert_eq!(converted_again.arrow_type, converted.arrow_type); + + Ok(()) + } + + fn test_expected_type( + message_type: &str, + expected_fields: Fields, + ) -> crate::errors::Result<()> { + test_roundtrip(message_type)?; + + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + assert_eq!(schema_fields, &expected_fields); + + Ok(()) + } + + /// Taken from the example in [Parquet Format - Nested Types - Lists - Backward-compatibility rules](https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L766-L769) + #[test] + fn basic_backward_compatible_list_1() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_list (LIST) { + repeated int32 element; + } + } + ", + Fields::from(vec![ + // Rule 1: List (nullable list, non-null elements) + Field::new( + "my_list", + DataType::List(Arc::new(Field::new("element", DataType::Int32, false))), + true, + ), + ]), + ) + } + + /// Taken from the example in [Parquet Format - Nested Types - Lists - Backward-compatibility rules](https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L771-L777) + #[test] + fn basic_backward_compatible_list_2() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_list (LIST) { + repeated group element { + required binary str (STRING); + required int32 num; + } + } + } + ", + Fields::from(vec![ + // Rule 2: List> (nullable list, non-null elements) + Field::new( + "my_list", + DataType::List(Arc::new(Field::new( + "element", + DataType::Struct(Fields::from(vec![ + Field::new("str", DataType::Utf8, false), + Field::new("num", DataType::Int32, false), + ])), + false, + ))), + true, + ), + ]), + ) + } + + /// Taken from the example in [Parquet Format - Nested Types - Lists - Backward-compatibility rules](https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L779-L784) + #[test] + fn basic_backward_compatible_list_3() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_list (LIST) { + repeated group array (LIST) { + repeated int32 array; + } + } + } + ", + Fields::from(vec![ + // Rule 3: List> (nullable outer list, non-null elements) + Field::new( + "my_list", + DataType::List(Arc::new(Field::new( + "array", + DataType::List(Arc::new(Field::new("array", DataType::Int32, false))), + false, + ))), + true, + ), + ]), + ) + } + + /// Taken from the example in [Parquet Format - Nested Types - Lists - Backward-compatibility rules](https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L786-L791) + #[test] + fn basic_backward_compatible_list_4_1() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_list (LIST) { + repeated group array { + required binary str (STRING); + } + } + } + ", + Fields::from(vec![ + // Rule 4: List> (nullable list, non-null elements) + Field::new( + "my_list", + DataType::List(Arc::new(Field::new( + "array", + DataType::Struct(Fields::from(vec![Field::new( + "str", + DataType::Utf8, + false, + )])), + false, + ))), + true, + ), + ]), + ) + } + + /// Taken from the example in [Parquet Format - Nested Types - Lists - Backward-compatibility rules](https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L793-L798) + #[test] + fn basic_backward_compatible_list_4_2() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_list (LIST) { + repeated group my_list_tuple { + required binary str (STRING); + } + } + } + ", + Fields::from(vec![ + // Rule 4: List> (nullable list, non-null elements) + Field::new( + "my_list", + DataType::List(Arc::new(Field::new( + "my_list_tuple", + DataType::Struct(Fields::from(vec![Field::new( + "str", + DataType::Utf8, + false, + )])), + false, + ))), + true, + ), + ]), + ) + } + + /// Taken from the example in [Parquet Format - Nested Types - Lists - Backward-compatibility rules](https://github.com/apache/parquet-format/blob/9fd57b59e0ce1a82a69237dcf8977d3e72a2965d/LogicalTypes.md?plain=1#L800-L805) + #[test] + fn basic_backward_compatible_list_5() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_list (LIST) { + repeated group element { + optional binary str (STRING); + } + } + } + ", + Fields::from(vec![ + // Rule 5: List (nullable list, nullable elements) + Field::new( + "my_list", + DataType::List(Arc::new(Field::new("str", DataType::Utf8, true))), + true, + ), + ]), + ) + } + + #[test] + fn basic_backward_compatible_map_1() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_map (MAP) { + repeated group map { + required binary str (STRING); + required int32 num; + } + } + } + ", + Fields::from(vec![ + // Map (nullable map, non-null values) + Field::new( + "my_map", + DataType::Map( + Arc::new(Field::new( + "map", + DataType::Struct(Fields::from(vec![ + Field::new("str", DataType::Utf8, false), + Field::new("num", DataType::Int32, false), + ])), + false, + )), + false, + ), + true, + ), + ]), + ) + } + + #[test] + fn basic_backward_compatible_map_2() -> crate::errors::Result<()> { + test_expected_type( + " + message schema { + optional group my_map (MAP_KEY_VALUE) { + repeated group map { + required binary key (STRING); + optional int32 value; + } + } + } + ", + Fields::from(vec![ + // Map (nullable map, nullable values) + Field::new( + "my_map", + DataType::Map( + Arc::new(Field::new( + "map", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ), + ]), + ) + } + + #[test] + fn convert_schema_with_nested_list_repeated_primitive() -> crate::errors::Result<()> { + test_roundtrip( + " + message schema { + optional group f1 (LIST) { + repeated group element { + repeated int32 element; + } + } + } + ", + ) + } + + #[test] + fn convert_schema_with_repeated_primitive_keep_field_id() -> crate::errors::Result<()> { + let message_type = " + message schema { + repeated BYTE_ARRAY col_1 = 1; + } + "; + + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + assert_eq!(schema_fields.len(), 1); + + let expected_schema = DataType::Struct(Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "col_1", + DataType::List(Arc::new( + // No metadata on inner field + arrow_schema::Field::new("col_1", DataType::Binary, false), + )), + false, + ) + // add the field id to the outer list + .with_field_id(1), + )])); + + assert_eq!(converted.arrow_type, expected_schema); + + Ok(()) + } + + #[test] + fn convert_schema_with_repeated_primitive_should_use_inferred_schema( + ) -> crate::errors::Result<()> { + let message_type = " + message schema { + repeated BYTE_ARRAY col_1 = 1; + } + "; + + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + assert_eq!(schema_fields.len(), 1); + + let expected_schema = DataType::Struct(Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "col_1", + DataType::List(Arc::new(arrow_schema::Field::new( + "col_1", + DataType::Binary, + false, + ))), + false, + ) + .with_metadata(schema_fields[0].metadata().clone()), + )])); + + assert_eq!(converted.arrow_type, expected_schema); + + let utf8_instead_of_binary = Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "col_1", + DataType::List(Arc::new(arrow_schema::Field::new( + "col_1", + DataType::Utf8, + false, + ))), + false, + ) + .with_metadata(schema_fields[0].metadata().clone()), + )]); + + // Should be able to convert the same thing + let converted_again = convert_schema( + &schema, + ProjectionMask::all(), + Some(&utf8_instead_of_binary), + )? + .unwrap(); + + // Assert that we changed to Utf8 + assert_eq!( + converted_again.arrow_type, + DataType::Struct(utf8_instead_of_binary) + ); + + Ok(()) + } + + #[test] + fn convert_schema_with_repeated_primitive_should_use_inferred_schema_for_list_as_well( + ) -> crate::errors::Result<()> { + let message_type = " + message schema { + repeated BYTE_ARRAY col_1 = 1; + } + "; + + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + assert_eq!(schema_fields.len(), 1); + + let expected_schema = DataType::Struct(Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "col_1", + DataType::List(Arc::new(arrow_schema::Field::new( + "col_1", + DataType::Binary, + false, + ))), + false, + ) + .with_metadata(schema_fields[0].metadata().clone()), + )])); + + assert_eq!(converted.arrow_type, expected_schema); + + let utf8_instead_of_binary = Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "col_1", + // Inferring as LargeList instead of List + DataType::LargeList(Arc::new(arrow_schema::Field::new( + "col_1", + DataType::Utf8, + false, + ))), + false, + ) + .with_metadata(schema_fields[0].metadata().clone()), + )]); + + // Should be able to convert the same thing + let converted_again = convert_schema( + &schema, + ProjectionMask::all(), + Some(&utf8_instead_of_binary), + )? + .unwrap(); + + // Assert that we changed to Utf8 + assert_eq!( + converted_again.arrow_type, + DataType::Struct(utf8_instead_of_binary) + ); + + Ok(()) + } + + #[test] + fn convert_schema_with_repeated_struct_and_inferred_schema() -> crate::errors::Result<()> { + test_roundtrip( + " + message schema { + repeated group my_col_1 = 1 { + optional binary my_col_2 = 2; + optional binary my_col_3 = 3; + optional group my_col_4 = 4 { + optional int64 my_col_5 = 5; + optional int32 my_col_6 = 6; + } + } + } + ", + ) + } + + #[test] + fn convert_schema_with_repeated_struct_and_inferred_schema_and_field_id( + ) -> crate::errors::Result<()> { + let message_type = " + message schema { + repeated group my_col_1 = 1 { + optional binary my_col_2 = 2; + optional binary my_col_3 = 3; + optional group my_col_4 = 4 { + optional int64 my_col_5 = 5; + optional int32 my_col_6 = 6; + } + } + } + "; + + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + assert_eq!(schema_fields.len(), 1); + + // Should be able to convert the same thing + let converted_again = + convert_schema(&schema, ProjectionMask::all(), Some(schema_fields))?.unwrap(); + + // Assert that we changed to Utf8 + assert_eq!(converted_again.arrow_type, converted.arrow_type); + + Ok(()) + } + + #[test] + fn convert_schema_with_nested_repeated_struct_and_primitives() -> crate::errors::Result<()> { + let message_type = " +message schema { + repeated group my_col_1 = 1 { + optional binary my_col_2 = 2; + repeated BYTE_ARRAY my_col_3 = 3; + repeated group my_col_4 = 4 { + optional int64 my_col_5 = 5; + repeated binary my_col_6 = 6; + } + } +} +"; + + let parsed_input_schema = Arc::new(parse_message_type(message_type)?); + let schema = SchemaDescriptor::new(parsed_input_schema); + + let converted = convert_schema(&schema, ProjectionMask::all(), None)?.unwrap(); + + let DataType::Struct(schema_fields) = &converted.arrow_type else { + panic!("Expected struct from convert_schema"); + }; + + assert_eq!(schema_fields.len(), 1); + + // Build expected schema + let expected_schema = DataType::Struct(Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "my_col_1", + DataType::List(Arc::new(arrow_schema::Field::new( + "my_col_1", + DataType::Struct(Fields::from(vec![ + Arc::new( + arrow_schema::Field::new("my_col_2", DataType::Binary, true) + .with_field_id(2), + ), + Arc::new( + arrow_schema::Field::new( + "my_col_3", + DataType::List(Arc::new(arrow_schema::Field::new( + "my_col_3", + DataType::Binary, + false, + ))), + false, + ) + // add the field id to the outer list + .with_field_id(3), + ), + Arc::new( + arrow_schema::Field::new( + "my_col_4", + DataType::List(Arc::new(arrow_schema::Field::new( + "my_col_4", + DataType::Struct(Fields::from(vec![ + Arc::new( + arrow_schema::Field::new( + "my_col_5", + DataType::Int64, + true, + ) + // add the field id to the outer list + .with_field_id(5), + ), + Arc::new( + arrow_schema::Field::new( + "my_col_6", + DataType::List(Arc::new(arrow_schema::Field::new( + "my_col_6", + DataType::Binary, + false, + ))), + false, + ) + // add the field id to the outer list + .with_field_id(6), + ), + ])), + false, + ))), + false, + ) + // add the field id to the outer list + .with_field_id(4), + ), + ])), + false, + ))), + false, + ) + // add the field id to the outer list + .with_field_id(1), + )])); + + assert_eq!(converted.arrow_type, expected_schema); + + // Test conversion with inferred schema + let converted_again = + convert_schema(&schema, ProjectionMask::all(), Some(schema_fields))?.unwrap(); + + assert_eq!(converted_again.arrow_type, converted.arrow_type); + + // Test conversion with modified schema (change lists to either LargeList or FixedSizeList) + // as well as changing Binary to Utf8 or BinaryView + let modified_schema_fields = Fields::from(vec![Arc::new( + arrow_schema::Field::new( + "my_col_1", + DataType::LargeList(Arc::new(arrow_schema::Field::new( + "my_col_1", + DataType::Struct(Fields::from(vec![ + Arc::new( + arrow_schema::Field::new("my_col_2", DataType::LargeBinary, true) + .with_field_id(2), + ), + Arc::new( + arrow_schema::Field::new( + "my_col_3", + DataType::LargeList(Arc::new(arrow_schema::Field::new( + "my_col_3", + DataType::Utf8, + false, + ))), + false, + ) + // add the field id to the outer list + .with_field_id(3), + ), + Arc::new( + arrow_schema::Field::new( + "my_col_4", + DataType::FixedSizeList( + Arc::new(arrow_schema::Field::new( + "my_col_4", + DataType::Struct(Fields::from(vec![ + Arc::new( + arrow_schema::Field::new( + "my_col_5", + DataType::Int64, + true, + ) + .with_field_id(5), + ), + Arc::new( + arrow_schema::Field::new( + "my_col_6", + DataType::LargeList(Arc::new( + arrow_schema::Field::new( + "my_col_6", + DataType::BinaryView, + false, + ), + )), + false, + ) + // add the field id to the outer list + .with_field_id(6), + ), + ])), + false, + )), + 3, + ), + false, + ) + // add the field id to the outer list + .with_field_id(4), + ), + ])), + false, + ))), + false, + ) + // add the field id to the outer list + .with_field_id(1), + )]); + + let converted_with_modified = convert_schema( + &schema, + ProjectionMask::all(), + Some(&modified_schema_fields), + )? + .unwrap(); + + assert_eq!( + converted_with_modified.arrow_type, + DataType::Struct(modified_schema_fields) + ); + + Ok(()) + } + + /// Backwards-compatibility: LIST with nullable element type - 1 - standard + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L452-L466) + #[test] + fn list_nullable_element_standard() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group list { + optional int32 element; + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new("element", DataType::Int32, true))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with nullable element type - 2 + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L468-L482) + #[test] + fn list_nullable_element_nested() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group element { + optional int32 num; + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new("num", DataType::Int32, true))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with non-nullable element type - 1 - standard + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L484-L495) + #[test] + fn list_required_element_standard() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group list { + required int32 element; + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new("element", DataType::Int32, false))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with non-nullable element type - 2 + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L497-L508) + #[test] + fn list_required_element_nested() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group element { + required int32 num; + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new("num", DataType::Int32, false))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with non-nullable element type - 3 + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L510-L519) + #[test] + fn list_required_element_primitive() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated int32 element; + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new("element", DataType::Int32, false))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with non-nullable element type - 4 + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L521-L540) + #[test] + fn list_required_element_struct() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group element { + required binary str (UTF8); + required int32 num; + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new( + "element", + DataType::Struct(Fields::from(vec![ + Field::new("str", DataType::Utf8, false), + Field::new("num", DataType::Int32, false), + ])), + false, + ))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with non-nullable element type - 5 - parquet-avro style + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L542-L559) + #[test] + fn list_required_element_avro_style() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group array { + required binary str (UTF8); + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new( + "array", + DataType::Struct(Fields::from(vec![Field::new("str", DataType::Utf8, false)])), + false, + ))), + true, + )]), + ) + } + + /// Backwards-compatibility: LIST with non-nullable element type - 6 - parquet-thrift style + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L561-L578) + #[test] + fn list_required_element_thrift_style() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (LIST) { + repeated group f1_tuple { + required binary str (UTF8); + } + } + }", + Fields::from(vec![Field::new( + "f1", + DataType::List(Arc::new(Field::new( + "f1_tuple", + DataType::Struct(Fields::from(vec![Field::new("str", DataType::Utf8, false)])), + false, + ))), + true, + )]), + ) + } + + /// Backwards-compatibility: MAP with non-nullable value type - 1 - standard + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L652-L667) + #[test] + fn map_required_value_standard() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (MAP) { + repeated group key_value { + required int32 key; + required binary value (UTF8); + } + } + }", + Fields::from(vec![Field::new_map( + "f1", + "key_value", + Field::new("key", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + false, + true, + )]), + ) + } + + /// Backwards-compatibility: MAP with non-nullable value type - 2 + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L669-L684) + #[test] + fn map_required_value_map_key_value() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (MAP_KEY_VALUE) { + repeated group map { + required int32 num; + required binary str (UTF8); + } + } + }", + Fields::from(vec![Field::new_map( + "f1", + "map", + Field::new("num", DataType::Int32, false), + Field::new("str", DataType::Utf8, false), + false, + true, + )]), + ) + } + + /// Backwards-compatibility: MAP with non-nullable value type - 3 - prior to 1.4.x + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L686-L701) + #[test] + fn map_required_value_legacy() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (MAP) { + repeated group map (MAP_KEY_VALUE) { + required int32 key; + required binary value (UTF8); + } + } + }", + Fields::from(vec![Field::new_map( + "f1", + "map", + Field::new("key", DataType::Int32, false), + Field::new("value", DataType::Utf8, false), + false, + true, + )]), + ) + } + + /// Backwards-compatibility: MAP with nullable value type - 1 - standard + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L703-L718) + #[test] + fn map_optional_value_standard() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (MAP) { + repeated group key_value { + required int32 key; + optional binary value (UTF8); + } + } + }", + Fields::from(vec![Field::new_map( + "f1", + "key_value", + Field::new("key", DataType::Int32, false), + Field::new("value", DataType::Utf8, true), + false, + true, + )]), + ) + } + + /// Backwards-compatibility: MAP with nullable value type - 2 + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L720-L735) + #[test] + fn map_optional_value_map_key_value() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (MAP_KEY_VALUE) { + repeated group map { + required int32 num; + optional binary str (UTF8); + } + } + }", + Fields::from(vec![Field::new_map( + "f1", + "map", + Field::new("num", DataType::Int32, false), + Field::new("str", DataType::Utf8, true), + false, + true, + )]), + ) + } + + /// Backwards-compatibility: MAP with nullable value type - 3 - parquet-avro style + /// Taken from [Spark](https://github.com/apache/spark/blob/8ab50765cd793169091d983b50d87a391f6ac1f4/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala#L737-L752) + #[test] + fn map_optional_value_avro_style() -> crate::errors::Result<()> { + test_expected_type( + " + message root { + optional group f1 (MAP) { + repeated group map (MAP_KEY_VALUE) { + required int32 key; + optional binary value (UTF8); + } + } + }", + Fields::from(vec![Field::new_map( + "f1", + "map", + Field::new("key", DataType::Int32, false), + Field::new("value", DataType::Utf8, true), + false, + true, + )]), + ) + } +}