From 732a427441fafe66d201b774e366532249e9fb19 Mon Sep 17 00:00:00 2001 From: A5 Pickle <5342825+a5-pickle@users.noreply.github.com> Date: Sat, 27 Apr 2024 23:09:06 -0500 Subject: [PATCH] io: make `WriteableBytes` and `TypePrefixedPayload` more generic (#59) * universal/io: make WriteableBytes generic * uptick wormhole-io 0.2.0 * universal/io: trim traits * uptick wormhole-io version 0.2.0-alpha.3 --------- Co-authored-by: A5 Pickle --- universal/Cargo.toml | 2 +- universal/io/src/payload.rs | 85 +++++------ universal/io/src/read_write.rs | 270 ++++++++++++++++++++------------- 3 files changed, 207 insertions(+), 150 deletions(-) diff --git a/universal/Cargo.toml b/universal/Cargo.toml index 552ba93..962644d 100644 --- a/universal/Cargo.toml +++ b/universal/Cargo.toml @@ -6,7 +6,7 @@ members = [ resolver = "2" [workspace.package] -version = "0.1.3" +version = "0.2.0-alpha.3" edition = "2021" authors = ["Wormhole Contributors"] license = "Apache-2.0" diff --git a/universal/io/src/payload.rs b/universal/io/src/payload.rs index 1c64dce..3869ad5 100644 --- a/universal/io/src/payload.rs +++ b/universal/io/src/payload.rs @@ -5,34 +5,34 @@ use crate::{Readable, Writeable}; /// Trait to capture common payload behavior. We do not recommend overwriting /// any trait methods. Simply set the type constant and implement [`Readable`] /// and [`Writeable`]. -pub trait TypePrefixedPayload: Readable + Writeable + Clone + std::fmt::Debug { - const TYPE: Option; - - /// Read the payload, including the type prefix. - fn read_typed(reader: &mut R) -> Result { - let payload_type = u8::read(reader)?; - if payload_type == Self::TYPE.expect("Called write_typed on untyped payload") { - Self::read(reader) - } else { - Err(io::Error::new( - io::ErrorKind::InvalidData, - "Invalid payload type", - )) - } - } +pub trait TypePrefixedPayload: + Readable + Writeable + Clone + std::fmt::Debug +{ + const TYPE: Option<[u8; N]>; + + fn written_size(&self) -> usize; - /// Write the payload, including the type prefix. - fn write_typed(&self, writer: &mut W) -> Result<(), io::Error> { - Self::TYPE - .expect("Called write_typed on untyped payload") - .write(writer)?; - Writeable::write(self, writer) + /// Returns the size of the payload, including the type prefix. + fn payload_written_size(&self) -> usize { + match Self::TYPE { + Some(_) => self.written_size() + N, + None => self.written_size(), + } } /// Read the payload, including the type prefix if applicable. fn read_payload(reader: &mut R) -> Result { match Self::TYPE { - Some(_) => Self::read_typed(reader), + Some(id) => { + if id != <[u8; N]>::read(reader)? { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid payload type", + )); + } + + Readable::read(reader) + } None => Readable::read(reader), } } @@ -59,20 +59,15 @@ pub trait TypePrefixedPayload: Readable + Writeable + Clone + std::fmt::Debug { /// Write the payload, including the type prefix if applicable. fn write_payload(&self, writer: &mut W) -> Result<(), io::Error> { match Self::TYPE { - Some(_) => self.write_typed(writer), + Some(id) => { + id.write(writer)?; + Writeable::write(self, writer) + } None => Writeable::write(self, writer), } } - /// Returns the size of the payload, including the type prefix. - fn payload_written_size(&self) -> usize { - match Self::TYPE { - Some(_) => self.written_size() + 1, - None => self.written_size(), - } - } - - fn to_vec_payload(&self) -> Vec { + fn to_vec(&self) -> Vec { let mut buf = Vec::with_capacity(self.payload_written_size()); self.write_payload(&mut buf).expect("no alloc failure"); buf @@ -91,18 +86,20 @@ mod test { pub struct Message { pub a: u32, pub b: NineteenBytes, - pub c: WriteableBytes, + pub c: WriteableBytes, pub d: [u64; 4], pub e: bool, } - impl TypePrefixedPayload for Message { - const TYPE: Option = Some(69); + impl TypePrefixedPayload<1> for Message { + const TYPE: Option<[u8; 1]> = Some([69]); + + fn written_size(&self) -> usize { + 88 + } } impl Readable for Message { - const SIZE: Option = Some(88); - fn read(reader: &mut R) -> std::io::Result where Self: Sized, @@ -119,10 +116,6 @@ mod test { } impl Writeable for Message { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> std::io::Result<()> where W: std::io::Write, @@ -141,18 +134,18 @@ mod test { let msg = Message { a: 420, b: NineteenBytes(hex!("ba5edba5edba5edba5edba5edba5edba5edba5")), - c: b"Somebody set us up the bomb.".to_vec().into(), + c: b"Somebody set us up the bomb.".to_vec().try_into().unwrap(), d: [0x45; 4], e: true, }; - let mut encoded = msg.to_vec_payload(); + let mut encoded = msg.to_vec(); assert_eq!(encoded, hex!("45000001a4ba5edba5edba5edba5edba5edba5edba5edba50000001c536f6d65626f6479207365742075732075702074686520626f6d622e000000000000004500000000000000450000000000000045000000000000004501")); - assert_eq!(encoded.capacity(), msg.payload_written_size()); + assert_eq!(encoded.capacity(), 1 + msg.written_size()); assert_eq!(encoded.capacity(), encoded.len()); let mut cursor = std::io::Cursor::new(&mut encoded); - let decoded = Message::read_typed(&mut cursor).unwrap(); + let decoded = Message::read_payload(&mut cursor).unwrap(); assert_eq!(msg, decoded); } @@ -173,7 +166,7 @@ mod test { let expected = Message { a: 420, b: NineteenBytes(hex!("ba5edba5edba5edba5edba5edba5edba5edba5")), - c: b"Somebody set us up the bomb.".to_vec().into(), + c: b"Somebody set us up the bomb.".to_vec().try_into().unwrap(), d: [0x45; 4], e: true, }; diff --git a/universal/io/src/read_write.rs b/universal/io/src/read_write.rs index b264f2c..16b495c 100644 --- a/universal/io/src/read_write.rs +++ b/universal/io/src/read_write.rs @@ -1,8 +1,6 @@ -use std::io; +use std::{io, marker::PhantomData}; pub trait Readable { - const SIZE: Option; - fn read(reader: &mut R) -> io::Result where Self: Sized, @@ -13,19 +11,9 @@ pub trait Writeable { fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write; - - fn written_size(&self) -> usize; - - fn to_vec(&self) -> Vec { - let mut buf = Vec::with_capacity(self.written_size()); - self.write(&mut buf).expect("no alloc failure"); - buf - } } impl Readable for u8 { - const SIZE: Option = Some(1); - fn read(reader: &mut R) -> io::Result where R: io::Read, @@ -37,10 +25,6 @@ impl Readable for u8 { } impl Writeable for u8 { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -50,8 +34,6 @@ impl Writeable for u8 { } impl Readable for bool { - const SIZE: Option = ::SIZE; - fn read(reader: &mut R) -> io::Result where R: io::Read, @@ -68,10 +50,6 @@ impl Readable for bool { } impl Writeable for bool { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -83,8 +61,6 @@ impl Writeable for bool { macro_rules! impl_for_int { ($type:ty) => { impl Readable for $type { - const SIZE: Option = Some(std::mem::size_of::<$type>()); - fn read(reader: &mut R) -> io::Result where R: io::Read, @@ -96,10 +72,6 @@ macro_rules! impl_for_int { } impl Writeable for $type { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -122,8 +94,6 @@ impl_for_int!(i64); impl_for_int!(i128); impl Readable for [u8; N] { - const SIZE: Option = Some(N); - fn read(reader: &mut R) -> io::Result where Self: Sized, @@ -136,10 +106,6 @@ impl Readable for [u8; N] { } impl Writeable for [u8; N] { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -151,8 +117,6 @@ impl Writeable for [u8; N] { macro_rules! impl_for_int_array { ($type:ty) => { impl Readable for [$type; N] { - const SIZE: Option = Some(N * std::mem::size_of::<$type>()); - fn read(reader: &mut R) -> io::Result where R: io::Read, @@ -166,10 +130,6 @@ macro_rules! impl_for_int_array { } impl Writeable for [$type; N] { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -196,68 +156,167 @@ impl_for_int_array!(i128); /// Wrapper for `Vec`. Encoding is similar to Borsh, where the length is encoded as u32 (but in /// this case, it's big endian). -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct WriteableBytes(Vec); +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct WriteableBytes +where + u32: From, + L: Sized + Readable + Writeable + TryFrom, +{ + phantom: PhantomData, + inner: Vec, +} -impl From> for WriteableBytes { - fn from(vec: Vec) -> Self { - Self(vec) +impl WriteableBytes +where + u32: From, + L: Sized + Readable + Writeable + TryFrom, +{ + pub fn new(inner: Vec) -> Self { + Self { + phantom: PhantomData, + inner, + } + } + + pub fn try_encoded_len(&self) -> io::Result { + match L::try_from(self.inner.len()) { + Ok(len) => Ok(len), + Err(_) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "L overflow when converting from usize", + )), + } } } -impl From for Vec { - fn from(bytes: WriteableBytes) -> Self { - bytes.0 +impl TryFrom> for WriteableBytes +where + u32: From, + L: Sized + Readable + Writeable + TryFrom, +{ + type Error = >::Error; + + fn try_from(vec: Vec) -> Result { + match L::try_from(vec.len()) { + Ok(_) => Ok(Self { + phantom: PhantomData, + inner: vec, + }), + Err(e) => Err(e), + } } } -impl std::ops::Deref for WriteableBytes { +impl From> for Vec +where + u32: From, + L: Sized + Readable + Writeable + TryFrom, +{ + fn from(bytes: WriteableBytes) -> Self { + bytes.inner + } +} + +impl std::ops::Deref for WriteableBytes +where + L: Sized + Readable + Writeable, + u32: From, + L: TryFrom, +{ type Target = Vec; fn deref(&self) -> &Self::Target { - &self.0 + &self.inner } } -impl std::ops::DerefMut for WriteableBytes { +impl std::ops::DerefMut for WriteableBytes +where + u32: From, + L: Sized + Readable + Writeable + TryFrom, +{ fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 + &mut self.inner } } -impl Readable for WriteableBytes { - const SIZE: Option = None; +impl Readable for WriteableBytes { + fn read(reader: &mut R) -> io::Result + where + Self: Sized, + R: io::Read, + { + let len = u8::read(reader)?; + let mut inner: Vec = vec![0u8; len.into()]; + reader.read_exact(&mut inner)?; + Ok(Self { + phantom: PhantomData, + inner, + }) + } +} +impl Readable for WriteableBytes { fn read(reader: &mut R) -> io::Result where Self: Sized, R: io::Read, { - let len = u32::read(reader)?; - let mut buf = vec![0u8; len.try_into().expect("usize overflow")]; - reader.read_exact(&mut buf)?; - Ok(Self(buf)) + let len = u16::read(reader)?; + let mut inner = vec![0u8; len.into()]; + reader.read_exact(&mut inner)?; + Ok(Self { + phantom: PhantomData, + inner, + }) } } -impl Writeable for WriteableBytes { - fn written_size(&self) -> usize { - 4 + self.0.len() +impl Readable for WriteableBytes { + fn read(reader: &mut R) -> io::Result + where + Self: Sized, + R: io::Read, + { + let len = u32::read(reader)?; + match len.try_into() { + Ok(len) => { + let mut inner = vec![0u8; len]; + reader.read_exact(&mut inner)?; + Ok(Self { + phantom: PhantomData, + inner, + }) + } + Err(_) => Err(io::Error::new( + io::ErrorKind::InvalidData, + "u32 overflow when converting to usize", + )), + } } +} +impl Writeable for WriteableBytes +where + u32: From, + L: Sized + Readable + Writeable + TryFrom, +{ fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, { - (u32::try_from(self.0.len()).expect("u32 overflow")).write(writer)?; - writer.write_all(&self.0) + match self.try_encoded_len() { + Ok(len) => { + len.write(writer)?; + writer.write_all(&self.inner) + } + Err(e) => Err(e), + } } } #[cfg(feature = "alloy")] impl Readable for alloy_primitives::FixedBytes { - const SIZE: Option = Some(N); - fn read(reader: &mut R) -> io::Result where Self: Sized, @@ -269,10 +328,6 @@ impl Readable for alloy_primitives::FixedBytes { #[cfg(feature = "alloy")] impl Writeable for alloy_primitives::FixedBytes { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -283,8 +338,6 @@ impl Writeable for alloy_primitives::FixedBytes { #[cfg(feature = "alloy")] impl Readable for alloy_primitives::Uint { - const SIZE: Option = { Some(BITS * 8) }; - fn read(reader: &mut R) -> io::Result where Self: Sized, @@ -299,10 +352,6 @@ impl Readable for alloy_primitives::Uint< #[cfg(feature = "alloy")] impl Writeable for alloy_primitives::Uint { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -313,8 +362,6 @@ impl Writeable for alloy_primitives::Uint #[cfg(feature = "alloy")] impl Readable for alloy_primitives::Address { - const SIZE: Option = Some(20); - fn read(reader: &mut R) -> io::Result where Self: Sized, @@ -326,10 +373,6 @@ impl Readable for alloy_primitives::Address { #[cfg(feature = "alloy")] impl Writeable for alloy_primitives::Address { - fn written_size(&self) -> usize { - ::SIZE.unwrap() - } - fn write(&self, writer: &mut W) -> io::Result<()> where W: io::Write, @@ -346,80 +389,101 @@ pub mod test { #[test] fn u8_read_write() { const EXPECTED_SIZE: usize = 1; - assert_eq!(u8::SIZE, Some(EXPECTED_SIZE)); let value = 69u8; - assert_eq!(value.written_size(), EXPECTED_SIZE); - let mut encoded = Vec::::with_capacity(value.written_size()); + let mut encoded = Vec::::with_capacity(EXPECTED_SIZE); let mut writer = std::io::Cursor::new(&mut encoded); value.write(&mut writer).unwrap(); let expected = hex!("45"); assert_eq!(encoded, expected); - assert_eq!(value.to_vec(), expected.to_vec()); } #[test] fn u64_read_write() { const EXPECTED_SIZE: usize = 8; - assert_eq!(u64::SIZE, Some(EXPECTED_SIZE)); let value = 69u64; - assert_eq!(value.written_size(), EXPECTED_SIZE); - - let mut encoded = Vec::::with_capacity(value.written_size()); + let mut encoded = Vec::::with_capacity(EXPECTED_SIZE); let mut writer = std::io::Cursor::new(&mut encoded); value.write(&mut writer).unwrap(); let expected = hex!("0000000000000045"); assert_eq!(encoded, expected); - assert_eq!(value.to_vec(), expected.to_vec()); } #[test] fn u8_array_read_write() { let data = [1, 2, 8, 16, 32, 64, 69u8]; - assert_eq!(<[u8; 7]>::SIZE, Some(data.len())); - assert_eq!(data.written_size(), data.len()); - let mut encoded = Vec::::with_capacity(data.written_size()); + let mut encoded = Vec::::with_capacity(data.len()); let mut writer = std::io::Cursor::new(&mut encoded); data.write(&mut writer).unwrap(); let expected = hex!("01020810204045"); assert_eq!(encoded, expected); - assert_eq!(data.to_vec(), expected.to_vec()); } #[test] fn u64_array_read_write() { let data = [1, 2, 8, 16, 32, 64, 69u64]; const EXPECTED_SIZE: usize = 56; - assert_eq!(<[u64; 7]>::SIZE, Some(EXPECTED_SIZE)); - assert_eq!(data.written_size(), EXPECTED_SIZE); - let mut encoded = Vec::::with_capacity(data.written_size()); + let mut encoded = Vec::::with_capacity(EXPECTED_SIZE); let mut writer = std::io::Cursor::new(&mut encoded); data.write(&mut writer).unwrap(); let expected = hex!("0000000000000001000000000000000200000000000000080000000000000010000000000000002000000000000000400000000000000045"); assert_eq!(encoded, expected); - assert_eq!(data.to_vec(), expected.to_vec()); } #[test] - fn variable_bytes_read_write() { + fn variable_bytes_read_write_u8() { + let data = b"All your base are belong to us."; + let bytes = WriteableBytes::::new(data.to_vec()); + + let mut encoded = Vec::::with_capacity(1 + data.len()); + let mut writer = std::io::Cursor::new(&mut encoded); + bytes.write(&mut writer).unwrap(); + + let expected = hex!("1f416c6c20796f75722062617365206172652062656c6f6e6720746f2075732e"); + assert_eq!(encoded, expected); + } + + #[test] + fn variable_bytes_read_write_u16() { let data = b"All your base are belong to us."; - let bytes = WriteableBytes(data.to_vec()); + let bytes = WriteableBytes::::new(data.to_vec()); - let mut encoded = Vec::::with_capacity(bytes.written_size()); + let mut encoded = Vec::::with_capacity(2 + data.len()); + let mut writer = std::io::Cursor::new(&mut encoded); + bytes.write(&mut writer).unwrap(); + + let expected = hex!("001f416c6c20796f75722062617365206172652062656c6f6e6720746f2075732e"); + assert_eq!(encoded, expected); + } + + #[test] + fn variable_bytes_read_write_u32() { + let data = b"All your base are belong to us."; + let bytes = WriteableBytes::::new(data.to_vec()); + + let mut encoded = Vec::::with_capacity(4 + data.len()); let mut writer = std::io::Cursor::new(&mut encoded); bytes.write(&mut writer).unwrap(); let expected = hex!("0000001f416c6c20796f75722062617365206172652062656c6f6e6720746f2075732e"); assert_eq!(encoded, expected); - assert_eq!(bytes.to_vec(), expected.to_vec()); + } + + #[test] + fn mem_take() { + let data = b"All your base are belong to us."; + let mut bytes = WriteableBytes::::new(data.to_vec()); + + let taken = std::mem::take(&mut bytes); + assert_eq!(taken.as_slice(), data); } }