diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 366b85b..c0e0c72 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -20,7 +20,7 @@ jobs: steps: - uses: actions/checkout@v3 - uses: dtolnay/rust-toolchain@stable - - run: cargo test + - run: cargo test --features arrayvec clippy: runs-on: ubuntu-latest diff --git a/.vscode/settings.json b/.vscode/settings.json index 3437c77..2b0ee9f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,5 @@ { + "rust-analyzer.cargo.allTargets": false, "rust-analyzer.linkedProjects": [ "Cargo.toml", "fuzz/Cargo.toml", diff --git a/CHANGELOG.md b/CHANGELOG.md index 391d436..2dd14cb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ and carries a predefined error subtype. - Added `erase_all` function as a helper to erase the flash in a region. - *Breaking:* Changed the way that queue iteration works. Now there's an `iter` function instead of two separate `peek_many` and `pop_many` functions. The new iter returns an entry from which you can get the data that was just peeked. If you want to pop it, then call the pop function on the entry. +- Added `arrayvec` feature that when activated impls the `Key` trait for `ArrayVec` and `ArrayString`. ## 1.0.0 01-03-24 diff --git a/Cargo.toml b/Cargo.toml index be63cdb..124ff65 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ embedded-storage-async = "0.4.1" defmt = { version = "0.3", optional = true } futures = { version = "0.3.30", features = ["executor"], optional = true } approx = { version = "0.5.1", optional = true } +arrayvec = { version = "0.7.4", default-features = false, optional = true } [dev-dependencies] approx = "0.5.1" @@ -25,4 +26,6 @@ futures-test = "0.3.30" [features] defmt-03 = ["dep:defmt"] std = [] -_test = ["dep:futures", "dep:approx", "std"] +# Enable the implementation of the map Key trait for ArrayVec and ArrayString +arrayvec = ["dep:arrayvec"] +_test = ["dep:futures", "dep:approx", "std", "arrayvec"] diff --git a/src/arrayvec_impl.rs b/src/arrayvec_impl.rs new file mode 100644 index 0000000..da43f54 --- /dev/null +++ b/src/arrayvec_impl.rs @@ -0,0 +1,123 @@ +use arrayvec::{ArrayString, ArrayVec}; + +use crate::map::{Key, SerializationError}; + +impl Key for ArrayVec { + fn serialize_into(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < self.len() + 2 { + return Err(SerializationError::BufferTooSmall); + } + + if self.len() > u16::MAX as usize { + return Err(SerializationError::InvalidData); + } + + buffer[..2].copy_from_slice(&(self.len() as u16).to_le_bytes()); + buffer[2..][..self.len()].copy_from_slice(self); + + Ok(self.len() + 2) + } + + fn deserialize_from(buffer: &[u8]) -> Result<(Self, usize), SerializationError> { + let total_len = Self::get_len(buffer)?; + + if buffer.len() < total_len { + return Err(SerializationError::BufferTooSmall); + } + + let data_len = total_len - 2; + + let mut output = ArrayVec::new(); + output + .try_extend_from_slice(&buffer[2..][..data_len]) + .map_err(|_| SerializationError::InvalidFormat)?; + + Ok((output, total_len)) + } + + fn get_len(buffer: &[u8]) -> Result { + if buffer.len() < 2 { + return Err(SerializationError::BufferTooSmall); + } + + let len = u16::from_le_bytes(buffer[..2].try_into().unwrap()); + + Ok(len as usize + 2) + } +} + +impl Key for ArrayString { + fn serialize_into(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < self.len() + 2 { + return Err(SerializationError::BufferTooSmall); + } + + if self.len() > u16::MAX as usize { + return Err(SerializationError::InvalidData); + } + + buffer[..2].copy_from_slice(&(self.len() as u16).to_le_bytes()); + buffer[2..][..self.len()].copy_from_slice(self.as_bytes()); + + Ok(self.len() + 2) + } + + fn deserialize_from(buffer: &[u8]) -> Result<(Self, usize), SerializationError> { + let total_len = Self::get_len(buffer)?; + + if buffer.len() < total_len { + return Err(SerializationError::BufferTooSmall); + } + + let data_len = total_len - 2; + + let mut output = ArrayString::new(); + output + .try_push_str( + core::str::from_utf8(&buffer[2..][..data_len]) + .map_err(|_| SerializationError::InvalidFormat)?, + ) + .map_err(|_| SerializationError::InvalidFormat)?; + + Ok((output, total_len)) + } + + fn get_len(buffer: &[u8]) -> Result { + if buffer.len() < 2 { + return Err(SerializationError::BufferTooSmall); + } + + let len = u16::from_le_bytes(buffer[..2].try_into().unwrap()); + + Ok(len as usize + 2) + } +} + +#[cfg(test)] +mod tests { + use core::str::FromStr; + + use super::*; + + #[test] + fn serde_arrayvec() { + let mut buffer = [0; 128]; + + let val = ArrayVec::::from_iter([0xAA; 12]); + val.serialize_into(&mut buffer).unwrap(); + let new_val = ArrayVec::::deserialize_from(&buffer).unwrap(); + + assert_eq!((val, 14), new_val); + } + + #[test] + fn serde_arraystring() { + let mut buffer = [0; 128]; + + let val = ArrayString::<45>::from_str("Hello world!").unwrap(); + val.serialize_into(&mut buffer).unwrap(); + let new_val = ArrayString::<45>::deserialize_from(&buffer).unwrap(); + + assert_eq!((val, 14), new_val); + } +} diff --git a/src/lib.rs b/src/lib.rs index 07a0ff2..21db652 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,8 +13,10 @@ use core::{ ops::{Deref, DerefMut, Range}, }; use embedded_storage_async::nor_flash::NorFlash; -use map::MapValueError; +use map::SerializationError; +#[cfg(feature = "arrayvec")] +mod arrayvec_impl; pub mod cache; mod item; pub mod map; @@ -377,8 +379,14 @@ pub enum Error { BufferTooBig, /// A provided buffer was to small to be used (usize is size needed) BufferTooSmall(usize), - /// A map value error - MapValueError(MapValueError), + /// A serialization error (from the key or value) + SerializationError(SerializationError), +} + +impl From for Error { + fn from(v: SerializationError) -> Self { + Self::SerializationError(v) + } } impl PartialEq for Error { @@ -407,7 +415,7 @@ where f, "A provided buffer was to small to be used. Needed was {needed}" ), - Error::MapValueError(value) => write!(f, "Map value error: {value}"), + Error::SerializationError(value) => write!(f, "Map value error: {value}"), } } } diff --git a/src/map.rs b/src/map.rs index 43d6c69..5eb88c9 100644 --- a/src/map.rs +++ b/src/map.rs @@ -143,21 +143,23 @@ pub async fn fetch_item<'d, K: Key, V: Value<'d>, S: NorFlash>( repair = try_repair::(flash, flash_range.clone(), cache, data_buffer).await? ); - let Some((item, _)) = result? else { + let Some((item, _, item_key_len)) = result? else { return Ok(None); }; - let item = item.reborrow(data_buffer); - - let data_len = item.data().len(); + let data_len = item.header.length as usize; + let item_key_len = match item_key_len { + Some(item_key_len) => item_key_len, + None => K::get_len(&data_buffer[..data_len])?, + }; Ok(Some( - V::deserialize_from(&item.destruct().1[K::LEN..][..data_len - K::LEN]) - .map_err(Error::MapValueError)?, + V::deserialize_from(&data_buffer[item_key_len..][..data_len - item_key_len]) + .map_err(Error::SerializationError)?, )) } -/// Fetch the item, but with the address and header +/// Fetch the item, but with the item unborrowed, the address of the item and the length of the key #[allow(clippy::type_complexity)] async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( flash: &mut S, @@ -165,7 +167,7 @@ async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( cache: &mut impl PrivateKeyCacheImpl, data_buffer: &'d mut [u8], search_key: K, -) -> Result, Error> { +) -> Result)>, Error> { assert_eq!(flash_range.start % S::ERASE_SIZE as u32, 0); assert_eq!(flash_range.end % S::ERASE_SIZE as u32, 0); assert!(flash_range.end - flash_range.start >= S::ERASE_SIZE as u32 * 2); @@ -210,7 +212,7 @@ async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( break 'cache; } item::MaybeItem::Present(item) => { - return Ok(Some((item.unborrow(), cached_location))); + return Ok(Some((item.unborrow(), cached_location, None))); } } } @@ -253,7 +255,7 @@ async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( // If we don't find it in the current page, then we check again in the previous page if that page is closed. let mut current_page_to_check = last_used_page.unwrap(); - let mut newest_found_item_address = None; + let mut newest_found_item_data = None; loop { let page_data_start_address = @@ -265,13 +267,14 @@ async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( let mut it = ItemIter::new(page_data_start_address, page_data_end_address); while let Some((item, address)) = it.next(flash, data_buffer).await? { - if K::deserialize_from(&item.data()[..K::LEN]) == search_key { - newest_found_item_address = Some(address); + let (found_key, found_key_len) = K::deserialize_from(item.data())?; + if found_key == search_key { + newest_found_item_data = Some((address, found_key_len)); } } // We've found the item! We can stop searching - if let Some(newest_found_item_address) = newest_found_item_address.as_ref() { + if let Some((newest_found_item_address, _)) = newest_found_item_data.as_ref() { cache.notice_key_location(search_key, *newest_found_item_address, false); break; @@ -295,7 +298,7 @@ async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( // We now need to reread the item because we lost all its data other than its address - if let Some(newest_found_item_address) = newest_found_item_address { + if let Some((newest_found_item_address, newest_found_item_key_len)) = newest_found_item_data { let item = ItemHeader::read_new(flash, newest_found_item_address, u32::MAX) .await? .ok_or_else(|| { @@ -308,7 +311,11 @@ async fn fetch_item_with_location<'d, K: Key, S: NorFlash>( .read_item(flash, data_buffer, newest_found_item_address, u32::MAX) .await?; - Ok(Some((item.unwrap()?.unborrow(), newest_found_item_address))) + Ok(Some(( + item.unwrap()?.unborrow(), + newest_found_item_address, + Some(newest_found_item_key_len), + ))) } else { Ok(None) } @@ -412,11 +419,11 @@ async fn store_item_inner<'d, K: Key, S: NorFlash>( calculate_page_end_address::(flash_range.clone(), partial_open_page) - S::WORD_SIZE as u32; - key.serialize_into(&mut data_buffer[..K::LEN]); - let item_data_length = K::LEN + let key_len = key.serialize_into(data_buffer)?; + let item_data_length = key_len + item - .serialize_into(&mut data_buffer[K::LEN..]) - .map_err(Error::MapValueError)?; + .serialize_into(&mut data_buffer[key_len..]) + .map_err(Error::SerializationError)?; let free_spot_address = find_next_free_item_spot( flash, @@ -604,7 +611,7 @@ async fn remove_item_inner( item::MaybeItem::Corrupted(_, _) => continue, item::MaybeItem::Erased(_, _) => continue, item::MaybeItem::Present(item) => { - let item_key = K::deserialize_from(&item.data()[..K::LEN]); + let (item_key, _) = K::deserialize_from(item.data())?; // If this item has the same key as the key we're trying to erase, then erase the item. // But keep going! We need to erase everything. @@ -626,36 +633,48 @@ async fn remove_item_inner( /// Anything implementing this trait can be used as a key in the map functions. /// -/// It provides a way to serialize and deserialize the key as well as provide a -/// constant for the serialized length. -/// -/// The functions don't return a result. Keys should be simple and trivial. +/// It provides a way to serialize and deserialize the key. /// /// The `Eq` bound is used because we need to be able to compare keys and the /// `Clone` bound helps us pass the key around. +/// +/// The key cannot have a lifetime like the [Value] pub trait Key: Eq + Clone + Sized { - /// The serialized length of the key - const LEN: usize; - /// Serialize the key into the given buffer. - /// The buffer is always of the same length as the [Self::LEN] constant. - fn serialize_into(&self, buffer: &mut [u8]); + /// The serialized size is returned. + fn serialize_into(&self, buffer: &mut [u8]) -> Result; /// Deserialize the key from the given buffer. - /// The buffer is always of the same length as the [Self::LEN] constant. - fn deserialize_from(buffer: &[u8]) -> Self; + /// The key is returned together with the serialized length. + fn deserialize_from(buffer: &[u8]) -> Result<(Self, usize), SerializationError>; + /// Get the length of the key from the buffer. + /// This is an optimized version of [Self::deserialize_from] that doesn't have to deserialize everything. + fn get_len(buffer: &[u8]) -> Result { + Self::deserialize_from(buffer).map(|(_, len)| len) + } } macro_rules! impl_key_num { ($int:ty) => { impl Key for $int { - const LEN: usize = core::mem::size_of::(); - - fn serialize_into(&self, buffer: &mut [u8]) { - buffer.copy_from_slice(&self.to_le_bytes()); + fn serialize_into(&self, buffer: &mut [u8]) -> Result { + let len = core::mem::size_of::(); + if buffer.len() < len { + return Err(SerializationError::BufferTooSmall); + } + buffer[..len].copy_from_slice(&self.to_le_bytes()); + Ok(len) } - fn deserialize_from(buffer: &[u8]) -> Self { - Self::from_le_bytes(buffer.try_into().unwrap()) + fn deserialize_from(buffer: &[u8]) -> Result<(Self, usize), SerializationError> { + let len = core::mem::size_of::(); + if buffer.len() < len { + return Err(SerializationError::BufferTooSmall); + } + + Ok(( + Self::from_le_bytes(buffer[..len].try_into().unwrap()), + core::mem::size_of::(), + )) } } }; @@ -673,14 +692,20 @@ impl_key_num!(i64); impl_key_num!(i128); impl Key for [u8; N] { - const LEN: usize = N; - - fn serialize_into(&self, buffer: &mut [u8]) { - buffer.copy_from_slice(self); + fn serialize_into(&self, buffer: &mut [u8]) -> Result { + if buffer.len() < N { + return Err(SerializationError::BufferTooSmall); + } + buffer[..N].copy_from_slice(self); + Ok(N) } - fn deserialize_from(buffer: &[u8]) -> Self { - buffer[..Self::LEN].try_into().unwrap() + fn deserialize_from(buffer: &[u8]) -> Result<(Self, usize), SerializationError> { + if buffer.len() < N { + return Err(SerializationError::BufferTooSmall); + } + + Ok((buffer[..N].try_into().unwrap(), N)) } } @@ -691,28 +716,29 @@ impl Key for [u8; N] { pub trait Value<'a> { /// Serialize the value into the given buffer. If everything went ok, this function returns the length /// of the used part of the buffer. - fn serialize_into(&self, buffer: &mut [u8]) -> Result; + fn serialize_into(&self, buffer: &mut [u8]) -> Result; /// Deserialize the value from the buffer. Because of the added lifetime, the implementation can borrow from the /// buffer which opens up some zero-copy possibilities. /// /// The buffer will be the same length as the serialize function returned for this value. Though note that the length - /// is written from flash, so bitflips can affect that (though the length is separately crc-protected). - fn deserialize_from(buffer: &'a [u8]) -> Result + /// is written from flash, so bitflips can affect that (though the length is separately crc-protected) and the key deserialization might + /// return a wrong length. + fn deserialize_from(buffer: &'a [u8]) -> Result where Self: Sized; } impl<'a> Value<'a> for &'a [u8] { - fn serialize_into(&self, buffer: &mut [u8]) -> Result { + fn serialize_into(&self, buffer: &mut [u8]) -> Result { if buffer.len() < self.len() { - return Err(MapValueError::BufferTooSmall); + return Err(SerializationError::BufferTooSmall); } buffer[..self.len()].copy_from_slice(self); Ok(self.len()) } - fn deserialize_from(buffer: &'a [u8]) -> Result + fn deserialize_from(buffer: &'a [u8]) -> Result where Self: Sized, { @@ -721,36 +747,38 @@ impl<'a> Value<'a> for &'a [u8] { } impl<'a, const N: usize> Value<'a> for [u8; N] { - fn serialize_into(&self, buffer: &mut [u8]) -> Result { + fn serialize_into(&self, buffer: &mut [u8]) -> Result { if buffer.len() < self.len() { - return Err(MapValueError::BufferTooSmall); + return Err(SerializationError::BufferTooSmall); } buffer[..self.len()].copy_from_slice(self); Ok(self.len()) } - fn deserialize_from(buffer: &'a [u8]) -> Result + fn deserialize_from(buffer: &'a [u8]) -> Result where Self: Sized, { - buffer.try_into().map_err(|_| MapValueError::BufferTooSmall) + buffer + .try_into() + .map_err(|_| SerializationError::BufferTooSmall) } } macro_rules! impl_map_item_num { ($int:ty) => { impl<'a> Value<'a> for $int { - fn serialize_into(&self, buffer: &mut [u8]) -> Result { + fn serialize_into(&self, buffer: &mut [u8]) -> Result { buffer[..core::mem::size_of::()].copy_from_slice(&self.to_le_bytes()); Ok(core::mem::size_of::()) } - fn deserialize_from(buffer: &[u8]) -> Result { + fn deserialize_from(buffer: &[u8]) -> Result { Ok(Self::from_le_bytes( buffer[..core::mem::size_of::()] .try_into() - .map_err(|_| MapValueError::BufferTooSmall)?, + .map_err(|_| SerializationError::BufferTooSmall)?, )) } } @@ -776,7 +804,7 @@ impl_map_item_num!(f64); #[non_exhaustive] #[derive(Debug, PartialEq, Eq, Clone)] #[cfg_attr(feature = "defmt-03", derive(defmt::Format))] -pub enum MapValueError { +pub enum SerializationError { /// The provided buffer was too small. BufferTooSmall, /// The serialization could not succeed because the data was not in order. (e.g. too big to fit) @@ -788,13 +816,13 @@ pub enum MapValueError { Custom(i32), } -impl core::fmt::Display for MapValueError { +impl core::fmt::Display for SerializationError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - MapValueError::BufferTooSmall => write!(f, "Buffer too small"), - MapValueError::InvalidData => write!(f, "Invalid data"), - MapValueError::InvalidFormat => write!(f, "Invalid format"), - MapValueError::Custom(val) => write!(f, "Custom error: {val}"), + SerializationError::BufferTooSmall => write!(f, "Buffer too small"), + SerializationError::InvalidData => write!(f, "Invalid data"), + SerializationError::InvalidFormat => write!(f, "Invalid format"), + SerializationError::Custom(val) => write!(f, "Custom error: {val}"), } } } @@ -818,14 +846,14 @@ async fn migrate_items( calculate_page_end_address::(flash_range.clone(), source_page) - S::WORD_SIZE as u32, ); while let Some((item, item_address)) = it.next(flash, data_buffer).await? { - let key = K::deserialize_from(&item.data()[..K::LEN]); + let (key, _) = K::deserialize_from(item.data())?; let (_, data_buffer) = item.destruct(); // We're in a decent state here cache.unmark_dirty(); // Search for the newest item with the key we found - let Some((found_item, found_address)) = fetch_item_with_location::( + let Some((found_item, found_address, _)) = fetch_item_with_location::( flash, flash_range.clone(), cache,