Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 68 additions & 8 deletions src/fixed_vector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use derivative::Derivative;
use serde::Deserialize;
use serde_derive::Serialize;
use std::marker::PhantomData;
use std::mem;
use std::ops::{Deref, DerefMut, Index, IndexMut};
use std::slice::SliceIndex;
use tree_hash::Hash256;
Expand Down Expand Up @@ -302,6 +303,25 @@ where
len: 0,
expected: 1,
})
} else if mem::size_of::<T>() == 1 && mem::align_of::<T>() == 1 {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using size_of and align_of removes the need to add any bounds on T. Other crates like bytemuck would have required us to constrain T quite a lot, which sort of defeats the point of a generic impl.

Even using TypeId would have required us to add 'static.

The other advantage of this is that it works for types that encode as u8, like the ParticipationFlags in the BeaconState.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't T be a bool as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess it could be, but we never use bools in any consensus data structures, do we?

Would be a good test case

if bytes.len() != fixed_len {
return Err(ssz::DecodeError::BytesInvalid(format!(
"FixedVector of {} items has {} items",
fixed_len,
bytes.len(),
)));
}

// Safety: We've verified T is u8, so Vec<T> is Vec<u8>
// and bytes.to_vec() produces Vec<u8>
let vec_u8 = bytes.to_vec();
let vec_t = unsafe { std::mem::transmute::<Vec<u8>, Vec<T>>(vec_u8) };
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a test for a type that is not u8

Self::new(vec_t).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of FixedVector elements: {:?}",
e
))
})
} else if T::is_ssz_fixed_len() {
let num_items = bytes
.len()
Expand All @@ -311,17 +331,24 @@ where
if num_items != fixed_len {
return Err(ssz::DecodeError::BytesInvalid(format!(
"FixedVector of {} items has {} items",
num_items, fixed_len
fixed_len, num_items
)));
}

// Check that we have a whole number of items and that it is safe to use chunks_exact
if bytes.len() % T::ssz_fixed_len() != 0 {
return Err(ssz::DecodeError::BytesInvalid(format!(
"FixedVector of {} items has {} bytes",
num_items,
bytes.len()
)));
}

let vec = bytes.chunks(T::ssz_fixed_len()).try_fold(
Vec::with_capacity(num_items),
|mut vec, chunk| {
vec.push(T::from_ssz_bytes(chunk)?);
Ok(vec)
},
)?;
let mut vec = Vec::with_capacity(num_items);
for chunk in bytes.chunks_exact(T::ssz_fixed_len()) {
vec.push(T::from_ssz_bytes(chunk)?);
}

Self::new(vec).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of FixedVector elements: {:?}",
Expand Down Expand Up @@ -476,6 +503,39 @@ mod test {
ssz_round_trip::<FixedVector<u16, U8>>(vec![0; 8].try_into().unwrap());
}

// Test byte decoding (we have a specialised code path with unsafe code that NEEDS coverage).
#[test]
fn ssz_round_trip_u8_len_1024() {
ssz_round_trip::<FixedVector<u8, U1024>>(vec![42; 1024].try_into().unwrap());
ssz_round_trip::<FixedVector<u8, U1024>>(vec![0; 1024].try_into().unwrap());
}

#[test]
fn ssz_u8_len_1024_too_long() {
assert_eq!(
FixedVector::<u8, U1024>::from_ssz_bytes(&vec![42; 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into())
);
}

#[test]
fn ssz_u64_len_1024_too_long() {
assert_eq!(
FixedVector::<u64, U1024>::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid("FixedVector of 1024 items has 1025 items".into())
);
}

// Decoding an input with invalid trailing bytes MUST fail.
#[test]
fn ssz_bytes_u64_trailing() {
let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 1];
assert_eq!(
FixedVector::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
ssz::DecodeError::BytesInvalid("FixedVector of 2 items has 9 bytes".into())
);
}

#[test]
fn tree_hash_u8() {
let fixed: FixedVector<u8, U0> = FixedVector::try_from(vec![]).unwrap();
Expand Down
95 changes: 84 additions & 11 deletions src/variable_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,26 @@ where
return Ok(Self::default());
}

if std::mem::size_of::<T>() == 1 && std::mem::align_of::<T>() == 1 {
if bytes.len() > max_len {
return Err(ssz::DecodeError::BytesInvalid(format!(
"VariableList of {} items exceeds maximum of {}",
bytes.len(),
max_len
)));
}

// Safety: We've verified T has the same memory layout as u8, so Vec<T> *is* Vec<u8>.
let vec_u8 = bytes.to_vec();
let vec_t = unsafe { std::mem::transmute::<Vec<u8>, Vec<T>>(vec_u8) };
return Self::new(vec_t).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of VariableList elements: {:?}",
e
))
});
}

if T::is_ssz_fixed_len() {
let num_items = bytes
.len()
Expand All @@ -301,13 +321,25 @@ where
)));
}

bytes.chunks(T::ssz_fixed_len()).try_fold(
Vec::with_capacity(num_items),
|mut vec, chunk| {
vec.push(T::from_ssz_bytes(chunk)?);
Ok(vec)
},
)
// Check that we have a whole number of items and that it is safe to use chunks_exact
if bytes.len() % T::ssz_fixed_len() != 0 {
return Err(ssz::DecodeError::BytesInvalid(format!(
"VariableList of {} items has {} bytes",
num_items,
bytes.len()
)));
}

let mut vec = Vec::with_capacity(num_items);
for chunk in bytes.chunks_exact(T::ssz_fixed_len()) {
vec.push(T::from_ssz_bytes(chunk)?);
}
Self::new(vec).map_err(|e| {
ssz::DecodeError::BytesInvalid(format!(
"Wrong number of VariableList elements: {:?}",
e
))
})
} else {
ssz::decode_list_of_variable_length_items(bytes, Some(max_len))
}?
Expand Down Expand Up @@ -431,17 +463,43 @@ mod test {
assert_eq!(<VariableList<u16, U2> as Encode>::ssz_fixed_len(), 4);
}

fn round_trip<T: Encode + Decode + std::fmt::Debug + PartialEq>(item: T) {
fn ssz_round_trip<T: Encode + Decode + std::fmt::Debug + PartialEq>(item: T) {
let encoded = &item.as_ssz_bytes();
assert_eq!(item.ssz_bytes_len(), encoded.len());
assert_eq!(T::from_ssz_bytes(encoded), Ok(item));
}

#[test]
fn u16_len_8() {
round_trip::<VariableList<u16, U8>>(vec![42; 8].try_into().unwrap());
round_trip::<VariableList<u16, U8>>(vec![0; 8].try_into().unwrap());
round_trip::<VariableList<u16, U8>>(vec![].try_into().unwrap());
ssz_round_trip::<VariableList<u16, U8>>(vec![42; 8].try_into().unwrap());
ssz_round_trip::<VariableList<u16, U8>>(vec![0; 8].try_into().unwrap());
ssz_round_trip::<VariableList<u16, U8>>(vec![].try_into().unwrap());
}

#[test]
fn ssz_round_trip_u8_len_1024() {
ssz_round_trip::<VariableList<u8, U1024>>(vec![42; 1024].try_into().unwrap());
ssz_round_trip::<VariableList<u8, U1024>>(vec![0; 1024].try_into().unwrap());
}

#[test]
fn ssz_u8_len_1024_too_long() {
assert_eq!(
VariableList::<u8, U1024>::from_ssz_bytes(&vec![42; 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid(
"VariableList of 1025 items exceeds maximum of 1024".into()
)
);
}

#[test]
fn ssz_u64_len_1024_too_long() {
assert_eq!(
VariableList::<u64, U1024>::from_ssz_bytes(&vec![42; 8 * 1025]).unwrap_err(),
ssz::DecodeError::BytesInvalid(
"VariableList of 1025 items exceeds maximum of 1024".into()
)
);
}

#[test]
Expand All @@ -452,6 +510,21 @@ mod test {
assert_eq!(VariableList::from_ssz_bytes(&[]).unwrap(), empty_list);
}

#[test]
fn ssz_bytes_u32_trailing() {
let bytes = [1, 0, 0, 0, 2, 0];
assert_eq!(
VariableList::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
ssz::DecodeError::BytesInvalid("VariableList of 1 items has 6 bytes".into())
);

let bytes = [1, 0, 0, 0, 2, 0, 0, 0, 3];
assert_eq!(
VariableList::<u32, U2>::from_ssz_bytes(&bytes).unwrap_err(),
ssz::DecodeError::BytesInvalid("VariableList of 2 items has 9 bytes".into())
);
}

fn root_with_length(bytes: &[u8], len: usize) -> Hash256 {
let root = merkle_root(bytes, 0);
tree_hash::mix_in_length(&root, len)
Expand Down
Loading