Skip to content

Commit

Permalink
borsh feature
Browse files Browse the repository at this point in the history
  • Loading branch information
CorinJG committed Apr 17, 2024
1 parent 9971a3f commit 01cecbc
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 1 deletion.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "smol_str"
version = "0.2.1"
version = "0.2.2"
description = "small-string optimized string type with O(1) clone"
license = "MIT OR Apache-2.0"
repository = "https://github.com/rust-analyzer/smol_str"
Expand All @@ -10,6 +10,7 @@ edition = "2018"
[dependencies]
serde = { version = "1.0.136", optional = true, default_features = false }
arbitrary = { version = "1.1.0", optional = true }
borsh = { version = "1.4.0", features = ["unstable__schema"] , optional = true }

[dev-dependencies]
proptest = "1.0.0"
Expand Down
62 changes: 62 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,65 @@ mod serde {
}
}
}

#[cfg(feature = "borsh")]
mod borsh {
use crate::{Repr, SmolStr, INLINE_CAP};
use alloc::{
collections::BTreeMap,
string::{String, ToString},
};
use borsh::{
io::{Error, ErrorKind, Read, Write},
schema::{Declaration, Definition},
BorshDeserialize, BorshSchema, BorshSerialize,
};
use core::intrinsics::transmute;

impl BorshSerialize for SmolStr {
fn serialize<W: Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
self.as_str().serialize(writer)
}
}

impl BorshDeserialize for SmolStr {
#[inline]
fn deserialize_reader<R: Read>(reader: &mut R) -> borsh::io::Result<Self> {
let len = u32::deserialize_reader(reader)?;
if (len as usize) < INLINE_CAP {
let mut buf = [0u8; INLINE_CAP];
reader.read_exact(&mut buf[..len as usize])?;
_ = core::str::from_utf8(&buf[..len as usize]).map_err(|err| {
let msg = err.to_string();
Error::new(ErrorKind::InvalidData, msg)
})?;
Ok(SmolStr(Repr::Inline {
len: unsafe { transmute(len as u8) },
buf,
}))
} else {
// u8::vec_from_reader always returns Some on success in current implementation
let vec = u8::vec_from_reader(len, reader)?.ok_or_else(|| {
Error::new(
ErrorKind::Other,
"u8::vec_from_reader unexpectedly returned None".to_string(),
)
})?;
Ok(SmolStr::from(String::from_utf8(vec).map_err(|err| {
let msg = err.to_string();
Error::new(ErrorKind::InvalidData, msg)
})?))
}
}
}

impl BorshSchema for SmolStr {
fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) {
str::add_definitions_recursively(definitions)
}

fn declaration() -> Declaration {
str::declaration()
}
}
}
54 changes: 54 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -313,3 +313,57 @@ mod test_str_ext {
assert!(!uppercase.is_heap_allocated());
}
}

#[cfg(feature = "borsh")]
mod borsh_tests {
use borsh::BorshDeserialize;
use smol_str::{SmolStr, ToSmolStr};
use std::io::Cursor;

#[test]
fn borsh_serialize_stack() {
let smolstr_on_stack = "aßΔCaßδc".to_smolstr();
let mut buffer = Vec::new();
borsh::BorshSerialize::serialize(&smolstr_on_stack, &mut buffer).unwrap();
let mut cursor = Cursor::new(buffer);
let decoded: SmolStr = borsh::BorshDeserialize::deserialize_reader(&mut cursor).unwrap();
assert_eq!(smolstr_on_stack, decoded);
}

#[test]
fn borsh_serialize_heap() {
let smolstr_on_heap = "aßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδcaßΔCaßδc".to_smolstr();
let mut buffer = Vec::new();
borsh::BorshSerialize::serialize(&smolstr_on_heap, &mut buffer).unwrap();
let mut cursor = Cursor::new(buffer);
let decoded: SmolStr = borsh::BorshDeserialize::deserialize_reader(&mut cursor).unwrap();
assert_eq!(smolstr_on_heap, decoded);
}

#[test]
fn borsh_non_utf8_stack() {
let invalid_utf8: Vec<u8> = vec![0xF0, 0x9F, 0x8F]; // Incomplete UTF-8 sequence

let wrong_utf8 = SmolStr::from(unsafe { String::from_utf8_unchecked(invalid_utf8) });
let mut buffer = Vec::new();
borsh::BorshSerialize::serialize(&wrong_utf8, &mut buffer).unwrap();
let mut cursor = Cursor::new(buffer);
let result = SmolStr::deserialize_reader(&mut cursor);
assert!(result.is_err());
}

#[test]
fn borsh_non_utf8_heap() {
let invalid_utf8: Vec<u8> = vec![
0xC1, 0x8A, 0x5F, 0xE2, 0x3A, 0x9E, 0x3B, 0xAA, 0x01, 0x08, 0x6F, 0x2F, 0xC0, 0x32,
0xAB, 0xE1, 0x9A, 0x2F, 0x4A, 0x3F, 0x25, 0x0D, 0x8A, 0x2A, 0x19, 0x11, 0xF0, 0x7F,
0x0E, 0x80,
];
let wrong_utf8 = SmolStr::from(unsafe { String::from_utf8_unchecked(invalid_utf8) });
let mut buffer = Vec::new();
borsh::BorshSerialize::serialize(&wrong_utf8, &mut buffer).unwrap();
let mut cursor = Cursor::new(buffer);
let result = SmolStr::deserialize_reader(&mut cursor);
assert!(result.is_err());
}
}

0 comments on commit 01cecbc

Please sign in to comment.