Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Dec 17, 2024
1 parent 29d55eb commit 657f005
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 22 deletions.
6 changes: 1 addition & 5 deletions parquet/src/arrow/arrow_reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,13 +696,9 @@ impl<T: ChunkReader + 'static> Iterator for ReaderPageIterator<T> {
let reader = self.reader.clone();

let file_decryptor = Arc::new(self.metadata.file_decryptor().clone().unwrap());
// let aad_file_unique = file_decryptor?.aad_file_unique();
// let aad_prefix = file_decryptor?.aad_prefix();
//
// let file_decryptor = FileDecryptor::new(file_decryptor, aad_file_unique.clone(), aad_prefix.clone());

let crypto_context = CryptoContext::new(
meta.dictionary_page_offset().is_some(), rg_idx.to_i16()?, self.column_idx.to_i16()?, file_decryptor.clone(), file_decryptor);
meta.dictionary_page_offset().is_some(), rg_idx as i16, self.column_idx as i16, file_decryptor.clone(), file_decryptor);
let crypto_context = Arc::new(crypto_context);

let ret = SerializedPageReader::new(reader, meta, total_rows, page_locations, Some(crypto_context));
Expand Down
17 changes: 9 additions & 8 deletions parquet/src/encryption/ciphers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ pub fn create_page_aad(file_aad: &[u8], module_type: ModuleType, row_group_ordin
create_module_aad(file_aad, module_type, row_group_ordinal, column_ordinal, page_ordinal)
}

fn create_module_aad(file_aad: &[u8], module_type: ModuleType, row_group_ordinal: i16,
column_ordinal: i16, page_ordinal: i32) -> Result<Vec<u8>> {
pub fn create_module_aad(file_aad: &[u8], module_type: ModuleType, row_group_ordinal: i16,
column_ordinal: i16, page_ordinal: i32) -> Result<Vec<u8>> {

let module_buf = [module_type as u8];

Expand All @@ -192,17 +192,18 @@ fn create_module_aad(file_aad: &[u8], module_type: ModuleType, row_group_ordinal
if row_group_ordinal < 0 {
return Err(general_err!("Wrong row group ordinal: {}", row_group_ordinal));
}
// todo: this check is a noop here
if row_group_ordinal > i16::MAX {
return Err(general_err!("Encrypted parquet files can't have more than {} row groups: {}",
u16::MAX, row_group_ordinal));
i16::MAX, row_group_ordinal));
}

if column_ordinal < 0 {
return Err(general_err!("Wrong column ordinal: {}", column_ordinal));
}
// todo: this check is a noop here
if column_ordinal > i16::MAX {
return Err(general_err!("Encrypted parquet files can't have more than {} columns: {}",
u16::MAX, column_ordinal));
i16::MAX, column_ordinal));
}

if module_buf[0] != (ModuleType::DataPageHeader as u8) &&
Expand All @@ -218,17 +219,17 @@ fn create_module_aad(file_aad: &[u8], module_type: ModuleType, row_group_ordinal
if page_ordinal < 0 {
return Err(general_err!("Wrong page ordinal: {}", page_ordinal));
}
if page_ordinal > i32::MAX {
if page_ordinal > i16::MAX as i32 {
return Err(general_err!("Encrypted parquet files can't have more than {} pages in a chunk: {}",
u16::MAX, page_ordinal));
i16::MAX, page_ordinal));
}

let mut aad = Vec::with_capacity(file_aad.len() + 7);
aad.extend_from_slice(file_aad);
aad.extend_from_slice(module_buf.as_ref());
aad.extend_from_slice(row_group_ordinal.to_le_bytes().as_ref());
aad.extend_from_slice(column_ordinal.to_le_bytes().as_ref());
aad.extend_from_slice(page_ordinal.to_le_bytes().as_ref());
aad.extend_from_slice((page_ordinal as i16).to_le_bytes().as_ref());
Ok(aad)
}

Expand Down
39 changes: 30 additions & 9 deletions parquet/src/file/serialized_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,31 +342,33 @@ impl<R: 'static + ChunkReader> RowGroupReader for SerializedRowGroupReader<'_, R

/// Reads a [`PageHeader`] from the provided [`Read`]
pub(crate) fn read_page_header<T: Read>(input: &mut T, crypto_context: Option<Arc<CryptoContext>>) -> Result<PageHeader> {
let mut prot = TCompactInputProtocol::new(input);
if let Some(crypto_context) = crypto_context {
// let mut buf = [0; 16 * 1024];
// let size = input.read(&mut buf)?;

let decryptor = &crypto_context.data_decryptor();
let file_decryptor = decryptor.footer_decryptor();
let aad_file_unique = decryptor.aad_file_unique();
// let aad_prefix = decryptor.aad_prefix();

// todo: page ordinal and page type (ModuleType)
let aad = create_page_aad(
aad_file_unique.as_slice(),
ModuleType::DictionaryPageHeader,
ModuleType::DataPageHeader,
crypto_context.row_group_ordinal,
crypto_context.column_ordinal,
0,
)?;

// todo: This currently fails, possibly due to wrongly generated AAD
let buf = file_decryptor.decrypt(prot.read_bytes()?.as_slice(), aad.as_ref());
todo!("Decrypted page header!");
let mut len_bytes = [0; 4];
input.read_exact(&mut len_bytes)?;
let ciphertext_len = u32::from_le_bytes(len_bytes) as usize;
let mut ciphertext = vec![0; 4 + ciphertext_len];
input.read_exact(&mut ciphertext[4..])?;
let buf = file_decryptor.decrypt(&ciphertext, aad.as_ref());

let mut prot = TCompactSliceInputProtocol::new(buf.as_slice());
let page_header = PageHeader::read_from_in_protocol(&mut prot)?;
return Ok(page_header)
}

let mut prot = TCompactInputProtocol::new(input);
let page_header = PageHeader::read_from_in_protocol(&mut prot)?;
Ok(page_header)
}
Expand Down Expand Up @@ -401,6 +403,7 @@ pub(crate) fn decode_page(
buffer: Bytes,
physical_type: Type,
decompressor: Option<&mut Box<dyn Codec>>,
crypto_context: Option<Arc<CryptoContext>>,
) -> Result<Page> {
// Verify the 32-bit CRC checksum of the page
#[cfg(feature = "crc")]
Expand All @@ -426,6 +429,22 @@ pub(crate) fn decode_page(
// When is_compressed flag is missing the page is considered compressed
can_decompress = header_v2.is_compressed.unwrap_or(true);
}
if crypto_context.is_some() {
let crypto_context = crypto_context.as_ref().unwrap();
let decryptor = crypto_context.data_decryptor();
let file_decryptor = decryptor.footer_decryptor();

// todo: page ordinal
let aad = create_page_aad(
decryptor.aad_file_unique().as_slice(),
ModuleType::DataPage,
crypto_context.row_group_ordinal,
crypto_context.column_ordinal,
0,
)?;
let decrypted = file_decryptor.decrypt(&buffer.as_ref()[offset..], &aad);
todo!("page decrypted!");
}

// TODO: page header could be huge because of statistics. We should set a
// maximum page header size and abort if that is exceeded.
Expand Down Expand Up @@ -665,6 +684,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
Bytes::from(buffer),
self.physical_type,
self.decompressor.as_mut(),
self.crypto_context.clone(),
)?
}
SerializedPageReaderState::Pages {
Expand Down Expand Up @@ -694,6 +714,7 @@ impl<R: ChunkReader> PageReader for SerializedPageReader<R> {
bytes,
self.physical_type,
self.decompressor.as_mut(),
None,
)?
}
};
Expand Down

0 comments on commit 657f005

Please sign in to comment.