diff --git a/src/fixed_vector.rs b/src/fixed_vector.rs index bd25ab0..4594541 100644 --- a/src/fixed_vector.rs +++ b/src/fixed_vector.rs @@ -4,6 +4,7 @@ use derivative::Derivative; use serde::Deserialize; use serde_derive::Serialize; use std::marker::PhantomData; +use std::mem; use std::ops::{Deref, DerefMut, Index, IndexMut}; use std::slice::SliceIndex; use tree_hash::Hash256; @@ -302,6 +303,25 @@ where len: 0, expected: 1, }) + } else if mem::size_of::() == 1 && mem::align_of::() == 1 { + if bytes.len() != fixed_len { + return Err(ssz::DecodeError::BytesInvalid(format!( + "FixedVector of {} items has {} items", + fixed_len, + bytes.len(), + ))); + } + + // Safety: We've verified T is u8, so Vec is Vec + // and bytes.to_vec() produces Vec + let vec_u8 = bytes.to_vec(); + let vec_t = unsafe { std::mem::transmute::, Vec>(vec_u8) }; + Self::new(vec_t).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!( + "Wrong number of FixedVector elements: {:?}", + e + )) + }) } else if T::is_ssz_fixed_len() { let num_items = bytes .len() @@ -311,17 +331,24 @@ where if num_items != fixed_len { return Err(ssz::DecodeError::BytesInvalid(format!( "FixedVector of {} items has {} items", - num_items, fixed_len + fixed_len, num_items + ))); + } + + // Check that we have a whole number of items and that it is safe to use chunks_exact + if bytes.len() % T::ssz_fixed_len() != 0 { + return Err(ssz::DecodeError::BytesInvalid(format!( + "FixedVector of {} items has {} bytes", + num_items, + bytes.len() ))); } - let vec = bytes.chunks(T::ssz_fixed_len()).try_fold( - Vec::with_capacity(num_items), - |mut vec, chunk| { - vec.push(T::from_ssz_bytes(chunk)?); - Ok(vec) - }, - )?; + let mut vec = Vec::with_capacity(num_items); + for chunk in bytes.chunks_exact(T::ssz_fixed_len()) { + vec.push(T::from_ssz_bytes(chunk)?); + } + Self::new(vec).map_err(|e| { ssz::DecodeError::BytesInvalid(format!( "Wrong number of FixedVector elements: {:?}", @@ -476,6 +503,39 @@ mod test { ssz_round_trip::>(vec![0; 8].try_into().unwrap()); } + // Test byte decoding (we have a specialised code path with unsafe code that NEEDS coverage). + #[test] + fn ssz_round_trip_u8_len_1024() { + ssz_round_trip::>(vec![42; 1024].try_into().unwrap()); + ssz_round_trip::>(vec![0; 1024].try_into().unwrap()); + } + + #[test] + fn ssz_u8_len_1024_too_long() { + assert_eq!( + FixedVector::::from_ssz_bytes(&vec![42; 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into()) + ); + } + + #[test] + fn ssz_u64_len_1024_too_long() { + assert_eq!( + FixedVector::::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into()) + ); + } + + // Decoding an input with invalid trailing bytes MUST fail. + #[test] + fn ssz_bytes_u64_trailing() { + let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 1]; + assert_eq!( + FixedVector::::from_ssz_bytes(&bytes).unwrap_err(), + ssz::DecodeError::BytesInvalid("FixedVector of 2 items has 9 bytes".into()) + ); + } + #[test] fn tree_hash_u8() { let fixed: FixedVector = FixedVector::try_from(vec![]).unwrap(); diff --git a/src/variable_list.rs b/src/variable_list.rs index 00a80f4..6c5241e 100644 --- a/src/variable_list.rs +++ b/src/variable_list.rs @@ -288,6 +288,26 @@ where return Ok(Self::default()); } + if std::mem::size_of::() == 1 && std::mem::align_of::() == 1 { + if bytes.len() > max_len { + return Err(ssz::DecodeError::BytesInvalid(format!( + "VariableList of {} items exceeds maximum of {}", + bytes.len(), + max_len + ))); + } + + // Safety: We've verified T has the same memory layout as u8, so Vec *is* Vec. + let vec_u8 = bytes.to_vec(); + let vec_t = unsafe { std::mem::transmute::, Vec>(vec_u8) }; + return Self::new(vec_t).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!( + "Wrong number of VariableList elements: {:?}", + e + )) + }); + } + if T::is_ssz_fixed_len() { let num_items = bytes .len() @@ -301,13 +321,25 @@ where ))); } - bytes.chunks(T::ssz_fixed_len()).try_fold( - Vec::with_capacity(num_items), - |mut vec, chunk| { - vec.push(T::from_ssz_bytes(chunk)?); - Ok(vec) - }, - ) + // Check that we have a whole number of items and that it is safe to use chunks_exact + if bytes.len() % T::ssz_fixed_len() != 0 { + return Err(ssz::DecodeError::BytesInvalid(format!( + "VariableList of {} items has {} bytes", + num_items, + bytes.len() + ))); + } + + let mut vec = Vec::with_capacity(num_items); + for chunk in bytes.chunks_exact(T::ssz_fixed_len()) { + vec.push(T::from_ssz_bytes(chunk)?); + } + Self::new(vec).map_err(|e| { + ssz::DecodeError::BytesInvalid(format!( + "Wrong number of VariableList elements: {:?}", + e + )) + }) } else { ssz::decode_list_of_variable_length_items(bytes, Some(max_len)) }? @@ -431,7 +463,7 @@ mod test { assert_eq!( as Encode>::ssz_fixed_len(), 4); } - fn round_trip(item: T) { + fn ssz_round_trip(item: T) { let encoded = &item.as_ssz_bytes(); assert_eq!(item.ssz_bytes_len(), encoded.len()); assert_eq!(T::from_ssz_bytes(encoded), Ok(item)); @@ -439,9 +471,35 @@ mod test { #[test] fn u16_len_8() { - round_trip::>(vec![42; 8].try_into().unwrap()); - round_trip::>(vec![0; 8].try_into().unwrap()); - round_trip::>(vec![].try_into().unwrap()); + ssz_round_trip::>(vec![42; 8].try_into().unwrap()); + ssz_round_trip::>(vec![0; 8].try_into().unwrap()); + ssz_round_trip::>(vec![].try_into().unwrap()); + } + + #[test] + fn ssz_round_trip_u8_len_1024() { + ssz_round_trip::>(vec![42; 1024].try_into().unwrap()); + ssz_round_trip::>(vec![0; 1024].try_into().unwrap()); + } + + #[test] + fn ssz_u8_len_1024_too_long() { + assert_eq!( + VariableList::::from_ssz_bytes(&vec![42; 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid( + "VariableList of 1025 items exceeds maximum of 1024".into() + ) + ); + } + + #[test] + fn ssz_u64_len_1024_too_long() { + assert_eq!( + VariableList::::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(), + ssz::DecodeError::BytesInvalid( + "VariableList of 1025 items exceeds maximum of 1024".into() + ) + ); } #[test] @@ -452,6 +510,21 @@ mod test { assert_eq!(VariableList::from_ssz_bytes(&[]).unwrap(), empty_list); } + #[test] + fn ssz_bytes_u32_trailing() { + let bytes = [1, 0, 0, 0, 2, 0]; + assert_eq!( + VariableList::::from_ssz_bytes(&bytes).unwrap_err(), + ssz::DecodeError::BytesInvalid("VariableList of 1 items has 6 bytes".into()) + ); + + let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 3]; + assert_eq!( + VariableList::::from_ssz_bytes(&bytes).unwrap_err(), + ssz::DecodeError::BytesInvalid("VariableList of 2 items has 9 bytes".into()) + ); + } + fn root_with_length(bytes: &[u8], len: usize) -> Hash256 { let root = merkle_root(bytes, 0); tree_hash::mix_in_length(&root, len)