From d1d72d13a4298f47230735c8f98434edb6bedbf4 Mon Sep 17 00:00:00 2001 From: levonpetrosyan93 <45027856+levonpetrosyan93@users.noreply.github.com> Date: Tue, 19 Dec 2023 10:42:28 +0400 Subject: [PATCH] Add size and type checks to coin deserialization (#1379) * Add size checks to coin deserialization * Fix * Add AEAD deserialization test * Unittests fixed --------- Co-authored-by: Aaron Feickert <66188213+AaronFeickert@users.noreply.github.com> --- src/libspark/aead.h | 12 +++++++++++ src/libspark/coin.h | 17 ++++++++++++++++ src/libspark/mint_transaction.cpp | 4 ++-- src/libspark/test/aead_test.cpp | 20 ++++++++++++------ src/libspark/util.cpp | 3 +++ src/test/spark_state_test.cpp | 34 +++++++++++++++++++++++++++++-- src/txmempool.cpp | 4 +++- src/validation.cpp | 4 +++- 8 files changed, 86 insertions(+), 12 deletions(-) diff --git a/src/libspark/aead.h b/src/libspark/aead.h index ce8470a17d..6f43a3d334 100644 --- a/src/libspark/aead.h +++ b/src/libspark/aead.h @@ -15,8 +15,20 @@ struct AEADEncryptedData { template inline void SerializationOp(Stream& s, Operation ser_action) { READWRITE(ciphertext); + + // Tag must be the correct size READWRITE(tag); + if (tag.size() != AEAD_TAG_SIZE) { + std::cout << "Bad tag size " << tag.size() << std::endl; + throw std::invalid_argument("Cannot deserialize AEAD data due to bad tag"); + } + + // Key commitment must be the correct size, which also includes an encoded size READWRITE(key_commitment); + if (key_commitment.size() != AEAD_COMMIT_SIZE) { + std::cout << "Bad keycom size " << key_commitment.size() << std::endl; + throw std::invalid_argument("Cannot deserialize AEAD data due to bad key commitment size"); + } } }; diff --git a/src/libspark/coin.h b/src/libspark/coin.h index cdb42d336f..e8e85c1ecd 100644 --- a/src/libspark/coin.h +++ b/src/libspark/coin.h @@ -108,11 +108,28 @@ class Coin { ADD_SERIALIZE_METHODS; template inline void SerializationOp(Stream& s, Operation ser_action) { + // The type must be valid READWRITE(type); + if (type != COIN_TYPE_MINT && type != COIN_TYPE_SPEND) { + throw std::invalid_argument("Cannot deserialize coin due to bad type"); + } READWRITE(S); READWRITE(K); READWRITE(C); + + // Encrypted coin data is always of a fixed size that depends on coin type + // Its tag and key commitment sizes are enforced during its deserialization + // For mint coins: encrypted diversifier (with size), encoded nonce, padded memo (with size) + // For spend coins: encoded value, encrypted diversifier (with size), encoded nonce, padded memo (with size) READWRITE(r_); + if (type == COIN_TYPE_MINT && r_.ciphertext.size() != (1 + AES_BLOCKSIZE) + SCALAR_ENCODING + (1 + params->get_memo_bytes())) { + std::cout << "Data size " << r_.ciphertext.size() << " but expected " << (AES_BLOCKSIZE + SCALAR_ENCODING + params->get_memo_bytes()) << std::endl; + throw std::invalid_argument("Cannot deserialize mint coin due to bad encrypted data"); + } + if (type == COIN_TYPE_SPEND && r_.ciphertext.size() != 8 + (1 + AES_BLOCKSIZE) + SCALAR_ENCODING + (1 + params->get_memo_bytes())) { + std::cout << "Data size " << r_.ciphertext.size() << " but expected " << (8 + AES_BLOCKSIZE + SCALAR_ENCODING + params->get_memo_bytes()) << std::endl; + throw std::invalid_argument("Cannot deserialize spend coin due to bad encrypted data"); + } if (type == COIN_TYPE_MINT) { READWRITE(v); diff --git a/src/libspark/mint_transaction.cpp b/src/libspark/mint_transaction.cpp index 7acac3cb73..f52a3b094e 100644 --- a/src/libspark/mint_transaction.cpp +++ b/src/libspark/mint_transaction.cpp @@ -42,10 +42,10 @@ MintTransaction::MintTransaction( value_statement.emplace_back(this->coins[j].C + this->params->get_G().inverse()*Scalar(this->coins[j].v)); value_witness.emplace_back(SparkUtils::hash_val(k)); } else { - Coin coin; + Coin coin(params); coin.type = 0; coin.r_.ciphertext.resize(82); // max possible size - coin.r_.key_commitment.resize(64); + coin.r_.key_commitment.resize(32); coin.r_.tag.resize(16); coin.v = 0; this->coins.emplace_back(coin); diff --git a/src/libspark/test/aead_test.cpp b/src/libspark/test/aead_test.cpp index 2a3901326d..78e30bffe8 100644 --- a/src/libspark/test/aead_test.cpp +++ b/src/libspark/test/aead_test.cpp @@ -13,20 +13,28 @@ BOOST_AUTO_TEST_CASE(complete) GroupElement prekey; prekey.randomize(); - // Serialize + // Serialize message int message = 12345; - CDataStream ser(SER_NETWORK, PROTOCOL_VERSION); - ser << message; + CDataStream ser_message(SER_NETWORK, PROTOCOL_VERSION); + ser_message << message; // Encrypt - AEADEncryptedData data = AEAD::encrypt(prekey, "Associated data", ser); + AEADEncryptedData data = AEAD::encrypt(prekey, "Associated data", ser_message); + + // Serialize encrypted data + CDataStream ser_data(SER_NETWORK, PROTOCOL_VERSION); + ser_data << data; + + // Deserialize encrypted data + AEADEncryptedData data_deser; + ser_data >> data_deser; // Decrypt - ser = AEAD::decrypt_and_verify(prekey, "Associated data", data); + ser_message = AEAD::decrypt_and_verify(prekey, "Associated data", data_deser); // Deserialize int message_; - ser >> message_; + ser_message >> message_; BOOST_CHECK_EQUAL(message_, message); } diff --git a/src/libspark/util.cpp b/src/libspark/util.cpp index 4547251320..17212cc1fd 100644 --- a/src/libspark/util.cpp +++ b/src/libspark/util.cpp @@ -36,6 +36,9 @@ uint64_t SparkUtils::diversifier_decrypt(const std::vector& key, if (key.size() != AES256_KEYSIZE) { throw std::invalid_argument("Bad diversifier decryption key size"); } + if (d.size() != AES_BLOCKSIZE) { + throw std::invalid_argument("Bad diversifier ciphertext size"); + } std::vector iv; iv.resize(AES_BLOCKSIZE); diff --git a/src/test/spark_state_test.cpp b/src/test/spark_state_test.cpp index e39049c3c6..f4178e6347 100644 --- a/src/test/spark_state_test.cpp +++ b/src/test/spark_state_test.cpp @@ -17,6 +17,16 @@ basic_ostream& operator<<(basic_ostream& os, const p } // namespace std +// Generate a random char vector from a random scalar +static std::vector random_char_vector() { + Scalar temp; + temp.randomize(); + std::vector result; + result.resize(spark::SCALAR_ENCODING); + temp.serialize(result.data()); + + return result; +} class SparkStateTests : public SparkTestingSetup { @@ -173,8 +183,28 @@ BOOST_AUTO_TEST_CASE(mempool) // - can not add on-chain coin BOOST_CHECK(!sparkState->CanAddMintToMempool(pwalletMain->sparkWallet->getCoinFromMeta(mint))); - // - can not add duplicated coin - spark::Coin randMint; + // Generate keys + const spark::Params* params = spark::Params::get_default(); + spark::SpendKey spend_key(params); + spark::FullViewKey full_view_key(spend_key); + spark::IncomingViewKey incoming_view_key(full_view_key); + + // Generate address + spark::Address address(incoming_view_key, 1); + + // Generate coin + Scalar k; + k.randomize(); + spark::Coin randMint = spark::Coin( + params, + spark::COIN_TYPE_MINT, + k, + address, + 100, + "memo", + random_char_vector() + ); + BOOST_CHECK(sparkState->CanAddMintToMempool(randMint)); sparkState->AddMintsToMempool({randMint}); BOOST_CHECK(!sparkState->CanAddMintToMempool(randMint)); diff --git a/src/txmempool.cpp b/src/txmempool.cpp index 37ee90a610..a21edd2f3f 100644 --- a/src/txmempool.cpp +++ b/src/txmempool.cpp @@ -630,7 +630,9 @@ void CTxMemPool::removeUnchecked(txiter it, MemPoolRemovalReason reason) { if (txout.scriptPubKey.IsSparkMint() || txout.scriptPubKey.IsSparkSMint()) { try { - spark::Coin txCoin; + const spark::Params* params = spark::Params::get_default(); + + spark::Coin txCoin(params); spark::ParseSparkMintCoin(txout.scriptPubKey, txCoin); sparkState.RemoveMintFromMempool(txCoin); } diff --git a/src/validation.cpp b/src/validation.cpp index f5c542cd53..4d932554db 100644 --- a/src/validation.cpp +++ b/src/validation.cpp @@ -3360,7 +3360,9 @@ void static RemoveConflictingPrivacyTransactionsFromMempool(const CBlock &block) if (txout.scriptPubKey.IsSparkMint() || txout.scriptPubKey.IsSparkSMint()) { try { - spark::Coin txCoin; + const spark::Params* params = spark::Params::get_default(); + + spark::Coin txCoin(params); spark::ParseSparkMintCoin(txout.scriptPubKey, txCoin); sparkState->RemoveMintFromMempool(txCoin); } catch (std::invalid_argument&) {