Skip to content

Commit

Permalink
f16 for record api
Browse files Browse the repository at this point in the history
  • Loading branch information
Jefffrey committed Nov 7, 2023
1 parent bf43eea commit af39f80
Showing 1 changed file with 85 additions and 3 deletions.
88 changes: 85 additions & 3 deletions parquet/src/record/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
use std::fmt;

use chrono::{TimeZone, Utc};
use half::f16;
use num::Float;
use num_bigint::{BigInt, Sign};

use crate::basic::{ConvertedType, Type as PhysicalType};
use crate::basic::{ConvertedType, LogicalType, Type as PhysicalType};
use crate::data_type::{ByteArray, Decimal, Int96};
use crate::errors::{ParquetError, Result};
use crate::schema::types::ColumnDescPtr;
Expand Down Expand Up @@ -121,6 +123,7 @@ pub trait RowAccessor {
fn get_ushort(&self, i: usize) -> Result<u16>;
fn get_uint(&self, i: usize) -> Result<u32>;
fn get_ulong(&self, i: usize) -> Result<u64>;
fn get_float16(&self, i: usize) -> Result<f16>;
fn get_float(&self, i: usize) -> Result<f32>;
fn get_double(&self, i: usize) -> Result<f64>;
fn get_timestamp_millis(&self, i: usize) -> Result<i64>;
Expand Down Expand Up @@ -215,6 +218,8 @@ impl RowAccessor for Row {

row_primitive_accessor!(get_ulong, ULong, u64);

row_primitive_accessor!(get_float16, Float16, f16);

row_primitive_accessor!(get_float, Float, f32);

row_primitive_accessor!(get_double, Double, f64);
Expand Down Expand Up @@ -293,6 +298,7 @@ pub trait ListAccessor {
fn get_ushort(&self, i: usize) -> Result<u16>;
fn get_uint(&self, i: usize) -> Result<u32>;
fn get_ulong(&self, i: usize) -> Result<u64>;
fn get_float16(&self, i: usize) -> Result<f16>;
fn get_float(&self, i: usize) -> Result<f32>;
fn get_double(&self, i: usize) -> Result<f64>;
fn get_timestamp_millis(&self, i: usize) -> Result<i64>;
Expand Down Expand Up @@ -358,6 +364,8 @@ impl ListAccessor for List {

list_primitive_accessor!(get_ulong, ULong, u64);

list_primitive_accessor!(get_float16, Float16, f16);

list_primitive_accessor!(get_float, Float, f32);

list_primitive_accessor!(get_double, Double, f64);
Expand Down Expand Up @@ -449,6 +457,8 @@ impl<'a> ListAccessor for MapList<'a> {

map_list_primitive_accessor!(get_ulong, ULong, u64);

map_list_primitive_accessor!(get_float16, Float16, f16);

map_list_primitive_accessor!(get_float, Float, f32);

map_list_primitive_accessor!(get_double, Double, f64);
Expand Down Expand Up @@ -510,6 +520,8 @@ pub enum Field {
UInt(u32),
// Unsigned integer UINT_64.
ULong(u64),
/// IEEE 16-bit floating point value.
Float16(f16),
/// IEEE 32-bit floating point value.
Float(f32),
/// IEEE 64-bit floating point value.
Expand Down Expand Up @@ -552,6 +564,7 @@ impl Field {
Field::UShort(_) => "UShort",
Field::UInt(_) => "UInt",
Field::ULong(_) => "ULong",
Field::Float16(_) => "Float16",
Field::Float(_) => "Float",
Field::Double(_) => "Double",
Field::Decimal(_) => "Decimal",
Expand Down Expand Up @@ -636,8 +649,8 @@ impl Field {
Field::Double(value)
}

/// Converts Parquet BYTE_ARRAY type with converted type into either UTF8 string or
/// array of bytes.
/// Converts Parquet BYTE_ARRAY type with converted type into a UTF8
/// string, decimal, float16, or an array of bytes.
#[inline]
pub fn convert_byte_array(descr: &ColumnDescPtr, value: ByteArray) -> Result<Self> {
let field = match descr.physical_type() {
Expand Down Expand Up @@ -666,6 +679,16 @@ impl Field {
descr.type_precision(),
descr.type_scale(),
)),
ConvertedType::NONE if descr.logical_type() == Some(LogicalType::Float16) => {
if value.len() != 2 {
return Err(general_err!(
"Error reading FIXED_LEN_BYTE_ARRAY as FLOAT16. Length must be 2, got {}",
value.len()
));
}
let bytes = [value.data()[0], value.data()[1]];
Field::Float16(f16::from_le_bytes(bytes))
}
ConvertedType::NONE => Field::Bytes(value),
_ => nyi!(descr, value),
},
Expand All @@ -690,6 +713,9 @@ impl Field {
Field::UShort(n) => Value::Number(serde_json::Number::from(*n)),
Field::UInt(n) => Value::Number(serde_json::Number::from(*n)),
Field::ULong(n) => Value::Number(serde_json::Number::from(*n)),
Field::Float16(n) => serde_json::Number::from_f64(f64::from(*n))
.map(Value::Number)
.unwrap_or(Value::Null),
Field::Float(n) => serde_json::Number::from_f64(f64::from(*n))
.map(Value::Number)
.unwrap_or(Value::Null),
Expand Down Expand Up @@ -736,6 +762,15 @@ impl fmt::Display for Field {
Field::UShort(value) => write!(f, "{value}"),
Field::UInt(value) => write!(f, "{value}"),
Field::ULong(value) => write!(f, "{value}"),
Field::Float16(value) => {
if !value.is_finite() {
write!(f, "{value}")
} else if value.trunc() == value {
write!(f, "{value}.0")
} else {
write!(f, "{value}")
}
}
Field::Float(value) => {
if !(1e-15..=1e19).contains(&value) {
write!(f, "{value:E}")
Expand Down Expand Up @@ -1069,6 +1104,24 @@ mod tests {
Field::Decimal(Decimal::from_bytes(value, 17, 5))
);

// FLOAT16
let descr = {
let tpe = PrimitiveTypeBuilder::new("col", PhysicalType::FIXED_LEN_BYTE_ARRAY)
.with_logical_type(Some(LogicalType::Float16))
.with_length(2)
.build()
.unwrap();
Arc::new(ColumnDescriptor::new(
Arc::new(tpe),
0,
0,
ColumnPath::from("col"),
))
};
let value = ByteArray::from(f16::PI);
let row = Field::convert_byte_array(&descr, value.clone());
assert_eq!(row.unwrap(), Field::Float16(f16::PI));

// NONE (FIXED_LEN_BYTE_ARRAY)
let descr = make_column_descr![
PhysicalType::FIXED_LEN_BYTE_ARRAY,
Expand Down Expand Up @@ -1145,6 +1198,18 @@ mod tests {
check_datetime_conversion(2014, 11, 28, 21, 15, 12);
}

#[test]
fn test_convert_float16_to_string() {
assert_eq!(format!("{}", Field::Float16(f16::ONE)), "1.0");
assert_eq!(format!("{}", Field::Float16(f16::PI)), "3.140625");
assert_eq!(format!("{}", Field::Float16(f16::MAX)), "65504.0");
assert_eq!(format!("{}", Field::Float16(f16::NAN)), "NaN");
assert_eq!(format!("{}", Field::Float16(f16::INFINITY)), "inf");
assert_eq!(format!("{}", Field::Float16(f16::NEG_INFINITY)), "-inf");
assert_eq!(format!("{}", Field::Float16(f16::ZERO)), "0.0");
assert_eq!(format!("{}", Field::Float16(f16::NEG_ZERO)), "-0.0");
}

#[test]
fn test_convert_float_to_string() {
assert_eq!(format!("{}", Field::Float(1.0)), "1.0");
Expand Down Expand Up @@ -1218,6 +1283,7 @@ mod tests {
assert_eq!(format!("{}", Field::UShort(2)), "2");
assert_eq!(format!("{}", Field::UInt(3)), "3");
assert_eq!(format!("{}", Field::ULong(4)), "4");
assert_eq!(format!("{}", Field::Float16(f16::E)), "2.71875");
assert_eq!(format!("{}", Field::Float(5.0)), "5.0");
assert_eq!(format!("{}", Field::Float(5.1234)), "5.1234");
assert_eq!(format!("{}", Field::Double(6.0)), "6.0");
Expand Down Expand Up @@ -1284,6 +1350,7 @@ mod tests {
assert!(Field::UShort(2).is_primitive());
assert!(Field::UInt(3).is_primitive());
assert!(Field::ULong(4).is_primitive());
assert!(Field::Float16(f16::E).is_primitive());
assert!(Field::Float(5.0).is_primitive());
assert!(Field::Float(5.1234).is_primitive());
assert!(Field::Double(6.0).is_primitive());
Expand Down Expand Up @@ -1344,6 +1411,7 @@ mod tests {
("15".to_string(), Field::TimestampMillis(1262391174000)),
("16".to_string(), Field::TimestampMicros(1262391174000000)),
("17".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))),
("18".to_string(), Field::Float16(f16::PI)),
]);

assert_eq!("null", format!("{}", row.fmt(0)));
Expand All @@ -1370,6 +1438,7 @@ mod tests {
format!("{}", row.fmt(16))
);
assert_eq!("0.04", format!("{}", row.fmt(17)));
assert_eq!("3.140625", format!("{}", row.fmt(18)));
}

#[test]
Expand Down Expand Up @@ -1429,6 +1498,7 @@ mod tests {
Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])),
),
("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))),
("p".to_string(), Field::Float16(f16::from_f32(9.1))),
]);

assert!(!row.get_bool(1).unwrap());
Expand All @@ -1445,6 +1515,7 @@ mod tests {
assert_eq!("abc", row.get_string(12).unwrap());
assert_eq!(5, row.get_bytes(13).unwrap().len());
assert_eq!(7, row.get_decimal(14).unwrap().precision());
assert!((f16::from_f32(9.1) - row.get_float16(15).unwrap()).abs() < f16::EPSILON);
}

#[test]
Expand All @@ -1469,6 +1540,7 @@ mod tests {
Field::Bytes(ByteArray::from(vec![1, 2, 3, 4, 5])),
),
("o".to_string(), Field::Decimal(Decimal::from_i32(4, 7, 2))),
("p".to_string(), Field::Float16(f16::from_f32(9.1))),
]);

for i in 0..row.len() {
Expand Down Expand Up @@ -1583,6 +1655,9 @@ mod tests {
let list = make_list(vec![Field::ULong(6), Field::ULong(7)]);
assert_eq!(7, list.get_ulong(1).unwrap());

let list = make_list(vec![Field::Float16(f16::PI)]);
assert!((f16::PI - list.get_float16(0).unwrap()).abs() < f16::EPSILON);

let list = make_list(vec![
Field::Float(8.1),
Field::Float(9.2),
Expand Down Expand Up @@ -1633,6 +1708,9 @@ mod tests {
let list = make_list(vec![Field::ULong(6), Field::ULong(7)]);
assert!(list.get_float(1).is_err());

let list = make_list(vec![Field::Float16(f16::PI)]);
assert!(list.get_string(0).is_err());

let list = make_list(vec![
Field::Float(8.1),
Field::Float(9.2),
Expand Down Expand Up @@ -1768,6 +1846,10 @@ mod tests {
Field::ULong(4).to_json_value(),
Value::Number(serde_json::Number::from(4))
);
assert_eq!(
Field::Float16(f16::from_f32(5.0)).to_json_value(),
Value::Number(serde_json::Number::from_f64(5.0).unwrap())
);
assert_eq!(
Field::Float(5.0).to_json_value(),
Value::Number(serde_json::Number::from_f64(5.0).unwrap())
Expand Down

0 comments on commit af39f80

Please sign in to comment.