Skip to content

Commit

Permalink
GH-43141: [C++][Parquet] Replace use of int with int32_t in the inter…
Browse files Browse the repository at this point in the history
…nal Parquet encryption APIs (#43413)

### Rationale for this change

See #43141

### What changes are included in this PR?

* Changes uses of int to int32_t in the Encryptor and Decryptor APIs, except where interfacing with OpenSSL.
* Also change RandBytes to use size_t instead of int and check for overflow.
* Check the return code from OpenSSL's Rand_bytes in case there is a failure generating random bytes

### Are these changes tested?

Yes, this doesn't change behaviour and is covered by existing tests.

### Are there any user-facing changes?

No
* GitHub Issue: #43141

Authored-by: Adam Reeve <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
adamreeve authored Aug 21, 2024
1 parent f9911ee commit f078942
Show file tree
Hide file tree
Showing 14 changed files with 233 additions and 193 deletions.
4 changes: 2 additions & 2 deletions cpp/src/parquet/column_reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,8 @@ std::shared_ptr<Page> SerializedPageReader::NextPage() {
// Advance the stream offset
PARQUET_THROW_NOT_OK(stream_->Advance(header_size));

int compressed_len = current_page_header_.compressed_page_size;
int uncompressed_len = current_page_header_.uncompressed_page_size;
int32_t compressed_len = current_page_header_.compressed_page_size;
int32_t uncompressed_len = current_page_header_.uncompressed_page_size;
if (compressed_len < 0 || uncompressed_len < 0) {
throw ParquetException("Invalid page header");
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/parquet/encryption/crypto_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ std::shared_ptr<FileEncryptionProperties> CryptoFactory::GetFileEncryptionProper
int dek_length = dek_length_bits / 8;

std::string footer_key(dek_length, '\0');
RandBytes(reinterpret_cast<uint8_t*>(&footer_key[0]),
static_cast<int>(footer_key.size()));
RandBytes(reinterpret_cast<uint8_t*>(footer_key.data()), footer_key.size());

std::string footer_key_metadata =
key_wrapper.GetEncryptionKeyMetadata(footer_key, footer_key_id, true);
Expand Down Expand Up @@ -148,8 +147,7 @@ ColumnPathToEncryptionPropertiesMap CryptoFactory::GetColumnEncryptionProperties
}

std::string column_key(dek_length, '\0');
RandBytes(reinterpret_cast<uint8_t*>(&column_key[0]),
static_cast<int>(column_key.size()));
RandBytes(reinterpret_cast<uint8_t*>(column_key.data()), column_key.size());
std::string column_key_key_metadata =
key_wrapper->GetEncryptionKeyMetadata(column_key, column_key_id, false);

Expand Down
251 changes: 146 additions & 105 deletions cpp/src/parquet/encryption/encryption_internal.cc

Large diffs are not rendered by default.

46 changes: 23 additions & 23 deletions cpp/src/parquet/encryption/encryption_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ using parquet::ParquetCipher;

namespace parquet::encryption {

constexpr int kGcmTagLength = 16;
constexpr int kNonceLength = 12;
constexpr int32_t kGcmTagLength = 16;
constexpr int32_t kNonceLength = 12;

// Module types
constexpr int8_t kFooter = 0;
Expand All @@ -49,13 +49,13 @@ class PARQUET_EXPORT AesEncryptor {
public:
/// Can serve one key length only. Possible values: 16, 24, 32 bytes.
/// If write_length is true, prepend ciphertext length to the ciphertext
explicit AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata,
explicit AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool write_length = true);

static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int key_len,
static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int32_t key_len,
bool metadata);

static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int key_len,
static std::unique_ptr<AesEncryptor> Make(ParquetCipher::type alg_id, int32_t key_len,
bool metadata, bool write_length);

~AesEncryptor();
Expand All @@ -65,17 +65,17 @@ class PARQUET_EXPORT AesEncryptor {

/// Encrypts plaintext with the key and aad. Key length is passed only for validation.
/// If different from value in constructor, exception will be thrown.
int Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> ciphertext);
int32_t Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> ciphertext);

/// Encrypts plaintext footer, in order to compute footer signature (tag).
int SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<const uint8_t> nonce,
::arrow::util::span<uint8_t> encrypted_footer);
int32_t SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<const uint8_t> nonce,
::arrow::util::span<uint8_t> encrypted_footer);

void WipeOut();

Expand All @@ -90,7 +90,7 @@ class PARQUET_EXPORT AesDecryptor {
public:
/// Can serve one key length only. Possible values: 16, 24, 32 bytes.
/// If contains_length is true, expect ciphertext length prepended to the ciphertext
explicit AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata,
explicit AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool contains_length = true);

/// \brief Factory function to create an AesDecryptor
Expand All @@ -102,26 +102,26 @@ class PARQUET_EXPORT AesDecryptor {
/// out when decryption is finished
/// \return shared pointer to a new AesDecryptor
static std::shared_ptr<AesDecryptor> Make(
ParquetCipher::type alg_id, int key_len, bool metadata,
ParquetCipher::type alg_id, int32_t key_len, bool metadata,
std::vector<std::weak_ptr<AesDecryptor>>* all_decryptors);

~AesDecryptor();
void WipeOut();

/// The size of the plaintext, for this cipher and the specified ciphertext length.
[[nodiscard]] int PlaintextLength(int ciphertext_len) const;
[[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const;

/// The size of the ciphertext, for this cipher and the specified plaintext length.
[[nodiscard]] int CiphertextLength(int plaintext_len) const;
[[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const;

/// Decrypts ciphertext with the key and aad. Key length is passed only for
/// validation. If different from value in constructor, exception will be thrown.
/// The caller is responsible for ensuring that the plaintext buffer is at least as
/// large as PlaintextLength(ciphertext_len).
int Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> plaintext);
int32_t Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> plaintext);

private:
// PIMPL Idiom
Expand All @@ -139,7 +139,7 @@ std::string CreateFooterAad(const std::string& aad_prefix_bytes);
void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD);

// Wraps OpenSSL RAND_bytes function
void RandBytes(unsigned char* buf, int num);
void RandBytes(unsigned char* buf, size_t num);

// Ensure OpenSSL is initialized.
//
Expand Down
47 changes: 24 additions & 23 deletions cpp/src/parquet/encryption/encryption_internal_nossl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ class AesEncryptor::AesEncryptorImpl {};

AesEncryptor::~AesEncryptor() {}

int AesEncryptor::SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<const uint8_t> nonce,
::arrow::util::span<uint8_t> encrypted_footer) {
int32_t AesEncryptor::SignedFooterEncrypt(::arrow::util::span<const uint8_t> footer,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<const uint8_t> nonce,
::arrow::util::span<uint8_t> encrypted_footer) {
ThrowOpenSSLRequiredException();
return -1;
}
Expand All @@ -45,25 +45,25 @@ int32_t AesEncryptor::CiphertextLength(int64_t plaintext_len) const {
return -1;
}

int AesEncryptor::Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> ciphertext) {
int32_t AesEncryptor::Encrypt(::arrow::util::span<const uint8_t> plaintext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> ciphertext) {
ThrowOpenSSLRequiredException();
return -1;
}

AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int key_len, bool metadata,
AesEncryptor::AesEncryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool write_length) {
ThrowOpenSSLRequiredException();
}

class AesDecryptor::AesDecryptorImpl {};

int AesDecryptor::Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> plaintext) {
int32_t AesDecryptor::Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<const uint8_t> key,
::arrow::util::span<const uint8_t> aad,
::arrow::util::span<uint8_t> plaintext) {
ThrowOpenSSLRequiredException();
return -1;
}
Expand All @@ -72,36 +72,37 @@ void AesDecryptor::WipeOut() { ThrowOpenSSLRequiredException(); }

AesDecryptor::~AesDecryptor() {}

std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id, int key_len,
bool metadata) {
std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata) {
ThrowOpenSSLRequiredException();
return NULLPTR;
}

std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id, int key_len,
bool metadata, bool write_length) {
std::unique_ptr<AesEncryptor> AesEncryptor::Make(ParquetCipher::type alg_id,
int32_t key_len, bool metadata,
bool write_length) {
ThrowOpenSSLRequiredException();
return NULLPTR;
}

AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int key_len, bool metadata,
AesDecryptor::AesDecryptor(ParquetCipher::type alg_id, int32_t key_len, bool metadata,
bool contains_length) {
ThrowOpenSSLRequiredException();
}

std::shared_ptr<AesDecryptor> AesDecryptor::Make(
ParquetCipher::type alg_id, int key_len, bool metadata,
ParquetCipher::type alg_id, int32_t key_len, bool metadata,
std::vector<std::weak_ptr<AesDecryptor>>* all_decryptors) {
ThrowOpenSSLRequiredException();
return NULLPTR;
}

int AesDecryptor::PlaintextLength(int ciphertext_len) const {
int32_t AesDecryptor::PlaintextLength(int32_t ciphertext_len) const {
ThrowOpenSSLRequiredException();
return -1;
}

int AesDecryptor::CiphertextLength(int plaintext_len) const {
int32_t AesDecryptor::CiphertextLength(int32_t plaintext_len) const {
ThrowOpenSSLRequiredException();
return -1;
}
Expand All @@ -122,7 +123,7 @@ void QuickUpdatePageAad(int32_t new_page_ordinal, std::string* AAD) {
ThrowOpenSSLRequiredException();
}

void RandBytes(unsigned char* buf, int num) { ThrowOpenSSLRequiredException(); }
void RandBytes(unsigned char* buf, size_t num) { ThrowOpenSSLRequiredException(); }

void EnsureBackendInitialized() {}

Expand Down
22 changes: 11 additions & 11 deletions cpp/src/parquet/encryption/encryption_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,22 @@ class TestAesEncryption : public ::testing::Test {
encryptor.CiphertextLength(static_cast<int64_t>(plain_text_.size()));
std::vector<uint8_t> ciphertext(expected_ciphertext_len, '\0');

int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
str2span(aad_), ciphertext);
int32_t ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
str2span(aad_), ciphertext);

ASSERT_EQ(ciphertext_length, expected_ciphertext_len);

AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length);

int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length);
int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length);
std::vector<uint8_t> decrypted_text(expected_plaintext_length, '\0');

int plaintext_length =
int32_t plaintext_length =
decryptor.Decrypt(ciphertext, str2span(key_), str2span(aad_), decrypted_text);

std::string decrypted_text_str(decrypted_text.begin(), decrypted_text.end());

ASSERT_EQ(plaintext_length, static_cast<int>(plain_text_.size()));
ASSERT_EQ(plaintext_length, static_cast<int32_t>(plain_text_.size()));
ASSERT_EQ(plaintext_length, expected_plaintext_length);
ASSERT_EQ(decrypted_text_str, plain_text_);
}
Expand All @@ -68,10 +68,10 @@ class TestAesEncryption : public ::testing::Test {
AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length);

// Create ciphertext of all zeros, so the ciphertext length will be read as zero
const int ciphertext_length = 100;
constexpr int32_t ciphertext_length = 100;
std::vector<uint8_t> ciphertext(ciphertext_length, '\0');

int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length);
int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length);
std::vector<uint8_t> decrypted_text(expected_plaintext_length, '\0');

EXPECT_THROW(
Expand All @@ -89,12 +89,12 @@ class TestAesEncryption : public ::testing::Test {
encryptor.CiphertextLength(static_cast<int64_t>(plain_text_.size()));
std::vector<uint8_t> ciphertext(expected_ciphertext_len, '\0');

int ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
str2span(aad_), ciphertext);
int32_t ciphertext_length = encryptor.Encrypt(str2span(plain_text_), str2span(key_),
str2span(aad_), ciphertext);

AesDecryptor decryptor(cipher_type, key_length_, metadata, write_length);

int expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length);
int32_t expected_plaintext_length = decryptor.PlaintextLength(ciphertext_length);
std::vector<uint8_t> decrypted_text(expected_plaintext_length, '\0');

::arrow::util::span<uint8_t> truncated_ciphertext(ciphertext.data(),
Expand All @@ -105,7 +105,7 @@ class TestAesEncryption : public ::testing::Test {
}

private:
int key_length_ = 0;
int32_t key_length_ = 0;
std::string key_;
std::string aad_;
std::string plain_text_;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/parquet/encryption/file_key_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,10 @@ std::string FileKeyWrapper::GetEncryptionKeyMetadata(const std::string& data_key
KeyEncryptionKey FileKeyWrapper::CreateKeyEncryptionKey(
const std::string& master_key_id) {
std::string kek_bytes(kKeyEncryptionKeyLength, '\0');
RandBytes(reinterpret_cast<uint8_t*>(&kek_bytes[0]), kKeyEncryptionKeyLength);
RandBytes(reinterpret_cast<uint8_t*>(kek_bytes.data()), kKeyEncryptionKeyLength);

std::string kek_id(kKeyEncryptionKeyIdLength, '\0');
RandBytes(reinterpret_cast<uint8_t*>(&kek_id[0]), kKeyEncryptionKeyIdLength);
RandBytes(reinterpret_cast<uint8_t*>(kek_id.data()), kKeyEncryptionKeyIdLength);

// Encrypt KEK with Master key
std::string encoded_wrapped_kek = kms_client_->WrapKey(kek_bytes, master_key_id);
Expand Down
12 changes: 6 additions & 6 deletions cpp/src/parquet/encryption/internal_file_decryptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,16 @@ Decryptor::Decryptor(std::shared_ptr<encryption::AesDecryptor> aes_decryptor,
aad_(aad),
pool_(pool) {}

int Decryptor::PlaintextLength(int ciphertext_len) const {
int32_t Decryptor::PlaintextLength(int32_t ciphertext_len) const {
return aes_decryptor_->PlaintextLength(ciphertext_len);
}

int Decryptor::CiphertextLength(int plaintext_len) const {
int32_t Decryptor::CiphertextLength(int32_t plaintext_len) const {
return aes_decryptor_->CiphertextLength(plaintext_len);
}

int Decryptor::Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<uint8_t> plaintext) {
int32_t Decryptor::Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<uint8_t> plaintext) {
return aes_decryptor_->Decrypt(ciphertext, str2span(key_), str2span(aad_), plaintext);
}

Expand Down Expand Up @@ -143,7 +143,7 @@ std::shared_ptr<Decryptor> InternalFileDecryptor::GetFooterDecryptor(

// Create both data and metadata decryptors to avoid redundant retrieval of key
// from the key_retriever.
int key_len = static_cast<int>(footer_key.size());
auto key_len = static_cast<int32_t>(footer_key.size());
std::shared_ptr<encryption::AesDecryptor> aes_metadata_decryptor;
std::shared_ptr<encryption::AesDecryptor> aes_data_decryptor;

Expand Down Expand Up @@ -197,7 +197,7 @@ std::shared_ptr<Decryptor> InternalFileDecryptor::GetColumnDecryptor(
throw HiddenColumnException("HiddenColumnException, path=" + column_path);
}

int key_len = static_cast<int>(column_key.size());
auto key_len = static_cast<int32_t>(column_key.size());
std::lock_guard<std::mutex> lock(mutex_);
auto aes_decryptor =
encryption::AesDecryptor::Make(algorithm_, key_len, metadata, &all_decryptors_);
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/parquet/encryption/internal_file_decryptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class PARQUET_EXPORT Decryptor {
void UpdateAad(const std::string& aad) { aad_ = aad; }
::arrow::MemoryPool* pool() { return pool_; }

[[nodiscard]] int PlaintextLength(int ciphertext_len) const;
[[nodiscard]] int CiphertextLength(int plaintext_len) const;
int Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<uint8_t> plaintext);
[[nodiscard]] int32_t PlaintextLength(int32_t ciphertext_len) const;
[[nodiscard]] int32_t CiphertextLength(int32_t plaintext_len) const;
int32_t Decrypt(::arrow::util::span<const uint8_t> ciphertext,
::arrow::util::span<uint8_t> plaintext);

private:
std::shared_ptr<encryption::AesDecryptor> aes_decryptor_;
Expand Down
Loading

0 comments on commit f078942

Please sign in to comment.