Skip to content

Commit

Permalink
Handle NaN for f16 statistics writing
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefffrey committed Nov 7, 2023
1 parent af39f80 commit 40f3e5f
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 8 deletions.
4 changes: 2 additions & 2 deletions parquet/src/column/writer/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,15 +290,15 @@ where
{
let first = loop {
let next = iter.next()?;
if !is_nan(next) {
if !is_nan(descr, next) {
break next;
}
};

let mut min = first;
let mut max = first;
for val in iter {
if is_nan(val) {
if is_nan(descr, val) {
continue;
}
if compare_greater(descr, min, val) {
Expand Down
149 changes: 143 additions & 6 deletions parquet/src/column/writer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

//! Contains column writer API.
use half::f16;

use crate::bloom_filter::Sbbf;
use crate::format::{ColumnIndex, OffsetIndex};
use std::collections::{BTreeSet, VecDeque};
Expand Down Expand Up @@ -967,18 +969,23 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> {
}

fn update_min<T: ParquetValueType>(descr: &ColumnDescriptor, val: &T, min: &mut Option<T>) {
update_stat::<T, _>(val, min, |cur| compare_greater(descr, cur, val))
update_stat::<T, _>(descr, val, min, |cur| compare_greater(descr, cur, val))
}

fn update_max<T: ParquetValueType>(descr: &ColumnDescriptor, val: &T, max: &mut Option<T>) {
update_stat::<T, _>(val, max, |cur| compare_greater(descr, val, cur))
update_stat::<T, _>(descr, val, max, |cur| compare_greater(descr, val, cur))
}

#[inline]
#[allow(clippy::eq_op)]
fn is_nan<T: ParquetValueType>(val: &T) -> bool {
fn is_nan<T: ParquetValueType>(descr: &ColumnDescriptor, val: &T) -> bool {
match T::PHYSICAL_TYPE {
Type::FLOAT | Type::DOUBLE => val != val,
Type::FIXED_LEN_BYTE_ARRAY if descr.logical_type() == Some(LogicalType::Float16) => {
let val = val.as_bytes();
let val = f16::from_le_bytes([val[0], val[1]]);
val.is_nan()
}
_ => false,
}
}
Expand All @@ -988,11 +995,15 @@ fn is_nan<T: ParquetValueType>(val: &T) -> bool {
/// If `cur` is `None`, sets `cur` to `Some(val)`, otherwise calls `should_update` with
/// the value of `cur`, and updates `cur` to `Some(val)` if it returns `true`
fn update_stat<T: ParquetValueType, F>(val: &T, cur: &mut Option<T>, should_update: F)
where
fn update_stat<T: ParquetValueType, F>(
descr: &ColumnDescriptor,
val: &T,
cur: &mut Option<T>,
should_update: F,
) where
F: Fn(&T) -> bool,
{
if is_nan(val) {
if is_nan(descr, val) {
return;
}

Expand Down Expand Up @@ -1038,6 +1049,14 @@ fn compare_greater<T: ParquetValueType>(descr: &ColumnDescriptor, a: &T, b: &T)
};
};

if let Some(LogicalType::Float16) = descr.logical_type() {
let a = a.as_bytes();
let a = f16::from_le_bytes([a[0], a[1]]);
let b = b.as_bytes();
let b = f16::from_le_bytes([b[0], b[1]]);
return a > b;
}

a > b
}

Expand Down Expand Up @@ -1169,6 +1188,7 @@ fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> {
mod tests {
use crate::{file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, format::BoundaryOrder};
use bytes::Bytes;
use half::f16;
use rand::distributions::uniform::SampleUniform;
use std::sync::Arc;

Expand Down Expand Up @@ -2077,6 +2097,79 @@ mod tests {
}
}

#[test]
fn test_column_writer_check_float16_min_max() {
let input = [
-f16::ONE,
f16::from_f32(3.0),
-f16::from_f32(2.0),
f16::from_f32(2.0),
]
.into_iter()
.map(|s| ByteArray::from(s).into())
.collect::<Vec<_>>();

let stats = float16_statistics_roundtrip(&input);
assert!(stats.has_min_max_set());
assert!(stats.is_min_max_backwards_compatible());
assert_eq!(stats.min(), &ByteArray::from(-f16::from_f32(2.0)));
assert_eq!(stats.max(), &ByteArray::from(f16::from_f32(3.0)));
}

#[test]
fn test_column_writer_check_float16_nan_middle() {
let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE]
.into_iter()
.map(|s| ByteArray::from(s).into())
.collect::<Vec<_>>();

let stats = float16_statistics_roundtrip(&input);
assert!(stats.has_min_max_set());
assert!(stats.is_min_max_backwards_compatible());
assert_eq!(stats.min(), &ByteArray::from(f16::ONE));
assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE));
}

#[test]
fn test_float16_statistics_nan_middle() {
let input = [f16::ONE, f16::NAN, f16::ONE + f16::ONE]
.into_iter()
.map(|s| ByteArray::from(s).into())
.collect::<Vec<_>>();

let stats = float16_statistics_roundtrip(&input);
assert!(stats.has_min_max_set());
assert!(stats.is_min_max_backwards_compatible());
assert_eq!(stats.min(), &ByteArray::from(f16::ONE));
assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE));
}

#[test]
fn test_float16_statistics_nan_start() {
let input = [f16::NAN, f16::ONE, f16::ONE + f16::ONE]
.into_iter()
.map(|s| ByteArray::from(s).into())
.collect::<Vec<_>>();

let stats = float16_statistics_roundtrip(&input);
assert!(stats.has_min_max_set());
assert!(stats.is_min_max_backwards_compatible());
assert_eq!(stats.min(), &ByteArray::from(f16::ONE));
assert_eq!(stats.max(), &ByteArray::from(f16::ONE + f16::ONE));
}

#[test]
fn test_float16_statistics_nan_only() {
let input = [f16::NAN, f16::NAN]
.into_iter()
.map(|s| ByteArray::from(s).into())
.collect::<Vec<_>>();

let stats = float16_statistics_roundtrip(&input);
assert!(!stats.has_min_max_set());
assert!(stats.is_min_max_backwards_compatible());
}

#[test]
fn test_float_statistics_nan_middle() {
let stats = statistics_roundtrip::<FloatType>(&[1.0, f32::NAN, 2.0]);
Expand Down Expand Up @@ -2735,6 +2828,50 @@ mod tests {
ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path)
}

fn float16_statistics_roundtrip(
values: &[FixedLenByteArray],
) -> ValueStatistics<FixedLenByteArray> {
let page_writer = get_test_page_writer();
let props = Default::default();
let mut writer =
get_test_float16_column_writer::<FixedLenByteArrayType>(page_writer, 0, 0, props);
writer.write_batch(values, None, None).unwrap();

let metadata = writer.close().unwrap().metadata;
if let Some(Statistics::FixedLenByteArray(stats)) = metadata.statistics() {
stats.clone()
} else {
panic!("metadata missing statistics");
}
}

fn get_test_float16_column_writer<T: DataType>(
page_writer: Box<dyn PageWriter>,
max_def_level: i16,
max_rep_level: i16,
props: WriterPropertiesPtr,
) -> ColumnWriterImpl<'static, T> {
let descr = Arc::new(get_test_float16_column_descr::<T>(
max_def_level,
max_rep_level,
));
let column_writer = get_column_writer(descr, props, page_writer);
get_typed_column_writer::<T>(column_writer)
}

fn get_test_float16_column_descr<T: DataType>(
max_def_level: i16,
max_rep_level: i16,
) -> ColumnDescriptor {
let path = ColumnPath::from("col");
let tpe = SchemaType::primitive_type_builder("col", T::get_physical_type())
.with_length(2)
.with_logical_type(Some(LogicalType::Float16))
.build()
.unwrap();
ColumnDescriptor::new(Arc::new(tpe), max_def_level, max_rep_level, path)
}

/// Returns column writer for UINT32 Column provided as ConvertedType only
fn get_test_unsigned_int_given_as_converted_column_writer<'a, T: DataType>(
page_writer: Box<dyn PageWriter + 'a>,
Expand Down
7 changes: 7 additions & 0 deletions parquet/src/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Data types that connect Parquet physical types with their Rust-specific
//! representations.
use bytes::Bytes;
use half::f16;
use std::cmp::Ordering;
use std::fmt;
use std::mem;
Expand Down Expand Up @@ -231,6 +232,12 @@ impl From<Bytes> for ByteArray {
}
}

impl From<f16> for ByteArray {
fn from(value: f16) -> Self {
Self::from(value.to_le_bytes().as_slice())
}
}

impl PartialEq for ByteArray {
fn eq(&self, other: &ByteArray) -> bool {
match (&self.data, &other.data) {
Expand Down

0 comments on commit 40f3e5f

Please sign in to comment.