diff --git a/core/src/memory_region.rs b/core/src/memory_region.rs index 70c927c..fd6378a 100644 --- a/core/src/memory_region.rs +++ b/core/src/memory_region.rs @@ -103,6 +103,56 @@ impl ArrayMemoryRegion { self.data.set_len(data_len); self.data.as_mut_ptr().copy_from(data_ptr, data_len); } + + /// Try to build a [ArrayMemoryRegion] from an [IntoIterator] + pub fn try_from_iter>( + iter: I, + ) -> Result { + use MemoryRegionFromIterError::*; + let mut iter = iter.into_iter(); + + match iter.next() { + Some(MEMORY_REGION_IDENTIFIER) => {} + Some(id) => return Err(InvalidIdentifier(id)), + None => return Err(NotEnoughItems), + } + + let start_address = u64::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); + + let length = u64::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); + + if length > SIZE as u64 { + return Err(LengthTooBig(length)); + } + + // This call panics if length > SIZE + // and `iter` does indeed contain more than SIZE items, + // but we've just covered that case + let data = ArrayVec::from_iter(iter.take(length as usize)); + + Ok(Self { + start_address, + data, + }) + } } #[cfg(feature = "std")] @@ -134,42 +184,7 @@ impl<'a, const SIZE: usize> FromIterator<&'a u8> for ArrayMemoryRegion { impl FromIterator for ArrayMemoryRegion { fn from_iter>(iter: T) -> Self { - let mut iter = iter.into_iter(); - - assert_eq!( - iter.next().unwrap(), - MEMORY_REGION_IDENTIFIER, - "The given iterator is not for a memory region" - ); - - let start_address = u64::from_le_bytes([ - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - ]); - - let length = u64::from_le_bytes([ - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - ]); - - let data = ArrayVec::from_iter(iter.take(length as usize)); - - Self { - start_address, - data, - } + Self::try_from_iter(iter).unwrap() } } @@ -234,6 +249,49 @@ impl VecMemoryRegion { self.data.as_mut_ptr().copy_from(data_ptr, data_len); } + + /// Try to build a [VecMemoryRegion] from an [IntoIterator] + pub fn try_from_iter>( + iter: I, + ) -> Result { + use MemoryRegionFromIterError::*; + let mut iter = iter.into_iter(); + + match iter.next() { + Some(MEMORY_REGION_IDENTIFIER) => {} + Some(id) => return Err(InvalidIdentifier(id)), + None => return Err(NotEnoughItems), + } + + let start_address = u64::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); + + let length = u64::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); + + let data = Vec::from_iter(iter.take(length as usize)); + + Ok(Self { + start_address, + data, + }) + } } #[cfg(feature = "std")] @@ -267,42 +325,7 @@ impl<'a> FromIterator<&'a u8> for VecMemoryRegion { #[cfg(feature = "std")] impl FromIterator for VecMemoryRegion { fn from_iter>(iter: T) -> Self { - let mut iter = iter.into_iter(); - - assert_eq!( - iter.next().unwrap(), - MEMORY_REGION_IDENTIFIER, - "The given iterator is not for a memory region" - ); - - let start_address = u64::from_le_bytes([ - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - ]); - - let length = u64::from_le_bytes([ - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - iter.next().unwrap(), - ]); - - let data = Vec::from_iter(iter.take(length as usize)); - - Self { - start_address, - data, - } + Self::try_from_iter(iter).unwrap() } } @@ -433,6 +456,32 @@ impl<'a> Iterator for MemoryRegionIterator<'a> { impl<'a> ExactSizeIterator for MemoryRegionIterator<'a> {} +#[derive(Debug)] +/// Specifies what went wrong building a [MemoryRegion] from an iterator +pub enum MemoryRegionFromIterError { + /// The given iterator is not for a memory region. + /// First item from iterator yielded invalid identifier. Expected [MEMORY_REGION_IDENTIFIER] + InvalidIdentifier(u8), + /// Iterator specified length too big for declared region + LengthTooBig(u64), + /// Iterator did not yield enough items to build memory region + NotEnoughItems, +} + +impl core::fmt::Display for MemoryRegionFromIterError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use MemoryRegionFromIterError::*; + match self { + InvalidIdentifier(id) => write!(f, "Iterator is not for a memory region. Started with {id}, expected {MEMORY_REGION_IDENTIFIER}"), + LengthTooBig(len) => write!(f, "Iterator specified length too big for declared region: {len}"), + NotEnoughItems => write!(f, "Iterator did not yield enough items to build memory region"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for MemoryRegionFromIterError {} + #[cfg(test)] mod tests { use super::*; diff --git a/core/src/register_data.rs b/core/src/register_data.rs index ac7b421..b16abe5 100644 --- a/core/src/register_data.rs +++ b/core/src/register_data.rs @@ -95,27 +95,36 @@ impl RegisterData for ArrayRegisterD } } -impl FromIterator for ArrayRegisterData +impl ArrayRegisterData where RB: funty::Integral, RB::Bytes: for<'a> TryFrom<&'a [u8]>, { - fn from_iter>(iter: T) -> Self { - // Get the iterator. We assume that it is in the same format as the bytes function outputs + /// Try to build a [ArrayRegisterData] from an [IntoIterator] + pub fn try_from_iter>( + iter: I, + ) -> Result { + use RegisterDataFromIterError::*; + // Get the iterator. let mut iter = iter.into_iter(); - assert_eq!( - iter.next().unwrap(), - REGISTER_DATA_IDENTIFIER, - "The given iterator is not for register data" - ); + match iter.next() { + Some(REGISTER_DATA_IDENTIFIER) => {} + Some(id) => return Err(InvalidIdentifier(id)), + None => return Err(NotEnoughItems), + } // First the starting number is encoded - let starting_register_number = - u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let starting_register_number = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); // Second is how many registers there are - let register_count = u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let register_count = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); // Create the buffer we're storing the registers in let mut registers = ArrayVec::new(); @@ -123,28 +132,48 @@ where // We process everything byte-by-byte generically so every register has an unknown length // So we need to store the bytes temporarily until we have enough to fully read the bytes as a register let register_size = core::mem::size_of::(); - let mut register_bytes_buffer = ArrayVec::::new(); - for byte in (0..register_count as usize * register_size).map(|_| iter.next().unwrap()) { - register_bytes_buffer.push(byte); + // Check that all register bytes will fit in `registers` + if register_count as usize > SIZE { + return Err(LengthTooBig(register_count, register_size)); + } + + let mut register_bytes_buffer = ArrayVec::::new(); + + let register_bytes_count = register_count as usize * register_size; + for byte in (0..register_bytes_count).map(|_| iter.next().ok_or(NotEnoughItems)) { + let byte = byte?; + register_bytes_buffer.try_push(byte).map_err(|_| Corrupt)?; if register_bytes_buffer.len() == register_size { registers.push(RB::from_le_bytes( register_bytes_buffer .as_slice() .try_into() - .unwrap_or_else(|_| panic!()), + .map_err(|_| Corrupt)?, )); register_bytes_buffer = ArrayVec::new(); } } - assert!(register_bytes_buffer.is_empty()); + if !register_bytes_buffer.is_empty() { + return Err(Corrupt); + } - Self { + Ok(Self { starting_register_number, registers, - } + }) + } +} + +impl FromIterator for ArrayRegisterData +where + RB: funty::Integral, + RB::Bytes: for<'a> TryFrom<&'a [u8]>, +{ + fn from_iter>(iter: T) -> Self { + Self::try_from_iter(iter).unwrap() } } @@ -223,50 +252,76 @@ impl RegisterData for VecRegisterData { } #[cfg(feature = "std")] -impl FromIterator for VecRegisterData +impl VecRegisterData where RB: funty::Integral, RB::Bytes: for<'a> TryFrom<&'a [u8]>, { - fn from_iter>(iter: T) -> Self { + /// Try to build a [VecRegisterData] from an [IntoIterator] + pub fn try_from_iter>( + iter: I, + ) -> Result { + use RegisterDataFromIterError::*; + let mut iter = iter.into_iter(); - assert_eq!( - iter.next().unwrap(), - REGISTER_DATA_IDENTIFIER, - "The given iterator is not for register data" - ); + match iter.next() { + Some(REGISTER_DATA_IDENTIFIER) => {} + Some(id) => return Err(InvalidIdentifier(id)), + None => return Err(NotEnoughItems), + } - let starting_register_number = - u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let starting_register_number = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); - let register_count = u16::from_le_bytes([iter.next().unwrap(), iter.next().unwrap()]); + let register_count = u16::from_le_bytes([ + iter.next().ok_or(NotEnoughItems)?, + iter.next().ok_or(NotEnoughItems)?, + ]); let mut registers = Vec::new(); let register_size = core::mem::size_of::(); let mut register_bytes_buffer = ArrayVec::::new(); - for byte in (0..register_count as usize * register_size).map(|_| iter.next().unwrap()) { - register_bytes_buffer.push(byte); + for byte in + (0..register_count as usize * register_size).map(|_| iter.next().ok_or(NotEnoughItems)) + { + let byte = byte?; + register_bytes_buffer.try_push(byte).map_err(|_| Corrupt)?; if register_bytes_buffer.len() == register_size { registers.push(RB::from_le_bytes( register_bytes_buffer .as_slice() .try_into() - .unwrap_or_else(|_| panic!()), + .map_err(|_| Corrupt)?, )); register_bytes_buffer.clear(); } } - assert!(register_bytes_buffer.is_empty()); + if !register_bytes_buffer.is_empty() { + return Err(Corrupt); + } - Self { + Ok(Self { starting_register_number, registers, - } + }) + } +} + +#[cfg(feature = "std")] +impl FromIterator for VecRegisterData +where + RB: funty::Integral, + RB::Bytes: for<'a> TryFrom<&'a [u8]>, +{ + fn from_iter>(iter: T) -> Self { + Self::try_from_iter(iter).unwrap() } } @@ -324,6 +379,35 @@ impl<'a, RB: funty::Integral> Iterator for RegisterDataBytesIterator<'a, RB> { impl<'a, RB: funty::Integral> ExactSizeIterator for RegisterDataBytesIterator<'a, RB> {} +#[derive(Debug)] +/// Specifies what went wrong building a [RegisterData] from an iterator +pub enum RegisterDataFromIterError { + /// The given iterator is not for a register set. + /// First item from iterator yielded invalid identifier. Expected [REGISTER_DATA_IDENTIFIER] + InvalidIdentifier(u8), + /// Iterator specified length too big for declared register set + LengthTooBig(u16, usize), + /// Iterator did not yield enough items to build register set + NotEnoughItems, + /// Iterator data is corrupt in some other way + Corrupt, +} + +impl core::fmt::Display for RegisterDataFromIterError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + use RegisterDataFromIterError::*; + match self { + InvalidIdentifier(id) => write!(f, "Iterator is not for a register set. Started with {id}, expected {REGISTER_DATA_IDENTIFIER}"), + LengthTooBig(count, size) => write!(f, "Iterator specified length too big for register set: {len}", len = *count as usize * size), + NotEnoughItems => write!(f, "Iterator did not yield enough items to build register set"), + Corrupt => write!(f, "Iterator data is corrupt") + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RegisterDataFromIterError {} + #[cfg(test)] mod tests { use super::*;