diff --git a/Cargo.toml b/Cargo.toml index 1d04d43..422ec7a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "smol_str" -version = "0.2.1" +version = "0.2.2" description = "small-string optimized string type with O(1) clone" license = "MIT OR Apache-2.0" repository = "https://github.com/rust-analyzer/smol_str" @@ -10,6 +10,7 @@ edition = "2018" [dependencies] serde = { version = "1.0.136", optional = true, default_features = false } arbitrary = { version = "1.1.0", optional = true } +borsh = { version = "1.4.0", features = ["unstable__schema"] , optional = true } [dev-dependencies] proptest = "1.0.0" diff --git a/src/lib.rs b/src/lib.rs index 375a4a5..9087dc2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -804,3 +804,65 @@ mod serde { } } } + +#[cfg(feature = "borsh")] +mod borsh { + use crate::{Repr, SmolStr, INLINE_CAP}; + use alloc::{ + collections::BTreeMap, + string::{String, ToString}, + }; + use borsh::{ + io::{Error, ErrorKind, Read, Write}, + schema::{Declaration, Definition}, + BorshDeserialize, BorshSchema, BorshSerialize, + }; + use core::intrinsics::transmute; + + impl BorshSerialize for SmolStr { + fn serialize(&self, writer: &mut W) -> borsh::io::Result<()> { + self.as_str().serialize(writer) + } + } + + impl BorshDeserialize for SmolStr { + #[inline] + fn deserialize_reader(reader: &mut R) -> borsh::io::Result { + let len = u32::deserialize_reader(reader)?; + if (len as usize) < INLINE_CAP { + let mut buf = [0u8; INLINE_CAP]; + reader.read_exact(&mut buf[..len as usize])?; + _ = core::str::from_utf8(&buf[..len as usize]).map_err(|err| { + let msg = err.to_string(); + Error::new(ErrorKind::InvalidData, msg) + })?; + Ok(SmolStr(Repr::Inline { + len: unsafe { transmute(len as u8) }, + buf, + })) + } else { + // u8::vec_from_reader always returns Some on success in current implementation + let vec = u8::vec_from_reader(len, reader)?.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "u8::vec_from_reader unexpectedly returned None".to_string(), + ) + })?; + Ok(SmolStr::from(String::from_utf8(vec).map_err(|err| { + let msg = err.to_string(); + Error::new(ErrorKind::InvalidData, msg) + })?)) + } + } + } + + impl BorshSchema for SmolStr { + fn add_definitions_recursively(definitions: &mut BTreeMap) { + str::add_definitions_recursively(definitions) + } + + fn declaration() -> Declaration { + str::declaration() + } + } +} diff --git a/tests/test.rs b/tests/test.rs index 11b7df7..25a4a3b 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -313,3 +313,57 @@ mod test_str_ext { assert!(!uppercase.is_heap_allocated()); } } + +#[cfg(feature = "borsh")] +mod borsh_tests { + use borsh::BorshDeserialize; + use smol_str::{SmolStr, ToSmolStr}; + use std::io::Cursor; + + #[test] + fn borsh_serialize_stack() { + let smolstr_on_stack = "aßΔCaßδc".to_smolstr(); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&smolstr_on_stack, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let decoded: SmolStr = borsh::BorshDeserialize::deserialize_reader(&mut cursor).unwrap(); + assert_eq!(smolstr_on_stack, decoded); + } + + #[test] + fn borsh_serialize_heap() { + let smolstr_on_heap = "aßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδc".to_smolstr(); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&smolstr_on_heap, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let decoded: SmolStr = borsh::BorshDeserialize::deserialize_reader(&mut cursor).unwrap(); + assert_eq!(smolstr_on_heap, decoded); + } + + #[test] + fn borsh_non_utf8_stack() { + let invalid_utf8: Vec = vec![0xF0, 0x9F, 0x8F]; // Incomplete UTF-8 sequence + + let wrong_utf8 = SmolStr::from(unsafe { String::from_utf8_unchecked(invalid_utf8) }); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&wrong_utf8, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let result = SmolStr::deserialize_reader(&mut cursor); + assert!(result.is_err()); + } + + #[test] + fn borsh_non_utf8_heap() { + let invalid_utf8: Vec = vec![ + 0xC1, 0x8A, 0x5F, 0xE2, 0x3A, 0x9E, 0x3B, 0xAA, 0x01, 0x08, 0x6F, 0x2F, 0xC0, 0x32, + 0xAB, 0xE1, 0x9A, 0x2F, 0x4A, 0x3F, 0x25, 0x0D, 0x8A, 0x2A, 0x19, 0x11, 0xF0, 0x7F, + 0x0E, 0x80, + ]; + let wrong_utf8 = SmolStr::from(unsafe { String::from_utf8_unchecked(invalid_utf8) }); + let mut buffer = Vec::new(); + borsh::BorshSerialize::serialize(&wrong_utf8, &mut buffer).unwrap(); + let mut cursor = Cursor::new(buffer); + let result = SmolStr::deserialize_reader(&mut cursor); + assert!(result.is_err()); + } +}