diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index db552c373fcb..1e2acd99d828 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -36,31 +36,38 @@ enum Nulls { NullSecond, } -/// An Avro field mapped to the arrow data model +/// An Avro datatype mapped to the arrow data model #[derive(Debug, Clone)] -pub struct AvroField { +pub struct AvroDataType { nulls: Option, - meta: Arc, + metadata: HashMap, + codec: Codec, +} + +impl AvroDataType { + /// Returns an arrow [`Field`] with the given name + pub fn field_with_name(&self, name: &str) -> Field { + let d = self.codec.data_type(); + Field::new(name, d, self.nulls.is_some()).with_metadata(self.metadata.clone()) + } } +/// A named [`AvroDataType`] #[derive(Debug, Clone)] -struct AvroFieldMeta { +pub struct AvroField { name: String, - metadata: HashMap, - codec: Codec, + data_type: AvroDataType, } impl AvroField { /// Returns the arrow [`Field`] pub fn field(&self) -> Field { - let d = self.meta.codec.data_type(); - Field::new(&self.meta.name, d, self.nulls.is_some()) - .with_metadata(self.meta.metadata.clone()) + self.data_type.field_with_name(&self.name) } /// Returns the [`Codec`] pub fn codec(&self) -> &Codec { - &self.meta.codec + &self.data_type.codec } } @@ -68,8 +75,19 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { type Error = ArrowError; fn try_from(schema: &Schema<'a>) -> Result { - let mut resolver = Resolver::default(); - make_field(schema, "item", None, &mut resolver) + match schema { + Schema::Complex(ComplexType::Record(r)) => { + let mut resolver = Resolver::default(); + let data_type = make_data_type(schema, None, &mut resolver)?; + Ok(AvroField { + data_type, + name: r.name.to_string(), + }) + } + _ => Err(ArrowError::ParseError(format!( + "Expected record got {schema:?}" + ))), + } } } @@ -94,7 +112,7 @@ pub enum Codec { /// TimestampMicros(is_utc) TimestampMicros(bool), Fixed(i32), - List(Arc), + List(Arc), Struct(Arc<[AvroField]>), Duration, } @@ -121,7 +139,7 @@ impl Codec { } Self::Duration => DataType::Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::List(f) => DataType::List(Arc::new(f.field())), + Self::List(f) => DataType::List(Arc::new(f.field_with_name("item"))), Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), } } @@ -142,20 +160,20 @@ impl From for Codec { } } -/// Resolves Avro type names to [`AvroField`] +/// Resolves Avro type names to [`AvroDataType`] /// /// See #[derive(Debug, Default)] struct Resolver<'a> { - map: HashMap<(&'a str, &'a str), AvroField>, + map: HashMap<(&'a str, &'a str), AvroDataType>, } impl<'a> Resolver<'a> { - fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroField) { + fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { self.map.insert((name, namespace.unwrap_or("")), schema); } - fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { + fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { let (namespace, name) = name .rsplit_once('.') .unwrap_or_else(|| (namespace.unwrap_or(""), name)); @@ -167,26 +185,22 @@ impl<'a> Resolver<'a> { } } -/// Parses a [`AvroField`] from the provided [`Schema`] and the given `name` and `namespace` +/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` /// /// `name`: is name used to refer to `schema` in its parent /// `namespace`: an optional qualifier used as part of a type hierarchy /// /// See [`Resolver`] for more information -fn make_field<'a>( +fn make_data_type<'a>( schema: &Schema<'a>, - name: &'a str, namespace: Option<&'a str>, resolver: &mut Resolver<'a>, -) -> Result { +) -> Result { match schema { - Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroField { + Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { nulls: None, - meta: Arc::new(AvroFieldMeta { - name: name.to_string(), - metadata: Default::default(), - codec: (*p).into(), - }), + metadata: Default::default(), + codec: (*p).into(), }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), Schema::Union(f) => { @@ -196,12 +210,12 @@ fn make_field<'a>( .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); match (f.len() == 2, null) { (true, Some(0)) => { - let mut field = make_field(&f[1], name, namespace, resolver)?; + let mut field = make_data_type(&f[1], namespace, resolver)?; field.nulls = Some(Nulls::NullFirst); Ok(field) } (true, Some(1)) => { - let mut field = make_field(&f[0], name, namespace, resolver)?; + let mut field = make_data_type(&f[0], namespace, resolver)?; field.nulls = Some(Nulls::NullSecond); Ok(field) } @@ -216,29 +230,28 @@ fn make_field<'a>( let fields = r .fields .iter() - .map(|field| make_field(&field.r#type, field.name, namespace, resolver)) - .collect::>()?; + .map(|field| { + Ok(AvroField { + name: field.name.to_string(), + data_type: make_data_type(&field.r#type, namespace, resolver)?, + }) + }) + .collect::>()?; - let field = AvroField { + let field = AvroDataType { nulls: None, - meta: Arc::new(AvroFieldMeta { - name: r.name.to_string(), - codec: Codec::Struct(fields), - metadata: r.attributes.field_metadata(), - }), + codec: Codec::Struct(fields), + metadata: r.attributes.field_metadata(), }; - resolver.register(name, namespace, field.clone()); + resolver.register(r.name, namespace, field.clone()); Ok(field) } ComplexType::Array(a) => { - let mut field = make_field(a.items.as_ref(), "item", namespace, resolver)?; - Ok(AvroField { + let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; + Ok(AvroDataType { nulls: None, - meta: Arc::new(AvroFieldMeta { - name: name.to_string(), - metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), - }), + metadata: a.attributes.field_metadata(), + codec: Codec::List(Arc::new(field)), }) } ComplexType::Fixed(f) => { @@ -246,13 +259,10 @@ fn make_field<'a>( ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) })?; - let field = AvroField { + let field = AvroDataType { nulls: None, - meta: Arc::new(AvroFieldMeta { - name: f.name.to_string(), - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), - }), + metadata: f.attributes.field_metadata(), + codec: Codec::Fixed(size), }; resolver.register(f.name, namespace, field.clone()); Ok(field) @@ -265,16 +275,11 @@ fn make_field<'a>( ))), }, Schema::Type(t) => { - let mut field = make_field( - &Schema::TypeName(t.r#type.clone()), - name, - namespace, - resolver, - )?; - let meta = Arc::make_mut(&mut field.meta); + let mut field = + make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; // https://avro.apache.org/docs/1.11.1/specification/#logical-types - match (t.attributes.logical_type, &mut meta.codec) { + match (t.attributes.logical_type, &mut field.codec) { (Some("decimal"), c @ Codec::Fixed(_)) => { return Err(ArrowError::NotYetImplemented( "Decimals are not currently supported".to_string(), @@ -294,14 +299,14 @@ fn make_field<'a>( (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, (Some(logical), _) => { // Insert unrecognized logical type into metadata map - meta.metadata.insert("logicalType".into(), logical.into()); + field.metadata.insert("logicalType".into(), logical.into()); } (None, _) => {} } if !t.attributes.additional.is_empty() { for (k, v) in &t.attributes.additional { - meta.metadata.insert(k.to_string(), v.to_string()); + field.metadata.insert(k.to_string(), v.to_string()); } } Ok(field) diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index c2171cccf301..97f5d3b8b112 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -236,7 +236,7 @@ impl HeaderDecoder { #[cfg(test)] mod test { use super::*; - use crate::codec::AvroField; + use crate::codec::{AvroDataType, AvroField}; use crate::reader::read_header; use crate::schema::SCHEMA_METADATA_KEY; use crate::test_util::arrow_test_data; diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 6f1f3d6bd012..6707f8137c9b 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -210,7 +210,7 @@ pub struct Fixed<'a> { #[cfg(test)] mod tests { use super::*; - use crate::codec::AvroField; + use crate::codec::{AvroDataType, AvroField}; use arrow_schema::{DataType, Fields, TimeUnit}; use serde_json::json;