From 893813f05869ceee90388f67599b879878545345 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sat, 30 Nov 2024 16:29:42 +0000 Subject: [PATCH 1/6] Hook up Avro Decoder --- arrow-avro/Cargo.toml | 4 +- arrow-avro/src/codec.rs | 42 +++-- arrow-avro/src/reader/cursor.rs | 140 ++++++++++++++++ arrow-avro/src/reader/header.rs | 13 +- arrow-avro/src/reader/mod.rs | 141 ++++++++++++++-- arrow-avro/src/reader/record.rs | 288 ++++++++++++++++++++++++++++++++ 6 files changed, 596 insertions(+), 32 deletions(-) create mode 100644 arrow-avro/src/reader/cursor.rs create mode 100644 arrow-avro/src/reader/record.rs diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index c82bf0113ea4..f4a466e86938 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -39,7 +39,9 @@ deflate = ["flate2"] snappy = ["snap", "crc"] [dependencies] -arrow-schema = { workspace = true } +arrow-schema = { workspace = true } +arrow-buffer = { workspace = true } +arrow-array = { workspace = true } serde_json = { version = "1.0", default-features = false, features = ["std"] } serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 35fa0339d69d..2ac1ad038bd7 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -29,7 +29,7 @@ use std::sync::Arc; /// To accommodate this we special case two-variant unions where one of the /// variants is the null type, and use this to derive arrow's notion of nullability #[derive(Debug, Copy, Clone)] -enum Nulls { +pub enum Nullability { /// The nulls are encoded as the first union variant NullFirst, /// The nulls are encoded as the second union variant @@ -39,7 +39,7 @@ enum Nulls { /// An Avro datatype mapped to the arrow data model #[derive(Debug, Clone)] pub struct AvroDataType { - nulls: Option, + nullability: Option, metadata: HashMap, codec: Codec, } @@ -48,7 +48,15 @@ impl AvroDataType { /// Returns an arrow [`Field`] with the given name pub fn field_with_name(&self, name: &str) -> Field { let d = self.codec.data_type(); - Field::new(name, d, self.nulls.is_some()).with_metadata(self.metadata.clone()) + Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) + } + + pub fn codec(&self) -> &Codec { + &self.codec + } + + pub fn nullability(&self) -> Option { + self.nullability } } @@ -65,9 +73,13 @@ impl AvroField { self.data_type.field_with_name(&self.name) } - /// Returns the [`Codec`] - pub fn codec(&self) -> &Codec { - &self.data_type.codec + /// Returns the [`AvroDataType`] + pub fn data_type(&self) -> &AvroDataType { + &self.data_type + } + + pub fn name(&self) -> &str { + &self.name } } @@ -114,7 +126,7 @@ pub enum Codec { Fixed(i32), List(Arc), Struct(Arc<[AvroField]>), - Duration, + Interval, } impl Codec { @@ -137,7 +149,7 @@ impl Codec { Self::TimestampMicros(is_utc) => { DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } - Self::Duration => DataType::Interval(IntervalUnit::MonthDayNano), + Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => DataType::FixedSizeBinary(*size), Self::List(f) => { DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) @@ -200,7 +212,7 @@ fn make_data_type<'a>( ) -> Result { match schema { Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType { - nulls: None, + nullability: None, metadata: Default::default(), codec: (*p).into(), }), @@ -213,12 +225,12 @@ fn make_data_type<'a>( match (f.len() == 2, null) { (true, Some(0)) => { let mut field = make_data_type(&f[1], namespace, resolver)?; - field.nulls = Some(Nulls::NullFirst); + field.nullability = Some(Nullability::NullFirst); Ok(field) } (true, Some(1)) => { let mut field = make_data_type(&f[0], namespace, resolver)?; - field.nulls = Some(Nulls::NullSecond); + field.nullability = Some(Nullability::NullSecond); Ok(field) } _ => Err(ArrowError::NotYetImplemented(format!( @@ -241,7 +253,7 @@ fn make_data_type<'a>( .collect::>()?; let field = AvroDataType { - nulls: None, + nullability: None, codec: Codec::Struct(fields), metadata: r.attributes.field_metadata(), }; @@ -251,7 +263,7 @@ fn make_data_type<'a>( ComplexType::Array(a) => { let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; Ok(AvroDataType { - nulls: None, + nullability: None, metadata: a.attributes.field_metadata(), codec: Codec::List(Arc::new(field)), }) @@ -262,7 +274,7 @@ fn make_data_type<'a>( })?; let field = AvroDataType { - nulls: None, + nullability: None, metadata: f.attributes.field_metadata(), codec: Codec::Fixed(size), }; @@ -298,7 +310,7 @@ fn make_data_type<'a>( (Some("local-timestamp-micros"), c @ Codec::Int64) => { *c = Codec::TimestampMicros(false) } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Duration, + (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, (Some(logical), _) => { // Insert unrecognized logical type into metadata map field.metadata.insert("logicalType".into(), logical.into()); diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs new file mode 100644 index 000000000000..0d4585a5d96d --- /dev/null +++ b/arrow-avro/src/reader/cursor.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; + +/// A wrapper around a byte slice, providing low-level decoding for Avro +/// +/// +#[derive(Debug)] +pub(crate) struct AvroCursor<'a> { + buf: &'a [u8], + start_len: usize, +} + +impl<'a> AvroCursor<'a> { + pub(crate) fn new(buf: &'a [u8]) -> Self { + Self { + buf, + start_len: buf.len(), + } + } + + /// Returns the current cursor position + #[inline] + pub(crate) fn position(&self) -> usize { + self.start_len - self.buf.len() + } + + /// Read a single `u8` + #[inline] + pub(crate) fn get_u8(&mut self) -> Result { + match self.buf.first().copied() { + Some(x) => { + self.buf = &self.buf[1..]; + Ok(x) + } + None => Err(ArrowError::ParseError("Unexpected EOF".to_string())), + } + } + + #[inline] + pub(crate) fn get_bool(&mut self) -> Result { + Ok(self.get_u8()? != 0) + } + + #[inline] + pub(crate) fn get_int(&mut self) -> Result { + let mut in_progress = 0; + let mut shift = 0; + + while let Some(byte) = self.buf.first().copied() { + self.buf = &self.buf[1..]; + in_progress |= ((byte & 0x7F) as u32) << shift; + shift += 7; + if byte & 0x80 == 0 { + let val = in_progress; + in_progress = 0; + shift = 0; + return Ok((val >> 1) as i32 ^ -((val & 1) as i32)); + } + } + Err(ArrowError::ParseError( + "Unexpected EOF reading int".to_string(), + )) + } + + #[inline] + pub(crate) fn get_long(&mut self) -> Result { + let mut in_progress = 0; + let mut shift = 0; + + while let Some(byte) = self.buf.first().copied() { + self.buf = &self.buf[1..]; + in_progress |= ((byte & 0x7F) as u64) << shift; + shift += 7; + if byte & 0x80 == 0 { + let val = in_progress; + in_progress = 0; + shift = 0; + return Ok((val >> 1) as i64 ^ -((val & 1) as i64)); + } + } + Err(ArrowError::ParseError( + "Unexpected EOF reading long".to_string(), + )) + } + + pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { + let len: usize = self.get_long()?.try_into().map_err(|_| { + ArrowError::ParseError("offset overflow reading avro bytes".to_string()) + })?; + + if (self.buf.len() < len) { + return Err(ArrowError::ParseError( + "Unexpected EOF reading bytes".to_string(), + )); + } + let ret = &self.buf[..len]; + self.buf = &self.buf[len..]; + Ok(ret) + } + + #[inline] + pub(crate) fn get_float(&mut self) -> Result { + if (self.buf.len() < 4) { + return Err(ArrowError::ParseError( + "Unexpected EOF reading float".to_string(), + )); + } + let ret = f32::from_le_bytes(self.buf[..4].try_into().unwrap()); + self.buf = &self.buf[4..]; + Ok(ret) + } + + #[inline] + pub(crate) fn get_double(&mut self) -> Result { + if (self.buf.len() < 8) { + return Err(ArrowError::ParseError( + "Unexpected EOF reading float".to_string(), + )); + } + let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); + self.buf = &self.buf[8..]; + Ok(ret) + } +} diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 19d48d1f89a1..98c285171bf3 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -19,7 +19,7 @@ use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; use crate::reader::vlq::VLQDecoder; -use crate::schema::Schema; +use crate::schema::{Schema, SCHEMA_METADATA_KEY}; use arrow_schema::ArrowError; #[derive(Debug)] @@ -89,6 +89,17 @@ impl Header { ))), } } + + /// Returns the [`Schema`] if any + pub fn schema(&self) -> Result>, ArrowError> { + self.get(SCHEMA_METADATA_KEY) + .map(|x| { + serde_json::from_slice(x).map_err(|e| { + ArrowError::ParseError(format!("Failed to parse Avro schema JSON: {e}")) + }) + }) + .transpose() + } } /// A decoder for [`Header`] diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 0151db7f855a..12fa67d9c8e3 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -26,6 +26,8 @@ mod header; mod block; +mod cursor; +mod record; mod vlq; /// Read a [`Header`] from the provided [`BufRead`] @@ -73,35 +75,144 @@ fn read_blocks(mut reader: R) -> impl Iterator RecordBatch { + let file = File::open(file).unwrap(); + let mut reader = BufReader::new(file); + let header = read_header(&mut reader).unwrap(); + let compression = header.compression().unwrap(); + let schema = header.schema().unwrap().unwrap(); + let root = AvroField::try_from(&schema).unwrap(); + let mut decoder = RecordDecoder::try_new(root.data_type()).unwrap(); + + for result in read_blocks(reader) { + let block = result.unwrap(); + assert_eq!(block.sync, header.sync()); + if let Some(c) = compression { + let decompressed = c.decompress(&block.data).unwrap(); + + let mut offset = 0; + let mut remaining = block.count; + while remaining > 0 { + let to_read = remaining.max(batch_size); + offset += decoder + .decode(&decompressed[offset..], block.count) + .unwrap(); + + remaining -= to_read; + } + assert_eq!(offset, decompressed.len()); + } + } + decoder.flush().unwrap() + } #[test] - fn test_mux() { + fn test_alltypes() { let files = [ "avro/alltypes_plain.avro", "avro/alltypes_plain.snappy.avro", "avro/alltypes_plain.zstandard.avro", - "avro/alltypes_nulls_plain.avro", ]; + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([ + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + ])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + for file in files { - println!("file: {file}"); - let file = File::open(arrow_test_data(file)).unwrap(); - let mut reader = BufReader::new(file); - let header = read_header(&mut reader).unwrap(); - let compression = header.compression().unwrap(); - println!("compression: {compression:?}"); - for result in read_blocks(reader) { - let block = result.unwrap(); - assert_eq!(block.sync, header.sync()); - if let Some(c) = compression { - c.decompress(&block.data).unwrap(); - } - } + let file = arrow_test_data(file); + + assert_eq!(read_file(&file, 8), expected); + assert_eq!(read_file(&file, 3), expected); } } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs new file mode 100644 index 000000000000..6245050c8048 --- /dev/null +++ b/arrow-avro/src/reader/record.rs @@ -0,0 +1,288 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::codec::{AvroDataType, Codec, Nullability}; +use crate::reader::block::{Block, BlockDecoder}; +use crate::reader::cursor::AvroCursor; +use crate::reader::header::Header; +use crate::schema::*; +use arrow_array::types::*; +use arrow_array::*; +use arrow_buffer::*; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, +}; +use std::collections::HashMap; +use std::io::Read; +use std::sync::Arc; + +/// Decodes avro encoded data into [`RecordBatch`] +pub struct RecordDecoder { + schema: SchemaRef, + fields: Vec, +} + +impl RecordDecoder { + pub fn try_new(data_type: &AvroDataType) -> Result { + match Decoder::try_new(data_type)? { + Decoder::Record(fields, encodings) => Ok(Self { + schema: Arc::new(ArrowSchema::new(fields)), + fields: encodings, + }), + encoding => Err(ArrowError::ParseError(format!( + "Expected record got {encoding:?}" + ))), + } + } + + pub fn schema(&self) -> &SchemaRef { + &self.schema + } + + /// Decode `count` records from `buf` + pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { + let mut cursor = AvroCursor::new(buf); + for _ in 0..count { + for field in &mut self.fields { + field.decode(&mut cursor)?; + } + } + Ok(cursor.position()) + } + + pub fn flush(&mut self) -> Result { + let arrays = self + .fields + .iter_mut() + .map(|x| x.flush(None)) + .collect::, _>>()?; + + RecordBatch::try_new(self.schema.clone(), arrays) + } +} + +#[derive(Debug)] +enum Decoder { + Null(usize), + Boolean(BooleanBufferBuilder), + Int32(Vec), + Int64(Vec), + Float32(Vec), + Float64(Vec), + Date32(Vec), + TimeMillis(Vec), + TimeMicros(Vec), + TimestampMillis(bool, Vec), + TimestampMicros(bool, Vec), + Binary(OffsetBufferBuilder, Vec), + String(OffsetBufferBuilder, Vec), + List(FieldRef, OffsetBufferBuilder, Box), + Record(Fields, Vec), + Nullable(Nullability, NullBufferBuilder, Box), +} + +impl Decoder { + fn try_new(data_type: &AvroDataType) -> Result { + let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); + + let decoder = match data_type.codec() { + Codec::Null => Self::Null(0), + Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Binary => Self::Binary( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + Codec::Utf8 => Self::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimestampMillis(is_utc) => { + Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::TimestampMicros(is_utc) => { + Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::Fixed(_) => return nyi("decoding fixed"), + Codec::Interval => return nyi("decoding interval"), + Codec::List(item) => { + let decoder = Self::try_new(item)?; + Self::List( + Arc::new(item.field_with_name("item")), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ) + } + Codec::Struct(fields) => { + let mut arrow_fields = Vec::with_capacity(fields.len()); + let mut encodings = Vec::with_capacity(fields.len()); + for avro_field in fields.iter() { + let encoding = Self::try_new(avro_field.data_type())?; + arrow_fields.push(avro_field.field()); + encodings.push(encoding); + } + Self::Record(arrow_fields.into(), encodings) + } + }; + + Ok(match data_type.nullability() { + Some(nullability) => Self::Nullable( + nullability, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + ), + None => decoder, + }) + } + + fn append_null(&mut self) { + match self { + Self::Null(count) => *count += 1, + Self::Boolean(b) => b.append(false), + Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), + Self::Int64(v) + | Self::TimeMicros(v) + | Self::TimestampMillis(_, v) + | Self::TimestampMicros(_, v) => v.push(0), + Self::Float32(v) => v.push(0.), + Self::Float64(v) => v.push(0.), + Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), + Self::List(_, offsets, e) => { + offsets.push_length(0); + e.append_null(); + } + Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), + Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + } + } + + fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + match self { + Self::Null(x) => *x += 1, + Self::Boolean(values) => values.append(buf.get_bool()?), + Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { + values.push(buf.get_int()?) + } + Self::Int64(values) + | Self::TimeMicros(values) + | Self::TimestampMillis(_, values) + | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), + Self::Float32(values) => values.push(buf.get_float()?), + Self::Float64(values) => values.push(buf.get_double()?), + Self::Binary(offsets, values) | Self::String(offsets, values) => { + let data = buf.get_bytes()?; + offsets.push_length(data.len()); + values.extend_from_slice(data); + } + Self::List(_, _, _) => { + return Err(ArrowError::NotYetImplemented( + "Decoding ListArray".to_string(), + )) + } + Self::Record(_, encodings) => { + for encoding in encodings { + encoding.decode(buf)?; + } + } + Self::Nullable(nullability, nulls, e) => { + let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); + nulls.append(is_valid); + match is_valid { + true => e.decode(buf)?, + false => e.append_null(), + } + } + } + Ok(()) + } + + fn flush(&mut self, nulls: Option) -> Result { + Ok(match self { + Self::Nullable(_, n, e) => e.flush(n.finish())?, + Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), + Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), + Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), + Self::TimeMillis(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::TimeMicros(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::TimestampMillis(is_utc, values) => Arc::new( + flush_primitive::(values, nulls) + .with_timezone_opt(is_utc.then(|| "+00:00")), + ), + Self::TimestampMicros(is_utc, values) => Arc::new( + flush_primitive::(values, nulls) + .with_timezone_opt(is_utc.then(|| "+00:00")), + ), + Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), + + Self::Binary(offsets, values) => { + let offsets = flush_offsets(offsets); + let values = flush_values(values).into(); + Arc::new(BinaryArray::new(offsets, values, nulls)) + } + Self::String(offsets, values) => { + let offsets = flush_offsets(offsets); + let values = flush_values(values).into(); + Arc::new(StringArray::new(offsets, values, nulls)) + } + Self::List(field, offsets, values) => { + let values = values.flush(None)?; + let offsets = flush_offsets(offsets); + Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + } + Self::Record(fields, encodings) => { + let arrays = encodings + .iter_mut() + .map(|x| x.flush(None)) + .collect::, _>>()?; + Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + } + }) + } +} + +#[inline] +fn flush_values(values: &mut Vec) -> Vec { + std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +#[inline] +fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +#[inline] +fn flush_primitive( + values: &mut Vec, + nulls: Option, +) -> PrimitiveArray { + PrimitiveArray::new(flush_values(values).into(), nulls) +} + +const DEFAULT_CAPACITY: usize = 1024; From bbdaf7b3d38166adfe3aed3db525fb2cd28a6742 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sat, 30 Nov 2024 17:30:18 +0000 Subject: [PATCH 2/6] Docs --- arrow-avro/src/reader/record.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 6245050c8048..52a58cf63303 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -64,6 +64,7 @@ impl RecordDecoder { Ok(cursor.position()) } + /// Flush the decoded records into a [`RecordBatch`] pub fn flush(&mut self) -> Result { let arrays = self .fields @@ -155,6 +156,7 @@ impl Decoder { }) } + /// Append a null record fn append_null(&mut self) { match self { Self::Null(count) => *count += 1, @@ -176,6 +178,7 @@ impl Decoder { } } + /// Decode a single record from `buf` fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { match self { Self::Null(x) => *x += 1, @@ -216,6 +219,7 @@ impl Decoder { Ok(()) } + /// Flush decoded records to an [`ArrayRef`] fn flush(&mut self, nulls: Option) -> Result { Ok(match self { Self::Nullable(_, n, e) => e.flush(n.finish())?, From 0626b90dea2e6bce382857ec206691db5a0dfcf7 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sun, 1 Dec 2024 16:21:56 +0000 Subject: [PATCH 3/6] Improved varint decode --- arrow-avro/Cargo.toml | 1 + arrow-avro/src/reader/cursor.rs | 49 ++++++------------ arrow-avro/src/reader/vlq.rs | 89 +++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 34 deletions(-) diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index f4a466e86938..c103c2ecc0f3 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -51,4 +51,5 @@ crc = { version = "3.0", optional = true } [dev-dependencies] +rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 0d4585a5d96d..4b6a5a4d65db 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::reader::vlq::read_varint; use arrow_schema::ArrowError; /// A wrapper around a byte slice, providing low-level decoding for Avro @@ -57,46 +58,26 @@ impl<'a> AvroCursor<'a> { Ok(self.get_u8()? != 0) } + pub(crate) fn read_vlq(&mut self) -> Result { + let (val, offset) = read_varint(self.buf) + .ok_or_else(|| ArrowError::ParseError("bad varint".to_string()))?; + self.buf = &self.buf[offset..]; + Ok(val) + } + #[inline] pub(crate) fn get_int(&mut self) -> Result { - let mut in_progress = 0; - let mut shift = 0; - - while let Some(byte) = self.buf.first().copied() { - self.buf = &self.buf[1..]; - in_progress |= ((byte & 0x7F) as u32) << shift; - shift += 7; - if byte & 0x80 == 0 { - let val = in_progress; - in_progress = 0; - shift = 0; - return Ok((val >> 1) as i32 ^ -((val & 1) as i32)); - } - } - Err(ArrowError::ParseError( - "Unexpected EOF reading int".to_string(), - )) + let varint = self.read_vlq()?; + let val: u32 = varint + .try_into() + .map_err(|_| ArrowError::ParseError("varint overflow".to_string()))?; + Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } #[inline] pub(crate) fn get_long(&mut self) -> Result { - let mut in_progress = 0; - let mut shift = 0; - - while let Some(byte) = self.buf.first().copied() { - self.buf = &self.buf[1..]; - in_progress |= ((byte & 0x7F) as u64) << shift; - shift += 7; - if byte & 0x80 == 0 { - let val = in_progress; - in_progress = 0; - shift = 0; - return Ok((val >> 1) as i64 ^ -((val & 1) as i64)); - } - } - Err(ArrowError::ParseError( - "Unexpected EOF reading long".to_string(), - )) + let val = self.read_vlq()?; + Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index 80f1c60eec7d..1c78616e7684 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -44,3 +44,92 @@ impl VLQDecoder { None } } + +/// Read a varint from `buf` returning the decoded `u64` and the number of bytes read +#[inline] +pub(crate) fn read_varint(buf: &[u8]) -> Option<(u64, usize)> { + let first = *buf.first()?; + if first < 0x80 { + return Some((first as u64, 1)); + } + + if let Some(array) = buf.get(..10) { + return read_varint_array(array.try_into().unwrap()); + } + + read_varint_slow(buf) +} + +/// Based on +/// - https://github.com/tokio-rs/prost/blob/master/prost/src/encoding/varint.rs#L71 +/// - https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406 +/// - https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 +#[inline] +fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { + let mut in_progress = 0_u64; + for idx in 0..9 { + let b = buf[idx] as u64; + in_progress += b << (7 * idx); + if b < 0x80 { + return Some((in_progress, idx + 1)); + } + in_progress -= 0x80 << (7 * idx); + } + + let b = buf[9] as u64; + in_progress += b << (7 * 9); + (b < 0x02).then_some((in_progress, 10)) +} + +#[inline(never)] +#[cold] +fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { + let mut value = 0; + for count in 0..buf.len().min(10) { + let byte = buf[count]; + value |= u64::from(byte & 0x7F) << (count * 7); + if byte <= 0x7F { + // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details. + // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 + return (count != 9 || byte < 2).then_some((value, count + 1)); + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + fn encode_var(mut n: u64, dst: &mut [u8]) -> usize { + let mut i = 0; + + while n >= 0x80 { + dst[i] = 0x80 | (n as u8); + i += 1; + n >>= 7; + } + + dst[i] = n as u8; + i + 1 + } + + fn varint_test(a: u64) { + let mut buf = [0_u8; 10]; + let len = encode_var(a, &mut buf); + assert_eq!(read_varint(&buf[..len]).unwrap(), (a, len)); + assert_eq!(read_varint(&buf).unwrap(), (a, len)); + } + + #[test] + fn test_varint() { + varint_test(0); + varint_test(4395932); + varint_test(u64::MAX); + + for _ in 0..1000 { + varint_test(rand::random()); + } + } +} From cd2d97183e5831f8a1eee42d30174335841aaf1e Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sun, 1 Dec 2024 16:25:26 +0000 Subject: [PATCH 4/6] Docs --- arrow-avro/src/reader/vlq.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index 1c78616e7684..1375ff2ebc18 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -61,9 +61,9 @@ pub(crate) fn read_varint(buf: &[u8]) -> Option<(u64, usize)> { } /// Based on -/// - https://github.com/tokio-rs/prost/blob/master/prost/src/encoding/varint.rs#L71 -/// - https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406 -/// - https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358 +/// - +/// - +/// - #[inline] fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { let mut in_progress = 0_u64; From 33e0a59e892fe191e9f2643feababc6297da96eb Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sun, 1 Dec 2024 16:34:07 +0000 Subject: [PATCH 5/6] Clippy --- arrow-avro/src/reader/vlq.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index 1375ff2ebc18..26ab6dbcfbeb 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -85,7 +85,7 @@ fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { #[cold] fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { let mut value = 0; - for count in 0..buf.len().min(10) { + for (count, byte) in buf.iter().take(10).enumerate() { let byte = buf[count]; value |= u64::from(byte & 0x7F) << (count * 7); if byte <= 0x7F { From bdadabb44994f058787bcc9f32ad91d3e8ff20bd Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sun, 1 Dec 2024 16:45:30 +0000 Subject: [PATCH 6/6] More clippy --- arrow-avro/src/reader/vlq.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index 26ab6dbcfbeb..b198a0d66f24 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -67,9 +67,8 @@ pub(crate) fn read_varint(buf: &[u8]) -> Option<(u64, usize)> { #[inline] fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { let mut in_progress = 0_u64; - for idx in 0..9 { - let b = buf[idx] as u64; - in_progress += b << (7 * idx); + for (idx, b) in buf.into_iter().take(9).enumerate() { + in_progress += (b as u64) << (7 * idx); if b < 0x80 { return Some((in_progress, idx + 1)); }