diff --git a/parquet/src/column/writer/encoder.rs b/parquet/src/column/writer/encoder.rs index 7bd4db30c3a8..0cbcda5b4854 100644 --- a/parquet/src/column/writer/encoder.rs +++ b/parquet/src/column/writer/encoder.rs @@ -290,7 +290,7 @@ where { let first = loop { let next = iter.next()?; - if !is_nan(next) { + if !is_nan(descr, next) { break next; } }; @@ -298,7 +298,7 @@ where let mut min = first; let mut max = first; for val in iter { - if is_nan(val) { + if is_nan(descr, val) { continue; } if compare_greater(descr, min, val) { diff --git a/parquet/src/column/writer/mod.rs b/parquet/src/column/writer/mod.rs index 84bf1911d89c..e657992acc9a 100644 --- a/parquet/src/column/writer/mod.rs +++ b/parquet/src/column/writer/mod.rs @@ -17,6 +17,8 @@ //! Contains column writer API. +use half::f16; + use crate::bloom_filter::Sbbf; use crate::format::{ColumnIndex, OffsetIndex}; use std::collections::{BTreeSet, VecDeque}; @@ -967,18 +969,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { } fn update_min(descr: &ColumnDescriptor, val: &T, min: &mut Option) { - update_stat::(val, min, |cur| compare_greater(descr, cur, val)) + update_stat::(descr, val, min, |cur| compare_greater(descr, cur, val)) } fn update_max(descr: &ColumnDescriptor, val: &T, max: &mut Option) { - update_stat::(val, max, |cur| compare_greater(descr, val, cur)) + update_stat::(descr, val, max, |cur| compare_greater(descr, val, cur)) } #[inline] #[allow(clippy::eq_op)] -fn is_nan(val: &T) -> bool { +fn is_nan(descr: &ColumnDescriptor, val: &T) -> bool { match T::PHYSICAL_TYPE { Type::FLOAT | Type::DOUBLE => val != val, + Type::FIXED_LEN_BYTE_ARRAY if descr.logical_type() == Some(LogicalType::Float16) => { + let val = val.as_bytes(); + let val = f16::from_le_bytes([val[0], val[1]]); + val.is_nan() + } _ => false, } } @@ -988,11 +995,15 @@ fn is_nan(val: &T) -> bool { /// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls `should_update` with /// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true` -fn update_stat(val: &T, cur: &mut Option, should_update: F) -where +fn update_stat( + descr: &ColumnDescriptor, + val: &T, + cur: &mut Option, + should_update: F, +) where F: Fn(&T) -> bool, { - if is_nan(val) { + if is_nan(descr, val) { return; } @@ -1038,6 +1049,14 @@ fn compare_greater(descr: &ColumnDescriptor, a: &T, b: &T) }; }; + if let Some(LogicalType::Float16) = descr.logical_type() { + let a = a.as_bytes(); + let a = f16::from_le_bytes([a[0], a[1]]); + let b = b.as_bytes(); + let b = f16::from_le_bytes([b[0], b[1]]); + return a > b; + } + a > b } @@ -1169,6 +1188,7 @@ fn increment_utf8(mut data: Vec) -> Option> { mod tests { use crate::{file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, format::BoundaryOrder}; use bytes::Bytes; + use half::f16; use rand::distributions::uniform::SampleUniform; use std::sync::Arc; @@ -2077,6 +2097,79 @@ mod tests { } } + #[test] + fn test_column_writer_check_float16_min_max() { + let input = [ + -f16::ONE, + f16::from_f32(3.0), + -f16::from_f32(2.0), + f16::from_f32(2.0), + ] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(-f16::from_f32(2.0))); + assert_eq!(stats.max(), &ByteArray::from(f16::from_f32(3.0))); + } + + #[test] + fn test_column_writer_check_float16_nan_middle() { + let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_middle() { + let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_start() { + let input = [f16::NAN, f16::ONE, f16::ONE + f16::ONE] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + assert_eq!(stats.min(), &ByteArray::from(f16::ONE)); + assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE)); + } + + #[test] + fn test_float16_statistics_nan_only() { + let input = [f16::NAN, f16::NAN] + .into_iter() + .map(|s| ByteArray::from(s).into()) + .collect::>(); + + let stats = float16_statistics_roundtrip(&input); + assert!(!stats.has_min_max_set()); + assert!(stats.is_min_max_backwards_compatible()); + } + #[test] fn test_float_statistics_nan_middle() { let stats = statistics_roundtrip::(&[1.0, f32::NAN, 2.0]); @@ -2735,6 +2828,50 @@ mod tests { ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) } + fn float16_statistics_roundtrip( + values: &[FixedLenByteArray], + ) -> ValueStatistics { + let page_writer = get_test_page_writer(); + let props = Default::default(); + let mut writer = + get_test_float16_column_writer::(page_writer, 0, 0, props); + writer.write_batch(values, None, None).unwrap(); + + let metadata = writer.close().unwrap().metadata; + if let Some(Statistics::FixedLenByteArray(stats)) = metadata.statistics() { + stats.clone() + } else { + panic!("metadata missing statistics"); + } + } + + fn get_test_float16_column_writer( + page_writer: Box, + max_def_level: i16, + max_rep_level: i16, + props: WriterPropertiesPtr, + ) -> ColumnWriterImpl<'static, T> { + let descr = Arc::new(get_test_float16_column_descr::( + max_def_level, + max_rep_level, + )); + let column_writer = get_column_writer(descr, props, page_writer); + get_typed_column_writer::(column_writer) + } + + fn get_test_float16_column_descr( + max_def_level: i16, + max_rep_level: i16, + ) -> ColumnDescriptor { + let path = ColumnPath::from("col"); + let tpe = SchemaType::primitive_type_builder("col", T::get_physical_type()) + .with_length(2) + .with_logical_type(Some(LogicalType::Float16)) + .build() + .unwrap(); + ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path) + } + /// Returns column writer for UINT32 Column provided as ConvertedType only fn get_test_unsigned_int_given_as_converted_column_writer<'a, T: DataType>( page_writer: Box, diff --git a/parquet/src/data_type.rs b/parquet/src/data_type.rs index 7e64478ed940..b1d52b75c723 100644 --- a/parquet/src/data_type.rs +++ b/parquet/src/data_type.rs @@ -18,6 +18,7 @@ //! Data types that connect Parquet physical types with their Rust-specific //! representations. use bytes::Bytes; +use half::f16; use std::cmp::Ordering; use std::fmt; use std::mem; @@ -231,6 +232,12 @@ impl From for ByteArray { } } +impl From for ByteArray { + fn from(value: f16) -> Self { + Self::from(value.to_le_bytes().as_slice()) + } +} + impl PartialEq for ByteArray { fn eq(&self, other: &ByteArray) -> bool { match (&self.data, &other.data) {