From 51b089b995c23b88a42ee8d934ba69e9dcc8455c Mon Sep 17 00:00:00 2001 From: Ratan Kaliani Date: Fri, 10 Nov 2023 11:09:13 -0800 Subject: [PATCH] fix: eddsa + sha512 pad fixes (#286) --- .../frontend/ecc/curve25519/ed25519/eddsa.rs | 125 +++++++++++++++--- .../src/frontend/hash/sha/sha512/curta.rs | 40 +++++- .../core/src/frontend/hash/sha/sha512/pad.rs | 5 +- 3 files changed, 148 insertions(+), 22 deletions(-) diff --git a/plonky2x/core/src/frontend/ecc/curve25519/ed25519/eddsa.rs b/plonky2x/core/src/frontend/ecc/curve25519/ed25519/eddsa.rs index f23bed0f3..5b7ee6995 100644 --- a/plonky2x/core/src/frontend/ecc/curve25519/ed25519/eddsa.rs +++ b/plonky2x/core/src/frontend/ecc/curve25519/ed25519/eddsa.rs @@ -91,6 +91,8 @@ impl, const D: usize> CircuitBuilder { assert!(signatures.len() == NUM_SIGS); assert!(pubkeys.len() == NUM_SIGS); + let max_msg_byte_length = self.constant::(MAX_MSG_LENGTH_BYTES as u32); + let (dummy_pub_key, dummy_sig, dummy_msg, dummy_msg_byte_length) = self.get_dummy_variables::(); @@ -103,28 +105,27 @@ impl, const D: usize> CircuitBuilder { msg_vec.push(self.select(is_active[i], messages[i], dummy_msg)); if let Some(ref msg_lens) = message_byte_lengths { msg_len_vec.push(self.select(is_active[i], msg_lens[i], dummy_msg_byte_length)); + } else { + msg_len_vec.push(self.select( + is_active[i], + max_msg_byte_length, + dummy_msg_byte_length, + )) } sig_vec.push(self.select(is_active[i], signatures[i].clone(), dummy_sig.clone())); pub_key_vec.push(self.select(is_active[i], pubkeys[i].clone(), dummy_pub_key.clone())); } + let msg_len_vec = ArrayVariable::::from( + msg_len_vec.into_iter().collect::>(), + ); let msg_array = ArrayVariable::, NUM_SIGS>::from(msg_vec); - let msg_len_array = ArrayVariable::::from(msg_len_vec); let sig_array = ArrayVariable::::from(sig_vec); let pub_key_array = ArrayVariable::::from(pub_key_vec); - self.curta_eddsa_verify_sigs( - msg_array, - if message_byte_lengths.is_none() { - None - } else { - Some(msg_len_array) - }, - sig_array, - pub_key_array, - ); + self.curta_eddsa_verify_sigs(msg_array, Some(msg_len_vec), sig_array, pub_key_array); } /// This function will verify a set of eddsa signatures. If message_byte_lengths is None, then @@ -210,12 +211,49 @@ mod tests { use crate::frontend::ecc::curve25519::ed25519::eddsa::{ EDDSASignatureVariable, EDDSASignatureVariableValue, }; - use crate::prelude::{ArrayVariable, BytesVariable, DefaultBuilder, U32Variable}; + use crate::prelude::{ArrayVariable, BoolVariable, BytesVariable, DefaultBuilder, U32Variable}; use crate::utils; - const MAX_MSG_LEN_BYTES: usize = 192; + const MAX_MSG_LEN_BYTES: usize = 174; const NUM_SIGS: usize = 3; + fn test_curta_eddsa_verify_sigs( + test_pub_keys: Vec, + test_signatures: Vec>, + test_messages: Vec<[u8; MAX_MSG_LEN_BYTES]>, + test_message_lens: Vec, + variable_msg_len: bool, + ) { + utils::setup_logger(); + + let mut builder = DefaultBuilder::new(); + + let pkeys = builder.read::>(); + let signatures = builder.read::>(); + let messages = builder.read::, NUM_SIGS>>(); + if variable_msg_len { + let message_lens = builder.read::>(); + builder.curta_eddsa_verify_sigs(messages, Some(message_lens), signatures, pkeys); + } else { + builder.curta_eddsa_verify_sigs(messages, None, signatures, pkeys); + } + + let circuit = builder.build(); + + let mut input = circuit.input(); + input.write::>(test_pub_keys); + input.write::>(test_signatures); + input.write::, NUM_SIGS>>( + test_messages.to_vec(), + ); + if variable_msg_len { + input.write::>(test_message_lens); + } + + let (proof, output) = circuit.prove(&input); + circuit.verify(&proof, &input, &output); + } + #[test] #[cfg_attr(feature = "ci", ignore)] fn test_curta_eddsa_verify_sigs_constant_msg_len() { @@ -341,10 +379,11 @@ mod tests { ); } - fn test_curta_eddsa_verify_sigs( + fn test_curta_eddsa_verify_sigs_conditional( + test_is_active: Vec, test_pub_keys: Vec, test_signatures: Vec>, - test_messages: Vec<[u8; 192]>, + test_messages: Vec<[u8; MAX_MSG_LEN_BYTES]>, test_message_lens: Vec, variable_msg_len: bool, ) { @@ -352,19 +391,28 @@ mod tests { let mut builder = DefaultBuilder::new(); + let is_active = builder.read::>(); let pkeys = builder.read::>(); let signatures = builder.read::>(); let messages = builder.read::, NUM_SIGS>>(); if variable_msg_len { let message_lens = builder.read::>(); - builder.curta_eddsa_verify_sigs(messages, Some(message_lens), signatures, pkeys); + builder.curta_eddsa_verify_sigs_conditional( + is_active, + Some(message_lens), + messages, + signatures, + pkeys, + ); } else { - builder.curta_eddsa_verify_sigs(messages, None, signatures, pkeys); + builder + .curta_eddsa_verify_sigs_conditional(is_active, None, messages, signatures, pkeys); } let circuit = builder.build(); let mut input = circuit.input(); + input.write::>(test_is_active); input.write::>(test_pub_keys); input.write::>(test_signatures); input.write::, NUM_SIGS>>( @@ -377,4 +425,47 @@ mod tests { let (proof, output) = circuit.prove(&input); circuit.verify(&proof, &input, &output); } + + #[test] + #[cfg_attr(feature = "ci", ignore)] + fn test_curta_eddsa_verify_sigs_constant_msg_len_conditional() { + // Generate random messages and private keys + let mut test_messages: Vec<[u8; MAX_MSG_LEN_BYTES]> = Vec::new(); + let test_message_lens = Vec::new(); + let mut test_is_active = Vec::new(); + let mut test_pub_keys = Vec::new(); + let mut test_signatures = Vec::new(); + + let mut csprng = OsRng; + for _i in 0..NUM_SIGS { + // Generate random length + let msg_len = MAX_MSG_LEN_BYTES as u32; + let mut test_message = Vec::new(); + for _ in 0..msg_len { + test_message.push(rand::thread_rng().gen_range(0..255)); + } + + let test_signing_key = SigningKey::generate(&mut csprng); + let test_pub_key = test_signing_key.verifying_key(); + let test_signature = test_signing_key.sign(&test_message); + + test_message.resize(MAX_MSG_LEN_BYTES, 0); + test_messages.push(test_message.try_into().unwrap()); + test_is_active.push(true); + test_pub_keys.push(CompressedEdwardsY(test_pub_key.to_bytes())); + test_signatures.push(EDDSASignatureVariableValue { + r: CompressedEdwardsY(*test_signature.r_bytes()), + s: U256::from_little_endian(test_signature.s_bytes()), + }); + } + + test_curta_eddsa_verify_sigs_conditional( + test_is_active, + test_pub_keys, + test_signatures, + test_messages, + test_message_lens, + false, + ); + } } diff --git a/plonky2x/core/src/frontend/hash/sha/sha512/curta.rs b/plonky2x/core/src/frontend/hash/sha/sha512/curta.rs index 891b2ff29..b0c5b66ed 100644 --- a/plonky2x/core/src/frontend/hash/sha/sha512/curta.rs +++ b/plonky2x/core/src/frontend/hash/sha/sha512/curta.rs @@ -306,21 +306,53 @@ mod tests { setup_logger(); let mut builder = DefaultBuilder::new(); - let max_number_of_chunks = 20; + let max_number_of_chunks = 2; let total_message_length = 128 * max_number_of_chunks; - let max_len = (total_message_length - 18) / 128; + let max_len = total_message_length - 8; let mut rng = thread_rng(); let total_message = (0..total_message_length) .map(|_| rng.gen::()) .collect::>(); + let message = total_message + .iter() + .map(|b| builder.constant::(*b)) + .collect::>(); for i in 0..max_len { - let message = &total_message[..i]; - let expected_digest = sha512(message); + let expected_digest = sha512(&total_message[..i]); + + let length = builder.constant::(i as u32); + + let digest = builder.curta_sha512_variable(&message, length); + let expected_digest = builder.constant::>(expected_digest); + builder.assert_is_equal(digest, expected_digest); + } + + let circuit = builder.build(); + let input = circuit.input(); + let (proof, output) = circuit.prove(&input); + circuit.verify(&proof, &input, &output); + } + + #[test] + #[cfg_attr(feature = "ci", ignore)] + fn test_sha512_variable_length_max_size() { + // This test checks that sha512_variable_pad works as intended, especially when the max + // input length is (length % 128 > 128 - 17). + setup_logger(); + let mut builder = DefaultBuilder::new(); + + let max_number_of_chunks = 1; + let total_message_length = 128 * max_number_of_chunks; + + for i in 127 - 20..total_message_length + 1 { + let mut rng = thread_rng(); + let total_message = (0..i).map(|_| rng.gen::()).collect::>(); let message = total_message .iter() .map(|b| builder.constant::(*b)) .collect::>(); + let expected_digest = sha512(&total_message); let length = builder.constant::(i as u32); diff --git a/plonky2x/core/src/frontend/hash/sha/sha512/pad.rs b/plonky2x/core/src/frontend/hash/sha/sha512/pad.rs index abb57277a..c62f7d025 100644 --- a/plonky2x/core/src/frontend/hash/sha/sha512/pad.rs +++ b/plonky2x/core/src/frontend/hash/sha/sha512/pad.rs @@ -67,8 +67,11 @@ impl, const D: usize> CircuitBuilder { ) -> Vec { let last_chunk = self.compute_sha512_last_chunk(input_byte_length); + // Calculate the number of chunks needed to store the input. 17 is the number of bytes added + // by the padding and LE length representation. + let max_num_chunks = ceil_div_usize(input.len() + 17, SHA512_CHUNK_SIZE_BYTES_128); + // Extend input to size max_num_chunks * 128 before padding. - let max_num_chunks = ceil_div_usize(input.len(), SHA512_CHUNK_SIZE_BYTES_128); let mut padded_input = input.to_vec(); padded_input.resize(max_num_chunks * SHA512_CHUNK_SIZE_BYTES_128, self.zero());