diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs new file mode 100644 index 00000000000..1e2acd99d82 --- /dev/null +++ b/arrow-avro/src/codec.rs @@ -0,0 +1,315 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use arrow_schema::{ + ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, +}; +use std::borrow::Cow; +use std::collections::HashMap; +use std::sync::Arc; + +/// Avro types are not nullable, with nullability instead encoded as a union +/// where one of the variants is the null type. +/// +/// To accommodate this we special case two-variant unions where one of the +/// variants is the null type, and use this to derive arrow's notion of nullability +#[derive(Debug, Copy, Clone)] +enum Nulls { + /// The nulls are encoded as the first union variant + NullFirst, + /// The nulls are encoded as the second union variant + NullSecond, +} + +/// An Avro datatype mapped to the arrow data model +#[derive(Debug, Clone)] +pub struct AvroDataType { + nulls: Option<Nulls>, + metadata: HashMap<String, String>, + codec: Codec, +} + +impl AvroDataType { + /// Returns an arrow [`Field`] with the given name + pub fn field_with_name(&self, name: &str) -> Field { + let d = self.codec.data_type(); + Field::new(name, d, self.nulls.is_some()).with_metadata(self.metadata.clone()) + } +} + +/// A named [`AvroDataType`] +#[derive(Debug, Clone)] +pub struct AvroField { + name: String, + data_type: AvroDataType, +} + +impl AvroField { + /// Returns the arrow [`Field`] + pub fn field(&self) -> Field { + self.data_type.field_with_name(&self.name) + } + + /// Returns the [`Codec`] + pub fn codec(&self) -> &Codec { + &self.data_type.codec + } +} + +impl<'a> TryFrom<&Schema<'a>> for AvroField { + type Error = ArrowError; + + fn try_from(schema: &Schema<'a>) -> Result<Self, Self::Error> { + match schema { + Schema::Complex(ComplexType::Record(r)) => { + let mut resolver = Resolver::default(); + let data_type = make_data_type(schema, None, &mut resolver)?; + Ok(AvroField { + data_type, + name: r.name.to_string(), + }) + } + _ => Err(ArrowError::ParseError(format!( + "Expected record got {schema:?}" + ))), + } + } +} + +/// An Avro encoding +/// +/// <https://avro.apache.org/docs/1.11.1/specification/#encodings> +#[derive(Debug, Clone)] +pub enum Codec { + Null, + Boolean, + Int32, + Int64, + Float32, + Float64, + Binary, + Utf8, + Date32, + TimeMillis, + TimeMicros, + /// TimestampMillis(is_utc) + TimestampMillis(bool), + /// TimestampMicros(is_utc) + TimestampMicros(bool), + Fixed(i32), + List(Arc<AvroDataType>), + Struct(Arc<[AvroField]>), + Duration, +} + +impl Codec { + fn data_type(&self) -> DataType { + match self { + Self::Null => DataType::Null, + Self::Boolean => DataType::Boolean, + Self::Int32 => DataType::Int32, + Self::Int64 => DataType::Int64, + Self::Float32 => DataType::Float32, + Self::Float64 => DataType::Float64, + Self::Binary => DataType::Binary, + Self::Utf8 => DataType::Utf8, + Self::Date32 => DataType::Date32, + Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), + Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + Self::TimestampMillis(is_utc) => { + DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + } + Self::TimestampMicros(is_utc) => { + DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + } + Self::Duration => DataType::Interval(IntervalUnit::MonthDayNano), + Self::Fixed(size) => DataType::FixedSizeBinary(*size), + Self::List(f) => DataType::List(Arc::new(f.field_with_name("item"))), + Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + } + } +} + +impl From<PrimitiveType> for Codec { + fn from(value: PrimitiveType) -> Self { + match value { + PrimitiveType::Null => Self::Null, + PrimitiveType::Boolean => Self::Boolean, + PrimitiveType::Int => Self::Int32, + PrimitiveType::Long => Self::Int64, + PrimitiveType::Float => Self::Float32, + PrimitiveType::Double => Self::Float64, + PrimitiveType::Bytes => Self::Binary, + PrimitiveType::String => Self::Utf8, + } + } +} + +/// Resolves Avro type names to [`AvroDataType`] +/// +/// See <https://avro.apache.org/docs/1.11.1/specification/#names> +#[derive(Debug, Default)] +struct Resolver<'a> { + map: HashMap<(&'a str, &'a str), AvroDataType>, +} + +impl<'a> Resolver<'a> { + fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { + self.map.insert((name, namespace.unwrap_or("")), schema); + } + + fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result<AvroDataType, ArrowError> { + let (namespace, name) = name + .rsplit_once('.') + .unwrap_or_else(|| (namespace.unwrap_or(""), name)); + + self.map + .get(&(namespace, name)) + .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) + .cloned() + } +} + +/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` +/// +/// `name`: is name used to refer to `schema` in its parent +/// `namespace`: an optional qualifier used as part of a type hierarchy +/// +/// See [`Resolver`] for more information +fn make_data_type<'a>( + schema: &Schema<'a>, + namespace: Option<&'a str>, + resolver: &mut Resolver<'a>, +) -> Result<AvroDataType, ArrowError> { + match schema { + Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { + nulls: None, + metadata: Default::default(), + codec: (*p).into(), + }), + Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), + Schema::Union(f) => { + // Special case the common case of nullable primitives + let null = f + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); + match (f.len() == 2, null) { + (true, Some(0)) => { + let mut field = make_data_type(&f[1], namespace, resolver)?; + field.nulls = Some(Nulls::NullFirst); + Ok(field) + } + (true, Some(1)) => { + let mut field = make_data_type(&f[0], namespace, resolver)?; + field.nulls = Some(Nulls::NullSecond); + Ok(field) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Union of {f:?} not currently supported" + ))), + } + } + Schema::Complex(c) => match c { + ComplexType::Record(r) => { + let namespace = r.namespace.or(namespace); + let fields = r + .fields + .iter() + .map(|field| { + Ok(AvroField { + name: field.name.to_string(), + data_type: make_data_type(&field.r#type, namespace, resolver)?, + }) + }) + .collect::<Result<_, ArrowError>>()?; + + let field = AvroDataType { + nulls: None, + codec: Codec::Struct(fields), + metadata: r.attributes.field_metadata(), + }; + resolver.register(r.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Array(a) => { + let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; + Ok(AvroDataType { + nulls: None, + metadata: a.attributes.field_metadata(), + codec: Codec::List(Arc::new(field)), + }) + } + ComplexType::Fixed(f) => { + let size = f.size.try_into().map_err(|e| { + ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) + })?; + + let field = AvroDataType { + nulls: None, + metadata: f.attributes.field_metadata(), + codec: Codec::Fixed(size), + }; + resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( + "Enum of {e:?} not currently supported" + ))), + ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( + "Map of {m:?} not currently supported" + ))), + }, + Schema::Type(t) => { + let mut field = + make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; + + // https://avro.apache.org/docs/1.11.1/specification/#logical-types + match (t.attributes.logical_type, &mut field.codec) { + (Some("decimal"), c @ Codec::Fixed(_)) => { + return Err(ArrowError::NotYetImplemented( + "Decimals are not currently supported".to_string(), + )) + } + (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, + (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, + (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, + (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), + (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), + (Some("local-timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(false) + } + (Some("local-timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(false) + } + (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, + (Some(logical), _) => { + // Insert unrecognized logical type into metadata map + field.metadata.insert("logicalType".into(), logical.into()); + } + (None, _) => {} + } + + if !t.attributes.additional.is_empty() { + for (k, v) in &t.attributes.additional { + field.metadata.insert(k.to_string(), v.to_string()); + } + } + Ok(field) + } + } +} diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index c76ecb399a4..f74edd7e926 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -27,6 +27,8 @@ mod schema; mod compression; +mod codec; + #[cfg(test)] mod test_util { pub fn arrow_test_data(path: &str) -> String { diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 00e85b39be7..97f5d3b8b11 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -236,9 +236,11 @@ impl HeaderDecoder { #[cfg(test)] mod test { use super::*; + use crate::codec::{AvroDataType, AvroField}; use crate::reader::read_header; use crate::schema::SCHEMA_METADATA_KEY; use crate::test_util::arrow_test_data; + use arrow_schema::{DataType, Field, Fields, TimeUnit}; use std::fs::File; use std::io::{BufRead, BufReader}; @@ -269,7 +271,34 @@ mod test { let schema_json = header.get(SCHEMA_METADATA_KEY).unwrap(); let expected = br#"{"type":"record","name":"topLevelRecord","fields":[{"name":"id","type":["int","null"]},{"name":"bool_col","type":["boolean","null"]},{"name":"tinyint_col","type":["int","null"]},{"name":"smallint_col","type":["int","null"]},{"name":"int_col","type":["int","null"]},{"name":"bigint_col","type":["long","null"]},{"name":"float_col","type":["float","null"]},{"name":"double_col","type":["double","null"]},{"name":"date_string_col","type":["bytes","null"]},{"name":"string_col","type":["bytes","null"]},{"name":"timestamp_col","type":[{"type":"long","logicalType":"timestamp-micros"},"null"]}]}"#; assert_eq!(schema_json, expected); - let _schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + let schema: Schema<'_> = serde_json::from_slice(schema_json).unwrap(); + let field = AvroField::try_from(&schema).unwrap(); + + assert_eq!( + field.field(), + Field::new( + "topLevelRecord", + DataType::Struct(Fields::from(vec![ + Field::new("id", DataType::Int32, true), + Field::new("bool_col", DataType::Boolean, true), + Field::new("tinyint_col", DataType::Int32, true), + Field::new("smallint_col", DataType::Int32, true), + Field::new("int_col", DataType::Int32, true), + Field::new("bigint_col", DataType::Int64, true), + Field::new("float_col", DataType::Float32, true), + Field::new("double_col", DataType::Float64, true), + Field::new("date_string_col", DataType::Binary, true), + Field::new("string_col", DataType::Binary, true), + Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + true + ), + ])), + false + ) + ); + assert_eq!( u128::from_le_bytes(header.sync()), 226966037233754408753420635932530907102 diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index 17b82cf861b..6707f8137c9 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -34,7 +34,7 @@ pub enum TypeName<'a> { /// A primitive type /// /// <https://avro.apache.org/docs/1.11.1/specification/#primitive-types> -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub enum PrimitiveType { Null, @@ -64,6 +64,16 @@ pub struct Attributes<'a> { pub additional: HashMap<&'a str, serde_json::Value>, } +impl<'a> Attributes<'a> { + /// Returns the field metadata for this [`Attributes`] + pub(crate) fn field_metadata(&self) -> HashMap<String, String> { + self.additional + .iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect() + } +} + /// A type definition that is not a variant of [`ComplexType`] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] @@ -94,8 +104,6 @@ pub enum Schema<'a> { #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum ComplexType<'a> { - #[serde(borrow)] - Union(Vec<Schema<'a>>), #[serde(borrow)] Record(Record<'a>), #[serde(borrow)] @@ -202,7 +210,10 @@ pub struct Fixed<'a> { #[cfg(test)] mod tests { use super::*; + use crate::codec::{AvroDataType, AvroField}; + use arrow_schema::{DataType, Fields, TimeUnit}; use serde_json::json; + #[test] fn test_deserialize() { let t: Schema = serde_json::from_str("\"string\"").unwrap(); @@ -352,6 +363,10 @@ mod tests { })) ); + // Recursive schema are not supported + let err = AvroField::try_from(&schema).unwrap_err().to_string(); + assert_eq!(err, "Parser error: Failed to resolve .LongList"); + let schema: Schema = serde_json::from_str( r#"{ "type":"record", @@ -409,14 +424,29 @@ mod tests { attributes: Default::default(), })) ); + let codec = AvroField::try_from(&schema).unwrap(); + assert_eq!( + codec.field(), + arrow_schema::Field::new( + "topLevelRecord", + DataType::Struct(Fields::from(vec![ + arrow_schema::Field::new("id", DataType::Int32, true), + arrow_schema::Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + true + ), + ])), + false + ) + ); let schema: Schema = serde_json::from_str( r#"{ "type": "record", "name": "HandshakeRequest", "namespace":"org.apache.avro.ipc", "fields": [ - {"name": "clientHash", - "type": {"type": "fixed", "name": "MD5", "size": 16}}, + {"name": "clientHash", "type": {"type": "fixed", "name": "MD5", "size": 16}}, {"name": "clientProtocol", "type": ["null", "string"]}, {"name": "serverHash", "type": "MD5"}, {"name": "meta", "type": ["null", {"type": "map", "values": "bytes"}]}