Unnamed repository; edit this file 'description' to name the repository.
| -rw-r--r-- | lib/smol_str/src/borsh.rs | 7 | ||||
| -rw-r--r-- | lib/smol_str/src/lib.rs | 258 | ||||
| -rw-r--r-- | lib/smol_str/src/serde.rs | 2 | ||||
| -rw-r--r-- | lib/smol_str/tests/test.rs | 24 |
4 files changed, 247 insertions, 44 deletions
diff --git a/lib/smol_str/src/borsh.rs b/lib/smol_str/src/borsh.rs index b684a4910c..44ae513ed4 100644 --- a/lib/smol_str/src/borsh.rs +++ b/lib/smol_str/src/borsh.rs @@ -16,7 +16,7 @@ impl BorshDeserialize for SmolStr { #[inline] fn deserialize_reader<R: Read>(reader: &mut R) -> borsh::io::Result<Self> { let len = u32::deserialize_reader(reader)?; - if (len as usize) < INLINE_CAP { + if (len as usize) <= INLINE_CAP { let mut buf = [0u8; INLINE_CAP]; reader.read_exact(&mut buf[..len as usize])?; _ = core::str::from_utf8(&buf[..len as usize]).map_err(|err| { @@ -29,9 +29,8 @@ impl BorshDeserialize for SmolStr { })) } else { // u8::vec_from_reader always returns Some on success in current implementation - let vec = u8::vec_from_reader(len, reader)?.ok_or_else(|| { - Error::new(ErrorKind::Other, "u8::vec_from_reader unexpectedly returned None") - })?; + let vec = u8::vec_from_reader(len, reader)? + .ok_or_else(|| Error::other("u8::vec_from_reader unexpectedly returned None"))?; Ok(SmolStr::from(String::from_utf8(vec).map_err(|err| { let msg = err.to_string(); Error::new(ErrorKind::InvalidData, msg) diff --git a/lib/smol_str/src/lib.rs b/lib/smol_str/src/lib.rs index 0d1f01a32b..55ede286c2 100644 --- a/lib/smol_str/src/lib.rs +++ b/lib/smol_str/src/lib.rs @@ -34,13 +34,17 @@ use core::{ pub struct SmolStr(Repr); impl SmolStr { + /// The maximum byte length of a string that can be stored inline + /// without heap allocation. + pub const INLINE_CAP: usize = INLINE_CAP; + /// Constructs an inline variant of `SmolStr`. /// /// This never allocates. /// /// # Panics /// - /// Panics if `text.len() > 23`. + /// Panics if `text.len() > `[`SmolStr::INLINE_CAP`]. #[inline] pub const fn new_inline(text: &str) -> SmolStr { assert!(text.len() <= INLINE_CAP); // avoids bounds checks in loop @@ -100,6 +104,24 @@ impl SmolStr { pub const fn is_heap_allocated(&self) -> bool { matches!(self.0, Repr::Heap(..)) } + + /// Constructs a `SmolStr` from a byte slice, returning an error if the slice is not valid + /// UTF-8. + #[inline] + pub fn from_utf8(bytes: &[u8]) -> Result<SmolStr, core::str::Utf8Error> { + core::str::from_utf8(bytes).map(SmolStr::new) + } + + /// Constructs a `SmolStr` from a byte slice without checking that the bytes are valid UTF-8. + /// + /// # Safety + /// + /// `bytes` must be valid UTF-8. + #[inline] + pub unsafe fn from_utf8_unchecked(bytes: &[u8]) -> SmolStr { + // SAFETY: caller guarantees bytes are valid UTF-8 + SmolStr::new(unsafe { core::str::from_utf8_unchecked(bytes) }) + } } impl Clone for SmolStr { @@ -116,7 +138,10 @@ impl Clone for SmolStr { return cold_clone(self); } - // SAFETY: We verified that the payload of `Repr` is a POD + // SAFETY: The non-heap variants (`Repr::Inline` and `Repr::Static`) contain only + // `Copy` data (a `[u8; 23]` + `InlineSize` enum, or a `&'static str` fat pointer) + // and carry no drop glue, so a raw `ptr::read` bitwise copy is sound. + // The heap variant (`Repr::Heap`) is excluded above. unsafe { core::ptr::read(self as *const SmolStr) } } } @@ -142,7 +167,12 @@ impl ops::Deref for SmolStr { impl Eq for SmolStr {} impl PartialEq<SmolStr> for SmolStr { fn eq(&self, other: &SmolStr) -> bool { - self.0.ptr_eq(&other.0) || self.as_str() == other.as_str() + match (&self.0, &other.0) { + (Repr::Inline { len: l_len, buf: l_buf }, Repr::Inline { len: r_len, buf: r_buf }) => { + l_len == r_len && l_buf == r_buf + } + _ => self.as_str() == other.as_str(), + } } } @@ -215,6 +245,48 @@ impl PartialOrd for SmolStr { } } +impl PartialOrd<str> for SmolStr { + fn partial_cmp(&self, other: &str) -> Option<Ordering> { + Some(self.as_str().cmp(other)) + } +} + +impl<'a> PartialOrd<&'a str> for SmolStr { + fn partial_cmp(&self, other: &&'a str) -> Option<Ordering> { + Some(self.as_str().cmp(*other)) + } +} + +impl PartialOrd<SmolStr> for &str { + fn partial_cmp(&self, other: &SmolStr) -> Option<Ordering> { + Some((*self).cmp(other.as_str())) + } +} + +impl PartialOrd<String> for SmolStr { + fn partial_cmp(&self, other: &String) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + +impl PartialOrd<SmolStr> for String { + fn partial_cmp(&self, other: &SmolStr) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + +impl<'a> PartialOrd<&'a String> for SmolStr { + fn partial_cmp(&self, other: &&'a String) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + +impl PartialOrd<SmolStr> for &String { + fn partial_cmp(&self, other: &SmolStr) -> Option<Ordering> { + Some(self.as_str().cmp(other.as_str())) + } +} + impl hash::Hash for SmolStr { fn hash<H: hash::Hasher>(&self, hasher: &mut H) { self.as_str().hash(hasher); @@ -359,6 +431,20 @@ impl AsRef<std::path::Path> for SmolStr { } } +impl From<char> for SmolStr { + #[inline] + fn from(c: char) -> SmolStr { + let mut buf = [0; INLINE_CAP]; + let len = c.len_utf8(); + c.encode_utf8(&mut buf); + SmolStr(Repr::Inline { + // SAFETY: A char is at most 4 bytes, which is always <= INLINE_CAP (23). + len: unsafe { InlineSize::transmute_from_u8(len as u8) }, + buf, + }) + } +} + impl From<&str> for SmolStr { #[inline] fn from(s: &str) -> SmolStr { @@ -483,11 +569,15 @@ enum InlineSize { } impl InlineSize { - /// SAFETY: `value` must be less than or equal to [`INLINE_CAP`] + /// # Safety + /// + /// `value` must be in the range `0..=23` (i.e. a valid `InlineSize` discriminant). + /// Values outside this range would produce an invalid enum discriminant, which is UB. #[inline(always)] const unsafe fn transmute_from_u8(value: u8) -> Self { debug_assert!(value <= InlineSize::_V23 as u8); - // SAFETY: The caller is responsible to uphold this invariant + // SAFETY: The caller guarantees `value` is a valid discriminant for this + // `#[repr(u8)]` enum (0..=23), so the transmute produces a valid `InlineSize`. unsafe { mem::transmute::<u8, Self>(value) } } } @@ -563,24 +653,15 @@ impl Repr { Repr::Static(data) => data, Repr::Inline { len, buf } => { let len = *len as usize; - // SAFETY: len is guaranteed to be <= INLINE_CAP + // SAFETY: `len` is an `InlineSize` discriminant (0..=23) which is always + // <= INLINE_CAP (23), so `..len` is always in bounds of `buf: [u8; 23]`. let buf = unsafe { buf.get_unchecked(..len) }; - // SAFETY: buf is guaranteed to be valid utf8 for ..len bytes + // SAFETY: All constructors that produce `Repr::Inline` copy from valid + // UTF-8 sources (`&str` or char encoding), so `buf[..len]` is valid UTF-8. unsafe { ::core::str::from_utf8_unchecked(buf) } } } } - - fn ptr_eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Heap(l0), Self::Heap(r0)) => Arc::ptr_eq(l0, r0), - (Self::Static(l0), Self::Static(r0)) => core::ptr::eq(l0, r0), - (Self::Inline { len: l_len, buf: l_buf }, Self::Inline { len: r_len, buf: r_buf }) => { - l_len == r_len && l_buf == r_buf - } - _ => false, - } - } } /// Convert value to [`SmolStr`] using [`fmt::Display`], potentially without allocating. @@ -666,7 +747,7 @@ impl StrExt for str { buf[..len].copy_from_slice(self.as_bytes()); buf[..len].make_ascii_lowercase(); SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `len` is guarded to be <= INLINE_CAP (23), a valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(len as u8) }, buf, }) @@ -683,7 +764,7 @@ impl StrExt for str { buf[..len].copy_from_slice(self.as_bytes()); buf[..len].make_ascii_uppercase(); SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `len` is guarded to be <= INLINE_CAP (23), a valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(len as u8) }, buf, }) @@ -703,8 +784,11 @@ impl StrExt for str { if let [from_u8] = from.as_bytes() && let [to_u8] = to.as_bytes() { + // SAFETY: `from` and `to` are single-byte `&str`s. In valid UTF-8, a single-byte + // code unit is always in the range 0x00..=0x7F (i.e. ASCII). The closure only + // replaces the matching ASCII byte with another ASCII byte, and returns all + // other bytes unchanged, so UTF-8 validity is preserved. return if self.len() <= count { - // SAFETY: `from_u8` & `to_u8` are ascii unsafe { replacen_1_ascii(self, |b| if b == from_u8 { *to_u8 } else { *b }) } } else { unsafe { @@ -736,7 +820,11 @@ impl StrExt for str { } } -/// SAFETY: `map` fn must only replace ascii with ascii or return unchanged bytes. +/// # Safety +/// +/// `map` must satisfy: for every byte `b` in `src`, if `b <= 0x7F` (ASCII) then `map(b)` must +/// also be `<= 0x7F` (ASCII). If `b > 0x7F` (part of a multi-byte UTF-8 sequence), `map` must +/// return `b` unchanged. This ensures the output is valid UTF-8 whenever the input is. #[inline] unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr { if src.len() <= INLINE_CAP { @@ -745,13 +833,16 @@ unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr buf[idx] = map(b); } SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `src` is a `&str` so `src.len()` <= INLINE_CAP <= 23, which is a + // valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(src.len() as u8) }, buf, }) } else { let out = src.as_bytes().iter().map(map).collect(); - // SAFETY: We replaced ascii with ascii on valid utf8 strings. + // SAFETY: The caller guarantees `map` only substitutes ASCII bytes with ASCII + // bytes and leaves multi-byte UTF-8 continuation bytes untouched, so the + // output byte sequence is valid UTF-8. unsafe { String::from_utf8_unchecked(out).into() } } } @@ -773,9 +864,11 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C let mut is_ascii = [false; N]; while slice.len() >= N { - // SAFETY: checked in loop condition + // SAFETY: The loop condition guarantees `slice.len() >= N`, so `..N` is in bounds. let chunk = unsafe { slice.get_unchecked(..N) }; - // SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets + // SAFETY: `out_slice` starts with the same length as `slice` (both derived from + // `s.len()`) and both are advanced by the same offset `N` each iteration, so + // `out_slice.len() >= N` holds whenever `slice.len() >= N`. let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) }; for j in 0..N { @@ -794,6 +887,7 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C out_chunk[j] = convert(&chunk[j]); } + // SAFETY: Same reasoning as above — both slices have len >= N at this point. slice = unsafe { slice.get_unchecked(N..) }; out_slice = unsafe { out_slice.get_unchecked_mut(N..) }; } @@ -804,7 +898,9 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C if byte > 127 { break; } - // SAFETY: out_slice has at least same length as input slice + // SAFETY: `out_slice` is always the same length as `slice` (both start equal and + // are advanced by 1 together), and `slice` is non-empty per the loop condition, + // so index 0 and `1..` are in bounds for both. unsafe { *out_slice.get_unchecked_mut(0) = convert(&byte); } @@ -813,8 +909,10 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C } unsafe { - // SAFETY: we know this is a valid char boundary - // since we only skipped over leading ascii bytes + // SAFETY: We only advanced past bytes that satisfy `b <= 127`, i.e. ASCII bytes. + // In UTF-8, ASCII bytes (0x00..=0x7F) are always single-byte code points and + // never appear as continuation bytes, so the remaining `slice` starts at a valid + // UTF-8 char boundary. let rest = core::str::from_utf8_unchecked(slice); (out, rest) } @@ -850,10 +948,18 @@ macro_rules! format_smolstr { /// A builder that can be used to efficiently build a [`SmolStr`]. /// /// This won't allocate if the final string fits into the inline buffer. -#[derive(Clone, Default, Debug, PartialEq, Eq)] +#[derive(Clone, Default, Debug)] pub struct SmolStrBuilder(SmolStrBuilderRepr); -#[derive(Clone, Debug, PartialEq, Eq)] +impl PartialEq for SmolStrBuilder { + fn eq(&self, other: &Self) -> bool { + self.as_str() == other.as_str() + } +} + +impl Eq for SmolStrBuilder {} + +#[derive(Clone, Debug)] enum SmolStrBuilderRepr { Inline { len: usize, buf: [u8; INLINE_CAP] }, Heap(String), @@ -873,11 +979,57 @@ impl SmolStrBuilder { Self(SmolStrBuilderRepr::Inline { buf: [0; INLINE_CAP], len: 0 }) } + /// Creates a new empty [`SmolStrBuilder`] with at least the specified capacity. + /// + /// If `capacity` is less than or equal to [`SmolStr::INLINE_CAP`], the builder + /// will use inline storage and not allocate. Otherwise, it will pre-allocate a + /// heap buffer of the requested capacity. + #[must_use] + pub fn with_capacity(capacity: usize) -> Self { + if capacity <= INLINE_CAP { + Self::new() + } else { + Self(SmolStrBuilderRepr::Heap(String::with_capacity(capacity))) + } + } + + /// Returns the number of bytes accumulated in the builder so far. + #[inline] + pub fn len(&self) -> usize { + match &self.0 { + SmolStrBuilderRepr::Inline { len, .. } => *len, + SmolStrBuilderRepr::Heap(heap) => heap.len(), + } + } + + /// Returns `true` if the builder has a length of zero bytes. + #[inline] + pub fn is_empty(&self) -> bool { + match &self.0 { + SmolStrBuilderRepr::Inline { len, .. } => *len == 0, + SmolStrBuilderRepr::Heap(heap) => heap.is_empty(), + } + } + + /// Returns a `&str` slice of the builder's current contents. + #[inline] + pub fn as_str(&self) -> &str { + match &self.0 { + SmolStrBuilderRepr::Inline { len, buf } => { + // SAFETY: `buf[..*len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8, and `*len <= INLINE_CAP` is maintained + // by the inline branch logic. + unsafe { core::str::from_utf8_unchecked(&buf[..*len]) } + } + SmolStrBuilderRepr::Heap(heap) => heap.as_str(), + } + } + /// Builds a [`SmolStr`] from `self`. #[must_use] - pub fn finish(&self) -> SmolStr { - SmolStr(match &self.0 { - &SmolStrBuilderRepr::Inline { len, buf } => { + pub fn finish(self) -> SmolStr { + SmolStr(match self.0 { + SmolStrBuilderRepr::Inline { len, buf } => { debug_assert!(len <= INLINE_CAP); Repr::Inline { // SAFETY: We know that `value.len` is less than or equal to the maximum value of `InlineSize` @@ -885,7 +1037,7 @@ impl SmolStrBuilder { buf, } } - SmolStrBuilderRepr::Heap(heap) => Repr::new(heap), + SmolStrBuilderRepr::Heap(heap) => Repr::new(&heap), }) } @@ -900,8 +1052,10 @@ impl SmolStrBuilder { *len += char_len; } else { let mut heap = String::with_capacity(new_len); - // copy existing inline bytes over to the heap - // SAFETY: inline data is guaranteed to be valid utf8 for `old_len` bytes + // SAFETY: `buf[..*len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8 (from `char::encode_utf8` or `&str` + // byte copies), so extending the Vec with these bytes preserves the + // String's UTF-8 invariant. unsafe { heap.as_mut_vec().extend_from_slice(&buf[..*len]) }; heap.push(c); self.0 = SmolStrBuilderRepr::Heap(heap); @@ -926,8 +1080,10 @@ impl SmolStrBuilder { let mut heap = String::with_capacity(*len); - // copy existing inline bytes over to the heap - // SAFETY: inline data is guaranteed to be valid utf8 for `old_len` bytes + // SAFETY: `buf[..old_len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8 (from `char::encode_utf8` or `&str` byte + // copies), so extending the Vec with these bytes preserves the String's + // UTF-8 invariant. unsafe { heap.as_mut_vec().extend_from_slice(&buf[..old_len]) }; heap.push_str(s); self.0 = SmolStrBuilderRepr::Heap(heap); @@ -945,6 +1101,30 @@ impl fmt::Write for SmolStrBuilder { } } +impl iter::Extend<char> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = char>>(&mut self, iter: I) { + for c in iter { + self.push(c); + } + } +} + +impl<'a> iter::Extend<&'a str> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = &'a str>>(&mut self, iter: I) { + for s in iter { + self.push_str(s); + } + } +} + +impl<'a> iter::Extend<&'a String> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = &'a String>>(&mut self, iter: I) { + for s in iter { + self.push_str(s); + } + } +} + impl From<SmolStrBuilder> for SmolStr { fn from(value: SmolStrBuilder) -> Self { value.finish() diff --git a/lib/smol_str/src/serde.rs b/lib/smol_str/src/serde.rs index 66cbcd3bad..9d82d64805 100644 --- a/lib/smol_str/src/serde.rs +++ b/lib/smol_str/src/serde.rs @@ -16,7 +16,7 @@ where impl<'a> Visitor<'a> for SmolStrVisitor { type Value = SmolStr; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { formatter.write_str("a string") } diff --git a/lib/smol_str/tests/test.rs b/lib/smol_str/tests/test.rs index 00fab2ee1c..83648edeec 100644 --- a/lib/smol_str/tests/test.rs +++ b/lib/smol_str/tests/test.rs @@ -10,6 +10,7 @@ use smol_str::{SmolStr, SmolStrBuilder}; #[cfg(target_pointer_width = "64")] fn smol_str_is_smol() { assert_eq!(::std::mem::size_of::<SmolStr>(), ::std::mem::size_of::<String>(),); + assert_eq!(::std::mem::size_of::<Option<SmolStr>>(), ::std::mem::size_of::<SmolStr>(),); } #[test] @@ -332,6 +333,29 @@ fn test_builder_push() { assert_eq!("a".repeat(24), s); } +#[test] +fn test_from_char() { + // ASCII char + let s: SmolStr = 'a'.into(); + assert_eq!(s, "a"); + assert!(!s.is_heap_allocated()); + + // Multi-byte char (2 bytes) + let s: SmolStr = SmolStr::from('ñ'); + assert_eq!(s, "ñ"); + assert!(!s.is_heap_allocated()); + + // 3-byte char + let s: SmolStr = '€'.into(); + assert_eq!(s, "€"); + assert!(!s.is_heap_allocated()); + + // 4-byte char (emoji) + let s: SmolStr = '🦀'.into(); + assert_eq!(s, "🦀"); + assert!(!s.is_heap_allocated()); +} + #[cfg(test)] mod test_str_ext { use smol_str::StrExt; |