From 295e0b90d126eb20bb3173e98cfb5e7f55074fd4 Mon Sep 17 00:00:00 2001 From: Lukas Wirth Date: Mon, 2 Sep 2024 21:36:34 +0200 Subject: [PATCH] Various cleanups --- Cargo.toml | 1 - src/lib.rs | 200 ++++++++++++++++++++++++-------------------------- tests/test.rs | 20 +---- 3 files changed, 95 insertions(+), 126 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 659c797..c1e34e7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,4 +23,3 @@ serde = { version = "1.0", features = ["derive"] } [features] default = ["std"] std = ["serde?/std"] -serde = ["dep:serde"] diff --git a/src/lib.rs b/src/lib.rs index ca9944c..9845e4d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,9 +8,7 @@ use core::{ borrow::Borrow, cmp::{self, Ordering}, convert::Infallible, - fmt, hash, iter, - mem::transmute, - ops::Deref, + fmt, hash, iter, mem, ops, str::FromStr, }; @@ -34,53 +32,23 @@ use core::{ /// `WS`: A string of 32 newlines followed by 128 spaces. pub struct SmolStr(Repr); -impl Clone for SmolStr { - #[inline] - fn clone(&self) -> Self { - if !self.is_heap_allocated() { - return unsafe { core::ptr::read(self as *const SmolStr) }; - } - Self(self.0.clone()) - } -} - impl SmolStr { - #[deprecated = "Use `new_inline` instead"] - pub const fn new_inline_from_ascii(len: usize, bytes: &[u8]) -> SmolStr { - assert!(len <= INLINE_CAP); - - const ZEROS: &[u8] = &[0; INLINE_CAP]; - - let mut buf = [0; INLINE_CAP]; - macro_rules! s { - ($($idx:literal),*) => ( $(s!(set $idx);)* ); - (set $idx:literal) => ({ - let src: &[u8] = [ZEROS, bytes][($idx < len) as usize]; - let byte = src[$idx]; - let _is_ascii = [(); 128][byte as usize]; - buf[$idx] = byte - }); - } - s!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22); - SmolStr(Repr::Inline { - // SAFETY: We know that `len` is less than or equal to the maximum value of `InlineSize` - // as we asserted it. - len: unsafe { InlineSize::transmute_from_u8(len as u8) }, - buf, - }) - } - - /// Constructs inline variant of `SmolStr`. + /// Constructs an inline variant of `SmolStr`. + /// + /// This never allocates. + /// + /// # Panics /// /// Panics if `text.len() > 23`. #[inline] pub const fn new_inline(text: &str) -> SmolStr { - assert!(text.len() <= INLINE_CAP); // avoids checks in loop + assert!(text.len() <= INLINE_CAP); // avoids bounds checks in loop + let text = text.as_bytes(); let mut buf = [0; INLINE_CAP]; let mut i = 0; while i < text.len() { - buf[i] = text.as_bytes()[i]; + buf[i] = text[i]; i += 1 } SmolStr(Repr::Inline { @@ -102,68 +70,45 @@ impl SmolStr { SmolStr(Repr::Static(text)) } - pub fn new(text: T) -> SmolStr - where - T: AsRef, - { - SmolStr(Repr::new(text)) + /// Constructs a `SmolStr` from a `str`, heap-allocating if necessary. + #[inline(always)] + pub fn new(text: impl AsRef) -> SmolStr { + SmolStr(Repr::new(text.as_ref())) } + /// Returns a `&str` slice of this `SmolStr`. #[inline(always)] pub fn as_str(&self) -> &str { self.0.as_str() } - #[allow(clippy::inherent_to_string_shadow_display)] - #[inline(always)] - pub fn to_string(&self) -> String { - use alloc::borrow::ToOwned; - - self.as_str().to_owned() - } - + /// Returns the length of `self` in bytes. #[inline(always)] pub fn len(&self) -> usize { self.0.len() } + /// Returns `true` if `self` has a length of zero bytes. #[inline(always)] pub fn is_empty(&self) -> bool { self.0.is_empty() } + /// Returns `true` if `self` is heap-allocated. #[inline(always)] pub const fn is_heap_allocated(&self) -> bool { matches!(self.0, Repr::Heap(..)) } +} - fn from_char_iter>(mut iter: I) -> SmolStr { - let (min_size, _) = iter.size_hint(); - if min_size > INLINE_CAP { - let heap: String = iter.collect(); - return SmolStr(Repr::Heap(heap.into_boxed_str().into())); - } - let mut len = 0; - let mut buf = [0u8; INLINE_CAP]; - while let Some(ch) = iter.next() { - let size = ch.len_utf8(); - if size + len > INLINE_CAP { - let (min_remaining, _) = iter.size_hint(); - let mut heap = String::with_capacity(size + len + min_remaining); - heap.push_str(core::str::from_utf8(&buf[..len]).unwrap()); - heap.push(ch); - heap.extend(iter); - return SmolStr(Repr::Heap(heap.into_boxed_str().into())); - } - ch.encode_utf8(&mut buf[len..]); - len += size; +impl Clone for SmolStr { + #[inline] + fn clone(&self) -> Self { + if !self.is_heap_allocated() { + // SAFETY: We verified that the payload of `Repr` is a POD + return unsafe { core::ptr::read(self as *const SmolStr) }; } - SmolStr(Repr::Inline { - // SAFETY: We know that `len` is less than or equal to the maximum value of `InlineSize` - // as we otherwise return early. - len: unsafe { InlineSize::transmute_from_u8(len as u8) }, - buf, - }) + Self(self.0.clone()) } } @@ -177,7 +122,7 @@ impl Default for SmolStr { } } -impl Deref for SmolStr { +impl ops::Deref for SmolStr { type Target = str; #[inline(always)] @@ -186,61 +131,71 @@ impl Deref for SmolStr { } } +// region: PartialEq implementations + +impl Eq for SmolStr {} impl PartialEq for SmolStr { fn eq(&self, other: &SmolStr) -> bool { self.0.ptr_eq(&other.0) || self.as_str() == other.as_str() } } -impl Eq for SmolStr {} - impl PartialEq for SmolStr { + #[inline(always)] fn eq(&self, other: &str) -> bool { self.as_str() == other } } impl PartialEq for str { + #[inline(always)] fn eq(&self, other: &SmolStr) -> bool { other == self } } impl<'a> PartialEq<&'a str> for SmolStr { + #[inline(always)] fn eq(&self, other: &&'a str) -> bool { self == *other } } impl<'a> PartialEq for &'a str { + #[inline(always)] fn eq(&self, other: &SmolStr) -> bool { *self == other } } impl PartialEq for SmolStr { + #[inline(always)] fn eq(&self, other: &String) -> bool { self.as_str() == other } } impl PartialEq for String { + #[inline(always)] fn eq(&self, other: &SmolStr) -> bool { other == self } } impl<'a> PartialEq<&'a String> for SmolStr { + #[inline(always)] fn eq(&self, other: &&'a String) -> bool { self == *other } } impl<'a> PartialEq for &'a String { + #[inline(always)] fn eq(&self, other: &SmolStr) -> bool { *self == other } } +// endregion: PartialEq implementations impl Ord for SmolStr { fn cmp(&self, other: &SmolStr) -> Ordering { @@ -274,9 +229,41 @@ impl fmt::Display for SmolStr { impl iter::FromIterator for SmolStr { fn from_iter>(iter: I) -> SmolStr { - let iter = iter.into_iter(); - Self::from_char_iter(iter) + from_char_iter(iter.into_iter()) + } +} + +fn from_char_iter(mut iter: impl Iterator) -> SmolStr { + let (min_size, _) = iter.size_hint(); + if min_size > INLINE_CAP { + let heap: String = iter.collect(); + if heap.len() <= INLINE_CAP { + // size hint lied + return SmolStr::new_inline(&heap); + } + return SmolStr(Repr::Heap(heap.into_boxed_str().into())); + } + let mut len = 0; + let mut buf = [0u8; INLINE_CAP]; + while let Some(ch) = iter.next() { + let size = ch.len_utf8(); + if size + len > INLINE_CAP { + let (min_remaining, _) = iter.size_hint(); + let mut heap = String::with_capacity(size + len + min_remaining); + heap.push_str(core::str::from_utf8(&buf[..len]).unwrap()); + heap.push(ch); + heap.extend(iter); + return SmolStr(Repr::Heap(heap.into_boxed_str().into())); + } + ch.encode_utf8(&mut buf[len..]); + len += size; } + SmolStr(Repr::Inline { + // SAFETY: We know that `len` is less than or equal to the maximum value of `InlineSize` + // as we otherwise return early. + len: unsafe { InlineSize::transmute_from_u8(len as u8) }, + buf, + }) } fn build_from_str_iter(mut iter: impl Iterator) -> SmolStr @@ -415,14 +402,6 @@ impl FromStr for SmolStr { } } -#[cfg(feature = "arbitrary")] -impl<'a> arbitrary::Arbitrary<'a> for SmolStr { - fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> Result { - let s = <&str>::arbitrary(u)?; - Ok(SmolStr::new(s)) - } -} - const INLINE_CAP: usize = InlineSize::_V23 as usize; const N_NEWLINES: usize = 32; const N_SPACES: usize = 128; @@ -434,6 +413,7 @@ const _: () = { assert!(WS.as_bytes()[N_NEWLINES] == b' '); }; +/// A [`u8`] with a bunch of niches. #[derive(Clone, Copy, Debug, PartialEq)] #[repr(u8)] enum InlineSize { @@ -464,10 +444,12 @@ enum InlineSize { } impl InlineSize { + /// SAFETY: `value` must be less than or equal to [`INLINE_CAP`] #[inline(always)] const unsafe fn transmute_from_u8(value: u8) -> Self { debug_assert!(value <= InlineSize::_V23 as u8); - unsafe { transmute::(value) } + // SAFETY: The caller is responsible to uphold this invariant + unsafe { mem::transmute::(value) } } } @@ -518,11 +500,8 @@ impl Repr { None } - fn new(text: T) -> Self - where - T: AsRef, - { - Self::new_on_stack(text.as_ref()).unwrap_or_else(|| Repr::Heap(text.as_ref().into())) + fn new(text: &str) -> Self { + Self::new_on_stack(text).unwrap_or_else(|| Repr::Heap(Arc::from(text))) } #[inline(always)] @@ -539,7 +518,7 @@ impl Repr { match self { Repr::Heap(data) => data.is_empty(), Repr::Static(data) => data.is_empty(), - Repr::Inline { len, .. } => *len as u8 == 0, + &Repr::Inline { len, .. } => len as u8 == 0, } } @@ -550,7 +529,8 @@ impl Repr { Repr::Static(data) => data, Repr::Inline { len, buf } => { let len = *len as usize; - let buf = &buf[..len]; + // SAFETY: len is guaranteed to be <= INLINE_CAP + let buf = unsafe { buf.get_unchecked(..len) }; // SAFETY: buf is guaranteed to be valid utf8 for ..len bytes unsafe { ::core::str::from_utf8_unchecked(buf) } } @@ -633,22 +613,22 @@ pub trait StrExt: private::Sealed { impl StrExt for str { #[inline] fn to_lowercase_smolstr(&self) -> SmolStr { - SmolStr::from_char_iter(self.chars().flat_map(|c| c.to_lowercase())) + from_char_iter(self.chars().flat_map(|c| c.to_lowercase())) } #[inline] fn to_uppercase_smolstr(&self) -> SmolStr { - SmolStr::from_char_iter(self.chars().flat_map(|c| c.to_uppercase())) + from_char_iter(self.chars().flat_map(|c| c.to_uppercase())) } #[inline] fn to_ascii_lowercase_smolstr(&self) -> SmolStr { - SmolStr::from_char_iter(self.chars().map(|c| c.to_ascii_lowercase())) + from_char_iter(self.chars().map(|c| c.to_ascii_lowercase())) } #[inline] fn to_ascii_uppercase_smolstr(&self) -> SmolStr { - SmolStr::from_char_iter(self.chars().map(|c| c.to_ascii_uppercase())) + from_char_iter(self.chars().map(|c| c.to_ascii_uppercase())) } #[inline] @@ -754,7 +734,7 @@ impl From for SmolStr { buf: value.inline, } } else { - Repr::new(value.heap) + Repr::new(&value.heap) }) } } @@ -768,5 +748,13 @@ where } } +#[cfg(feature = "arbitrary")] +impl<'a> arbitrary::Arbitrary<'a> for SmolStr { + fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> Result { + let s = <&str>::arbitrary(u)?; + Ok(SmolStr::new(s)) + } +} + #[cfg(feature = "serde")] mod serde; diff --git a/tests/test.rs b/tests/test.rs index 2e2914d..631f7d7 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -44,20 +44,6 @@ fn const_fn_ctor() { assert_eq!(LONG, SmolStr::from("ABCDEFGHIZKLMNOPQRSTUVW")); } -#[allow(deprecated)] -#[test] -fn old_const_fn_ctor() { - const EMPTY: SmolStr = SmolStr::new_inline_from_ascii(0, b""); - const A: SmolStr = SmolStr::new_inline_from_ascii(1, b"A"); - const HELLO: SmolStr = SmolStr::new_inline_from_ascii(5, b"HELLO"); - const LONG: SmolStr = SmolStr::new_inline_from_ascii(23, b"ABCDEFGHIZKLMNOPQRSTUVW"); - - assert_eq!(EMPTY, SmolStr::from("")); - assert_eq!(A, SmolStr::from("A")); - assert_eq!(HELLO, SmolStr::from("HELLO")); - assert_eq!(LONG, SmolStr::from("ABCDEFGHIZKLMNOPQRSTUVW")); -} - #[cfg(not(miri))] fn check_props(std_str: &str, smol: SmolStr) -> Result<(), proptest::test_runner::TestCaseError> { prop_assert_eq!(smol.as_str(), std_str); @@ -253,11 +239,7 @@ fn test_bad_size_hint_char_iter() { let collected: SmolStr = BadSizeHint(data.chars()).collect(); let new = SmolStr::new(data); - // Because of the bad size hint, `collected` will be heap allocated, but `new` will be inline - - // If we try to use the type of the string (inline/heap) to quickly test for equality, we need to ensure - // `collected` is inline allocated instead - assert!(collected.is_heap_allocated()); + assert!(!collected.is_heap_allocated()); assert!(!new.is_heap_allocated()); assert_eq!(new, collected); }