Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert some panics that happen on invalid parquet files to error results #6738

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions parquet/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Common Parquet errors and macros.

use core::num::TryFromIntError;
use std::error::Error;
use std::{cell, io, result, str};

Expand Down Expand Up @@ -81,6 +82,12 @@ impl Error for ParquetError {
}
}

impl From<TryFromIntError> for ParquetError {
fn from(e: TryFromIntError) -> ParquetError {
ParquetError::General(format!("Integer overflow: {e}"))
}
}

impl From<io::Error> for ParquetError {
fn from(e: io::Error) -> ParquetError {
ParquetError::External(Box::new(e))
Expand Down
26 changes: 13 additions & 13 deletions parquet/src/file/metadata/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,8 @@ impl ParquetMetaDataReader {
for rg in t_file_metadata.row_groups {
row_groups.push(RowGroupMetaData::from_thrift(schema_descr.clone(), rg)?);
}
let column_orders = Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr);
let column_orders =
Self::parse_column_orders(t_file_metadata.column_orders, &schema_descr)?;
Comment on lines +630 to +631
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this would currently panic, but would one ever prefer to just set column_orders to None and continue? The only impact AFAIK would be statistics being unusable, which would only matter if predicates were in use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point! i agree with the setting to None idea. but i guess this worths a separate issue to discuss and fix.


let file_metadata = FileMetaData::new(
t_file_metadata.version,
Expand All @@ -645,15 +646,13 @@ impl ParquetMetaDataReader {
fn parse_column_orders(
t_column_orders: Option<Vec<TColumnOrder>>,
schema_descr: &SchemaDescriptor,
) -> Option<Vec<ColumnOrder>> {
) -> Result<Option<Vec<ColumnOrder>>> {
match t_column_orders {
Some(orders) => {
// Should always be the case
assert_eq!(
orders.len(),
schema_descr.num_columns(),
"Column order length mismatch"
);
if orders.len() != schema_descr.num_columns() {
return Err(general_err!("Column order length mismatch"));
};
let mut res = Vec::new();
for (i, column) in schema_descr.columns().iter().enumerate() {
match orders[i] {
Expand All @@ -667,9 +666,9 @@ impl ParquetMetaDataReader {
}
}
}
Some(res)
Ok(Some(res))
}
None => None,
None => Ok(None),
}
}
}
Expand Down Expand Up @@ -741,7 +740,7 @@ mod tests {
]);

assert_eq!(
ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr),
ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr).unwrap(),
Some(vec![
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED),
ColumnOrder::TYPE_DEFINED_ORDER(SortOrder::SIGNED)
Expand All @@ -750,20 +749,21 @@ mod tests {

// Test when no column orders are defined.
assert_eq!(
ParquetMetaDataReader::parse_column_orders(None, &schema_descr),
ParquetMetaDataReader::parse_column_orders(None, &schema_descr).unwrap(),
None
);
}

#[test]
#[should_panic(expected = "Column order length mismatch")]
fn test_metadata_column_orders_len_mismatch() {
let schema = SchemaType::group_type_builder("schema").build().unwrap();
let schema_descr = SchemaDescriptor::new(Arc::new(schema));

let t_column_orders = Some(vec![TColumnOrder::TYPEORDER(TypeDefinedOrder::new())]);

ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr);
let res = ParquetMetaDataReader::parse_column_orders(t_column_orders, &schema_descr);
assert!(res.is_err());
assert!(format!("{:?}", res.unwrap_err()).contains("Column order length mismatch"));
}

#[test]
Expand Down
53 changes: 46 additions & 7 deletions parquet/src/file/serialized_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ pub(crate) fn decode_page(
let is_sorted = dict_header.is_sorted.unwrap_or(false);
Page::DictionaryPage {
buf: buffer,
num_values: dict_header.num_values as u32,
num_values: dict_header.num_values.try_into()?,
encoding: Encoding::try_from(dict_header.encoding)?,
is_sorted,
}
Expand All @@ -446,7 +446,7 @@ pub(crate) fn decode_page(
.ok_or_else(|| ParquetError::General("Missing V1 data page header".to_string()))?;
Page::DataPage {
buf: buffer,
num_values: header.num_values as u32,
num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
def_level_encoding: Encoding::try_from(header.definition_level_encoding)?,
rep_level_encoding: Encoding::try_from(header.repetition_level_encoding)?,
Expand All @@ -460,12 +460,12 @@ pub(crate) fn decode_page(
let is_compressed = header.is_compressed.unwrap_or(true);
Page::DataPageV2 {
buf: buffer,
num_values: header.num_values as u32,
num_values: header.num_values.try_into()?,
encoding: Encoding::try_from(header.encoding)?,
num_nulls: header.num_nulls as u32,
num_rows: header.num_rows as u32,
def_levels_byte_len: header.definition_levels_byte_length as u32,
rep_levels_byte_len: header.repetition_levels_byte_length as u32,
num_nulls: header.num_nulls.try_into()?,
num_rows: header.num_rows.try_into()?,
def_levels_byte_len: header.definition_levels_byte_length.try_into()?,
rep_levels_byte_len: header.repetition_levels_byte_length.try_into()?,
is_compressed,
statistics: statistics::from_thrift(physical_type, header.statistics)?,
}
Expand Down Expand Up @@ -578,6 +578,27 @@ impl<R: ChunkReader> Iterator for SerializedPageReader<R> {
}
}

fn verify_page_header_len(header_len: usize, remaining_bytes: usize) -> Result<()> {
if header_len > remaining_bytes {
return Err(eof_err!("Invalid page header"));
}
Ok(())
}

fn verify_page_size(
compressed_size: i32,
uncompressed_size: i32,
remaining_bytes: usize,
) -> Result<()> {
// The page's compressed size should not exceed the remaining bytes that are
// available to read. The page's uncompressed size is the expected size
// after decompression, which can never be negative.
if compressed_size < 0 || compressed_size as usize > remaining_bytes || uncompressed_size < 0 {
return Err(eof_err!("Invalid page header"));
}
Ok(())
}

impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
fn get_next_page(&mut self) -> Result<Option<Page>> {
loop {
Expand All @@ -596,10 +617,16 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
*header
} else {
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining)?;
*offset += header_len;
*remaining -= header_len;
header
};
verify_page_size(
header.compressed_page_size,
header.uncompressed_page_size,
*remaining,
)?;
let data_len = header.compressed_page_size as usize;
*offset += data_len;
*remaining -= data_len;
Expand Down Expand Up @@ -683,6 +710,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining_bytes)?;
*offset += header_len;
*remaining_bytes -= header_len;
let page_meta = if let Ok(page_meta) = (&header).try_into() {
Expand Down Expand Up @@ -733,12 +761,23 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
next_page_header,
} => {
if let Some(buffered_header) = next_page_header.take() {
verify_page_size(
buffered_header.compressed_page_size,
buffered_header.uncompressed_page_size,
*remaining_bytes,
)?;
// The next page header has already been peeked, so just advance the offset
*offset += buffered_header.compressed_page_size as usize;
*remaining_bytes -= buffered_header.compressed_page_size as usize;
} else {
let mut read = self.reader.get_read(*offset as u64)?;
let (header_len, header) = read_page_header_len(&mut read)?;
verify_page_header_len(header_len, *remaining_bytes)?;
verify_page_size(
header.compressed_page_size,
header.uncompressed_page_size,
*remaining_bytes,
)?;
let data_page_size = header.compressed_page_size as usize;
*offset += header_len + data_page_size;
*remaining_bytes -= header_len + data_page_size;
Expand Down
26 changes: 26 additions & 0 deletions parquet/src/file/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,32 @@ pub fn from_thrift(
stats.max_value
};

fn check_len(min: &Option<Vec<u8>>, max: &Option<Vec<u8>>, len: usize) -> Result<()> {
if let Some(min) = min {
if min.len() < len {
return Err(ParquetError::General(
"Insufficient bytes to parse min statistic".to_string(),
));
}
}
if let Some(max) = max {
if max.len() < len {
return Err(ParquetError::General(
"Insufficient bytes to parse max statistic".to_string(),
));
}
}
Ok(())
}

match physical_type {
Type::BOOLEAN => check_len(&min, &max, 1),
Type::INT32 | Type::FLOAT => check_len(&min, &max, 4),
Type::INT64 | Type::DOUBLE => check_len(&min, &max, 8),
Type::INT96 => check_len(&min, &max, 12),
_ => Ok(()),
}?;

// Values are encoded using PLAIN encoding definition, except that
// variable-length byte arrays do not include a length prefix.
//
Expand Down
22 changes: 21 additions & 1 deletion parquet/src/schema/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,8 @@ impl<'a> PrimitiveTypeBuilder<'a> {
}
}
PhysicalType::FIXED_LEN_BYTE_ARRAY => {
let max_precision = (2f64.powi(8 * self.length - 1) - 1f64).log10().floor() as i32;
let length = self.length.checked_mul(8).unwrap_or(i32::MAX);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the overflow error instead of falling through to max precision?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree it would be better to return an error:

I double checked that using i32::MAX results in

Max precision: 2147483647

https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=8b510b23f773540e2b348e80f0691b41

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, changed to returning error, thanks!

let max_precision = (2f64.powi(length - 1) - 1f64).log10().floor() as i32;

if self.precision > max_precision {
return Err(general_err!(
Expand Down Expand Up @@ -1171,9 +1172,25 @@ pub fn from_thrift(elements: &[SchemaElement]) -> Result<TypePtr> {
));
}

if !schema_nodes[0].is_group() {
return Err(general_err!("Expected root node to be a group type"));
}

Ok(schema_nodes.remove(0))
}

/// Checks if the logical type is valid.
fn check_logical_type(logical_type: &Option<LogicalType>) -> Result<()> {
if let Some(LogicalType::Integer { bit_width, .. }) = *logical_type {
if bit_width != 8 && bit_width != 16 && bit_width != 32 && bit_width != 64 {
return Err(general_err!(
"Bit width must be 8, 16, 32, or 64 for Integer logical type"
));
}
}
Ok(())
}

/// Constructs a new Type from the `elements`, starting at index `index`.
/// The first result is the starting index for the next Type after this one. If it is
/// equal to `elements.len()`, then this Type is the last one.
Expand All @@ -1198,6 +1215,9 @@ fn from_thrift_helper(elements: &[SchemaElement], index: usize) -> Result<(usize
.logical_type
.as_ref()
.map(|value| LogicalType::from(value.clone()));

check_logical_type(&logical_type)?;

let field_id = elements[index].field_id;
match elements[index].num_children {
// From parquet-format:
Expand Down
35 changes: 29 additions & 6 deletions parquet/src/thrift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl<'a> TCompactSliceInputProtocol<'a> {
let mut shift = 0;
loop {
let byte = self.read_byte()?;
in_progress |= ((byte & 0x7F) as u64) << shift;
in_progress |= ((byte & 0x7F) as u64).wrapping_shl(shift);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the purpose of this change?

As I understand it << panics in debug builds on overflow but wraps in release builds

It seems to me that this change now avoids panic'ing in debug builds too, which isn't obviously better to me

Copy link
Contributor Author

@jp0317 jp0317 Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for bringing this up! I'm not sure if << is guaranteed to always have the same behavior as wrapping_shl in release builds. On the other hand, imho this is not really a logic bug, and the only benefit keeping << is that debug mode can be used to identify invalid inputs (via panics), but i'm not sure if anyone will rely on that in practice.

Copy link
Contributor

@emkornfield emkornfield Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note the C++ code for this loop actually validates [total bytes decoded] (https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TCompactProtocol.tcc#L759) which is probably a good idea (I think this prevents the panic that the original checked shift was meant to do?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointer. iiuc the c++ one may still have an overflow on the 10th bytes where it shifts 63 bits?

Copy link
Contributor

@emkornfield emkornfield Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding but I thoght shifting by 63 bits is well defined on u64? So overflow yes but I wouldn't expect a panic? Also I think silently passing here if there is an overflow would be data corrupt data?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can consider this change independently if we are concerned about perf impacts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

overflow yes, i agree we can add check on # of bytes decoded separately. For this cl let's just solve the panic ?

shift += 7;
if byte & 0x80 == 0 {
return Ok(in_progress);
Expand Down Expand Up @@ -96,13 +96,22 @@ impl<'a> TCompactSliceInputProtocol<'a> {
}
}

macro_rules! thrift_unimplemented {
() => {
Err(thrift::Error::Protocol(thrift::ProtocolError {
kind: thrift::ProtocolErrorKind::NotImplemented,
message: "not implemented".to_string(),
}))
};
}

impl TInputProtocol for TCompactSliceInputProtocol<'_> {
fn read_message_begin(&mut self) -> thrift::Result<TMessageIdentifier> {
unimplemented!()
}

fn read_message_end(&mut self) -> thrift::Result<()> {
unimplemented!()
thrift_unimplemented!()
}

fn read_struct_begin(&mut self) -> thrift::Result<Option<TStructIdentifier>> {
Expand Down Expand Up @@ -147,7 +156,21 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
),
_ => {
if field_delta != 0 {
self.last_read_field_id += field_delta as i16;
self.last_read_field_id = self
.last_read_field_id
.checked_add(field_delta as i16)
.map_or_else(
|| {
Err(thrift::Error::Protocol(thrift::ProtocolError {
kind: thrift::ProtocolErrorKind::InvalidData,
message: format!(
"cannot add {} to {}",
field_delta, self.last_read_field_id
),
}))
},
Ok,
)?;
} else {
self.last_read_field_id = self.read_i16()?;
};
Expand Down Expand Up @@ -226,15 +249,15 @@ impl TInputProtocol for TCompactSliceInputProtocol<'_> {
}

fn read_set_begin(&mut self) -> thrift::Result<TSetIdentifier> {
unimplemented!()
thrift_unimplemented!()
}

fn read_set_end(&mut self) -> thrift::Result<()> {
unimplemented!()
thrift_unimplemented!()
}

fn read_map_begin(&mut self) -> thrift::Result<TMapIdentifier> {
unimplemented!()
thrift_unimplemented!()
}

fn read_map_end(&mut self) -> thrift::Result<()> {
Expand Down
2 changes: 1 addition & 1 deletion parquet/tests/arrow_reader/bad_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ fn test_arrow_rs_gh_6229_dict_header() {
let err = read_file("ARROW-RS-GH-6229-DICTHEADER.parquet").unwrap_err();
assert_eq!(
err.to_string(),
"External: Parquet argument error: EOF: eof decoding byte array"
"External: Parquet argument error: Parquet error: Integer overflow: out of range integral type conversion attempted"
);
}

Expand Down
Loading