-
Notifications
You must be signed in to change notification settings - Fork 825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improvements to UTF-8 statistics truncation #6870
Changes from all commits
0f7af02
52706f9
80fa0dd
f1726ab
7b88e91
c4d9474
7a7fd0e
400f5f8
006a388
f251b00
e7d0af8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -878,24 +878,67 @@ impl<'a, E: ColumnValueEncoder> GenericColumnWriter<'a, E> { | |
} | ||
} | ||
|
||
/// Returns `true` if this column's logical type is a UTF-8 string. | ||
fn is_utf8(&self) -> bool { | ||
self.get_descriptor().logical_type() == Some(LogicalType::String) | ||
|| self.get_descriptor().converted_type() == ConvertedType::UTF8 | ||
} | ||
|
||
/// Truncates a binary statistic to at most `truncation_length` bytes. | ||
/// | ||
/// If truncation is not possible, returns `data`. | ||
/// | ||
/// The `bool` in the returned tuple indicates whether truncation occurred or not. | ||
/// | ||
/// UTF-8 Note: | ||
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will | ||
/// also remain valid UTF-8, but may be less tnan `truncation_length` bytes to avoid splitting | ||
/// on non-character boundaries. | ||
fn truncate_min_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) { | ||
truncation_length | ||
.filter(|l| data.len() > *l) | ||
.and_then(|l| match str::from_utf8(data) { | ||
Ok(str_data) => truncate_utf8(str_data, l), | ||
Err(_) => Some(data[..l].to_vec()), | ||
}) | ||
.and_then(|l| | ||
// don't do extra work if this column isn't UTF-8 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💯 |
||
if self.is_utf8() { | ||
match str::from_utf8(data) { | ||
Ok(str_data) => truncate_utf8(str_data, l), | ||
Err(_) => Some(data[..l].to_vec()), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it is a somewhat questionable move to truncate this on invalid data, but I see that is wht the code used to do so seems good to me There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, good point. The old code simply tried utf first, and then fell back. Here we're actually expecting valid UTF8 so perhaps it's better to return an error. I'd hope some string validation was done before getting this far. I'll think on this some more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should leave it as is and maybe document that if non utf8 data is passed in it will be truncated with bytes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've managed a test that exercises this path via the To paraphrase a wise man I know: Every day I wake up. And then I remember Parquet exists. 🫤 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've left the logic here as is, but added documentation and a test. We can revisit if this ever becomes an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I solace myself with this quote from a former coworker:
Not that we can't / shouldn't improve it of course 🤣 thanks again for all the help here |
||
} | ||
} else { | ||
Some(data[..l].to_vec()) | ||
} | ||
) | ||
.map(|truncated| (truncated, true)) | ||
.unwrap_or_else(|| (data.to_vec(), false)) | ||
} | ||
|
||
/// Truncates a binary statistic to at most `truncation_length` bytes, and then increment the | ||
/// final byte(s) to yield a valid upper bound. This may result in a result of less than | ||
/// `truncation_length` bytes if the last byte(s) overflows. | ||
/// | ||
/// If truncation is not possible, returns `data`. | ||
/// | ||
/// The `bool` in the returned tuple indicates whether truncation occurred or not. | ||
/// | ||
/// UTF-8 Note: | ||
/// If the column type indicates UTF-8, and `data` contains valid UTF-8, then the result will | ||
/// also remain valid UTF-8 (but again may be less than `truncation_length` bytes). If `data` | ||
/// does not contain valid UTF-8, then truncation will occur as if the column is non-string | ||
/// binary. | ||
fn truncate_max_value(&self, truncation_length: Option<usize>, data: &[u8]) -> (Vec<u8>, bool) { | ||
truncation_length | ||
.filter(|l| data.len() > *l) | ||
.and_then(|l| match str::from_utf8(data) { | ||
Ok(str_data) => truncate_utf8(str_data, l).and_then(increment_utf8), | ||
Err(_) => increment(data[..l].to_vec()), | ||
}) | ||
.and_then(|l| | ||
// don't do extra work if this column isn't UTF-8 | ||
if self.is_utf8() { | ||
match str::from_utf8(data) { | ||
Ok(str_data) => truncate_and_increment_utf8(str_data, l), | ||
Err(_) => increment(data[..l].to_vec()), | ||
} | ||
} else { | ||
increment(data[..l].to_vec()) | ||
} | ||
) | ||
.map(|truncated| (truncated, true)) | ||
.unwrap_or_else(|| (data.to_vec(), false)) | ||
} | ||
|
@@ -1418,13 +1461,50 @@ fn compare_greater_byte_array_decimals(a: &[u8], b: &[u8]) -> bool { | |
(a[1..]) > (b[1..]) | ||
} | ||
|
||
/// Truncate a UTF8 slice to the longest prefix that is still a valid UTF8 string, | ||
/// while being less than `length` bytes and non-empty | ||
/// Truncate a UTF-8 slice to the longest prefix that is still a valid UTF-8 string, | ||
/// while being less than `length` bytes and non-empty. Returns `None` if truncation | ||
/// is not possible within those constraints. | ||
/// | ||
/// The caller guarantees that data.len() > length. | ||
fn truncate_utf8(data: &str, length: usize) -> Option<Vec<u8>> { | ||
let split = (1..=length).rfind(|x| data.is_char_boundary(*x))?; | ||
Some(data.as_bytes()[..split].to_vec()) | ||
} | ||
|
||
/// Truncate a UTF-8 slice and increment it's final character. The returned value is the | ||
/// longest such slice that is still a valid UTF-8 string while being less than `length` | ||
/// bytes and non-empty. Returns `None` if no such transformation is possible. | ||
/// | ||
/// The caller guarantees that data.len() > length. | ||
fn truncate_and_increment_utf8(data: &str, length: usize) -> Option<Vec<u8>> { | ||
// UTF-8 is max 4 bytes, so start search 3 back from desired length | ||
let lower_bound = length.saturating_sub(3); | ||
let split = (lower_bound..=length).rfind(|x| data.is_char_boundary(*x))?; | ||
increment_utf8(data.get(..split)?) | ||
} | ||
|
||
/// Increment the final character in a UTF-8 string in such a way that the returned result | ||
/// is still a valid UTF-8 string. The returned string may be shorter than the input if the | ||
/// last character(s) cannot be incremented (due to overflow or producing invalid code points). | ||
/// Returns `None` if the string cannot be incremented. | ||
/// | ||
/// Note that this implementation will not promote an N-byte code point to (N+1) bytes. | ||
fn increment_utf8(data: &str) -> Option<Vec<u8>> { | ||
for (idx, original_char) in data.char_indices().rev() { | ||
let original_len = original_char.len_utf8(); | ||
if let Some(next_char) = char::from_u32(original_char as u32 + 1) { | ||
// do not allow increasing byte width of incremented char | ||
etseidl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if next_char.len_utf8() == original_len { | ||
let mut result = data.as_bytes()[..idx + original_len].to_vec(); | ||
next_char.encode_utf8(&mut result[idx..]); | ||
return Some(result); | ||
} | ||
} | ||
} | ||
|
||
None | ||
} | ||
|
||
/// Try and increment the bytes from right to left. | ||
/// | ||
/// Returns `None` if all bytes are set to `u8::MAX`. | ||
|
@@ -1441,29 +1521,15 @@ fn increment(mut data: Vec<u8>) -> Option<Vec<u8>> { | |
None | ||
} | ||
|
||
/// Try and increment the the string's bytes from right to left, returning when the result | ||
/// is a valid UTF8 string. Returns `None` when it can't increment any byte. | ||
fn increment_utf8(mut data: Vec<u8>) -> Option<Vec<u8>> { | ||
for idx in (0..data.len()).rev() { | ||
let original = data[idx]; | ||
let (byte, overflow) = original.overflowing_add(1); | ||
if !overflow { | ||
data[idx] = byte; | ||
if str::from_utf8(&data).is_ok() { | ||
return Some(data); | ||
} | ||
data[idx] = original; | ||
} | ||
} | ||
|
||
None | ||
} | ||
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::file::properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH; | ||
use crate::{ | ||
file::{properties::DEFAULT_COLUMN_INDEX_TRUNCATE_LENGTH, writer::SerializedFileWriter}, | ||
schema::parser::parse_message_type, | ||
}; | ||
use core::str; | ||
use rand::distributions::uniform::SampleUniform; | ||
use std::sync::Arc; | ||
use std::{fs::File, sync::Arc}; | ||
|
||
use crate::column::{ | ||
page::PageReader, | ||
|
@@ -3140,39 +3206,69 @@ mod tests { | |
|
||
#[test] | ||
fn test_increment_utf8() { | ||
let test_inc = |o: &str, expected: &str| { | ||
if let Ok(v) = String::from_utf8(increment_utf8(o).unwrap()) { | ||
// Got the expected result... | ||
assert_eq!(v, expected); | ||
// and it's greater than the original string | ||
assert!(*v > *o); | ||
// Also show that BinaryArray level comparison works here | ||
let mut greater = ByteArray::new(); | ||
greater.set_data(Bytes::from(v)); | ||
let mut original = ByteArray::new(); | ||
original.set_data(Bytes::from(o.as_bytes().to_vec())); | ||
assert!(greater > original); | ||
etseidl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} else { | ||
panic!("Expected incremented UTF8 string to also be valid."); | ||
} | ||
}; | ||
|
||
// Basic ASCII case | ||
let v = increment_utf8("hello".as_bytes().to_vec()).unwrap(); | ||
assert_eq!(&v, "hellp".as_bytes()); | ||
test_inc("hello", "hellp"); | ||
|
||
// 1-byte ending in max 1-byte | ||
test_inc("a\u{7f}", "b"); | ||
|
||
// Also show that BinaryArray level comparison works here | ||
let mut greater = ByteArray::new(); | ||
greater.set_data(Bytes::from(v)); | ||
let mut original = ByteArray::new(); | ||
original.set_data(Bytes::from("hello".as_bytes().to_vec())); | ||
assert!(greater > original); | ||
// 1-byte max should not truncate as it would need 2-byte code points | ||
assert!(increment_utf8("\u{7f}\u{7f}").is_none()); | ||
|
||
// UTF8 string | ||
let s = "❤️🧡💛💚💙💜"; | ||
let v = increment_utf8(s.as_bytes().to_vec()).unwrap(); | ||
test_inc("❤️🧡💛💚💙💜", "❤️🧡💛💚💙💝"); | ||
|
||
if let Ok(new) = String::from_utf8(v) { | ||
assert_ne!(&new, s); | ||
assert_eq!(new, "❤️🧡💛💚💙💝"); | ||
assert!(new.as_bytes().last().unwrap() > s.as_bytes().last().unwrap()); | ||
} else { | ||
panic!("Expected incremented UTF8 string to also be valid.") | ||
} | ||
// 2-byte without overflow | ||
test_inc("éééé", "éééê"); | ||
|
||
// Max UTF8 character - should be a No-Op | ||
let s = char::MAX.to_string(); | ||
assert_eq!(s.len(), 4); | ||
let v = increment_utf8(s.as_bytes().to_vec()); | ||
assert!(v.is_none()); | ||
// 2-byte that overflows lowest byte | ||
test_inc("\u{ff}\u{ff}", "\u{ff}\u{100}"); | ||
etseidl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// 2-byte ending in max 2-byte | ||
test_inc("a\u{7ff}", "b"); | ||
|
||
// Max 2-byte should not truncate as it would need 3-byte code points | ||
assert!(increment_utf8("\u{7ff}\u{7ff}").is_none()); | ||
|
||
// 3-byte without overflow [U+800, U+800] -> [U+800, U+801] (note that these | ||
// characters should render right to left). | ||
test_inc("ࠀࠀ", "ࠀࠁ"); | ||
alamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// 3-byte ending in max 3-byte | ||
test_inc("a\u{ffff}", "b"); | ||
|
||
// Max 3-byte should not truncate as it would need 4-byte code points | ||
assert!(increment_utf8("\u{ffff}\u{ffff}").is_none()); | ||
|
||
// Handle multi-byte UTF8 characters | ||
let s = "a\u{10ffff}"; | ||
let v = increment_utf8(s.as_bytes().to_vec()); | ||
assert_eq!(&v.unwrap(), "b\u{10ffff}".as_bytes()); | ||
// 4-byte without overflow | ||
test_inc("𐀀𐀀", "𐀀𐀁"); | ||
|
||
// 4-byte ending in max unicode | ||
test_inc("a\u{10ffff}", "b"); | ||
|
||
// Max 4-byte should not truncate | ||
assert!(increment_utf8("\u{10ffff}\u{10ffff}").is_none()); | ||
|
||
// Skip over surrogate pair range (0xD800..=0xDFFF) | ||
//test_inc("a\u{D7FF}", "a\u{e000}"); | ||
test_inc("a\u{D7FF}", "b"); | ||
} | ||
|
||
#[test] | ||
|
@@ -3182,7 +3278,6 @@ mod tests { | |
let r = truncate_utf8(data, data.as_bytes().len()).unwrap(); | ||
assert_eq!(r.len(), data.as_bytes().len()); | ||
assert_eq!(&r, data.as_bytes()); | ||
println!("len is {}", data.len()); | ||
|
||
// We slice it away from the UTF8 boundary | ||
let r = truncate_utf8(data, 13).unwrap(); | ||
|
@@ -3192,6 +3287,90 @@ mod tests { | |
// One multi-byte code point, and a length shorter than it, so we can't slice it | ||
let r = truncate_utf8("\u{0836}", 1); | ||
assert!(r.is_none()); | ||
|
||
// Test truncate and increment for max bounds on UTF-8 statistics | ||
// 7-bit (i.e. ASCII) | ||
let r = truncate_and_increment_utf8("yyyyyyyyy", 8).unwrap(); | ||
assert_eq!(&r, "yyyyyyyz".as_bytes()); | ||
|
||
// 2-byte without overflow | ||
let r = truncate_and_increment_utf8("ééééé", 7).unwrap(); | ||
assert_eq!(&r, "ééê".as_bytes()); | ||
|
||
// 2-byte that overflows lowest byte | ||
let r = truncate_and_increment_utf8("\u{ff}\u{ff}\u{ff}\u{ff}\u{ff}", 8).unwrap(); | ||
assert_eq!(&r, "\u{ff}\u{ff}\u{ff}\u{100}".as_bytes()); | ||
|
||
// max 2-byte should not truncate as it would need 3-byte code points | ||
let r = truncate_and_increment_utf8("߿߿߿߿߿", 8); | ||
assert!(r.is_none()); | ||
|
||
// 3-byte without overflow [U+800, U+800, U+800] -> [U+800, U+801] (note that these | ||
// characters should render right to left). | ||
let r = truncate_and_increment_utf8("ࠀࠀࠀࠀ", 8).unwrap(); | ||
assert_eq!(&r, "ࠀࠁ".as_bytes()); | ||
|
||
// max 3-byte should not truncate as it would need 4-byte code points | ||
let r = truncate_and_increment_utf8("\u{ffff}\u{ffff}\u{ffff}", 8); | ||
assert!(r.is_none()); | ||
|
||
// 4-byte without overflow | ||
let r = truncate_and_increment_utf8("𐀀𐀀𐀀𐀀", 9).unwrap(); | ||
assert_eq!(&r, "𐀀𐀁".as_bytes()); | ||
|
||
// max 4-byte should not truncate | ||
let r = truncate_and_increment_utf8("\u{10ffff}\u{10ffff}", 8); | ||
assert!(r.is_none()); | ||
} | ||
|
||
#[test] | ||
// Check fallback truncation of statistics that should be UTF-8, but aren't | ||
// (see https://github.com/apache/arrow-rs/pull/6870). | ||
fn test_byte_array_truncate_invalid_utf8_statistics() { | ||
let message_type = " | ||
message test_schema { | ||
OPTIONAL BYTE_ARRAY a (UTF8); | ||
} | ||
"; | ||
let schema = Arc::new(parse_message_type(message_type).unwrap()); | ||
|
||
// Create Vec<ByteArray> containing non-UTF8 bytes | ||
let data = vec![ByteArray::from(vec![128u8; 32]); 7]; | ||
let def_levels = [1, 1, 1, 1, 0, 1, 0, 1, 0, 1]; | ||
let file: File = tempfile::tempfile().unwrap(); | ||
let props = Arc::new( | ||
WriterProperties::builder() | ||
.set_statistics_enabled(EnabledStatistics::Chunk) | ||
.set_statistics_truncate_length(Some(8)) | ||
.build(), | ||
); | ||
|
||
let mut writer = SerializedFileWriter::new(&file, schema, props).unwrap(); | ||
let mut row_group_writer = writer.next_row_group().unwrap(); | ||
|
||
let mut col_writer = row_group_writer.next_column().unwrap().unwrap(); | ||
col_writer | ||
.typed::<ByteArrayType>() | ||
.write_batch(&data, Some(&def_levels), None) | ||
.unwrap(); | ||
col_writer.close().unwrap(); | ||
row_group_writer.close().unwrap(); | ||
let file_metadata = writer.close().unwrap(); | ||
assert!(file_metadata.row_groups[0].columns[0].meta_data.is_some()); | ||
let stats = file_metadata.row_groups[0].columns[0] | ||
.meta_data | ||
.as_ref() | ||
.unwrap() | ||
.statistics | ||
.as_ref() | ||
.unwrap(); | ||
assert!(!stats.is_max_value_exact.unwrap()); | ||
// Truncation of invalid UTF-8 should fall back to binary truncation, so last byte should | ||
// be incremented by 1. | ||
assert_eq!( | ||
stats.max_value, | ||
Some([128, 128, 128, 128, 128, 128, 128, 129].to_vec()) | ||
); | ||
} | ||
|
||
#[test] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume this works for dictionary encoded columns as well right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, regardless of the encoding, the statistics are for the data itself. You wouldn't see a dictionary key here.