Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'lib/smol_str/src/lib.rs')
| -rw-r--r-- | lib/smol_str/src/lib.rs | 138 |
1 files changed, 102 insertions, 36 deletions
diff --git a/lib/smol_str/src/lib.rs b/lib/smol_str/src/lib.rs index 0d1f01a32b..b30eefc098 100644 --- a/lib/smol_str/src/lib.rs +++ b/lib/smol_str/src/lib.rs @@ -100,6 +100,24 @@ impl SmolStr { pub const fn is_heap_allocated(&self) -> bool { matches!(self.0, Repr::Heap(..)) } + + /// Constructs a `SmolStr` from a byte slice, returning an error if the slice is not valid + /// UTF-8. + #[inline] + pub fn from_utf8(bytes: &[u8]) -> Result<SmolStr, core::str::Utf8Error> { + core::str::from_utf8(bytes).map(SmolStr::new) + } + + /// Constructs a `SmolStr` from a byte slice without checking that the bytes are valid UTF-8. + /// + /// # Safety + /// + /// `bytes` must be valid UTF-8. + #[inline] + pub unsafe fn from_utf8_unchecked(bytes: &[u8]) -> SmolStr { + // SAFETY: caller guarantees bytes are valid UTF-8 + SmolStr::new(unsafe { core::str::from_utf8_unchecked(bytes) }) + } } impl Clone for SmolStr { @@ -116,7 +134,10 @@ impl Clone for SmolStr { return cold_clone(self); } - // SAFETY: We verified that the payload of `Repr` is a POD + // SAFETY: The non-heap variants (`Repr::Inline` and `Repr::Static`) contain only + // `Copy` data (a `[u8; 23]` + `InlineSize` enum, or a `&'static str` fat pointer) + // and carry no drop glue, so a raw `ptr::read` bitwise copy is sound. + // The heap variant (`Repr::Heap`) is excluded above. unsafe { core::ptr::read(self as *const SmolStr) } } } @@ -142,7 +163,12 @@ impl ops::Deref for SmolStr { impl Eq for SmolStr {} impl PartialEq<SmolStr> for SmolStr { fn eq(&self, other: &SmolStr) -> bool { - self.0.ptr_eq(&other.0) || self.as_str() == other.as_str() + match (&self.0, &other.0) { + (Repr::Inline { len: l_len, buf: l_buf }, Repr::Inline { len: r_len, buf: r_buf }) => { + l_len == r_len && l_buf == r_buf + } + _ => self.as_str() == other.as_str(), + } } } @@ -483,11 +509,15 @@ enum InlineSize { } impl InlineSize { - /// SAFETY: `value` must be less than or equal to [`INLINE_CAP`] + /// # Safety + /// + /// `value` must be in the range `0..=23` (i.e. a valid `InlineSize` discriminant). + /// Values outside this range would produce an invalid enum discriminant, which is UB. #[inline(always)] const unsafe fn transmute_from_u8(value: u8) -> Self { debug_assert!(value <= InlineSize::_V23 as u8); - // SAFETY: The caller is responsible to uphold this invariant + // SAFETY: The caller guarantees `value` is a valid discriminant for this + // `#[repr(u8)]` enum (0..=23), so the transmute produces a valid `InlineSize`. unsafe { mem::transmute::<u8, Self>(value) } } } @@ -563,24 +593,15 @@ impl Repr { Repr::Static(data) => data, Repr::Inline { len, buf } => { let len = *len as usize; - // SAFETY: len is guaranteed to be <= INLINE_CAP + // SAFETY: `len` is an `InlineSize` discriminant (0..=23) which is always + // <= INLINE_CAP (23), so `..len` is always in bounds of `buf: [u8; 23]`. let buf = unsafe { buf.get_unchecked(..len) }; - // SAFETY: buf is guaranteed to be valid utf8 for ..len bytes + // SAFETY: All constructors that produce `Repr::Inline` copy from valid + // UTF-8 sources (`&str` or char encoding), so `buf[..len]` is valid UTF-8. unsafe { ::core::str::from_utf8_unchecked(buf) } } } } - - fn ptr_eq(&self, other: &Self) -> bool { - match (self, other) { - (Self::Heap(l0), Self::Heap(r0)) => Arc::ptr_eq(l0, r0), - (Self::Static(l0), Self::Static(r0)) => core::ptr::eq(l0, r0), - (Self::Inline { len: l_len, buf: l_buf }, Self::Inline { len: r_len, buf: r_buf }) => { - l_len == r_len && l_buf == r_buf - } - _ => false, - } - } } /// Convert value to [`SmolStr`] using [`fmt::Display`], potentially without allocating. @@ -666,7 +687,7 @@ impl StrExt for str { buf[..len].copy_from_slice(self.as_bytes()); buf[..len].make_ascii_lowercase(); SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `len` is guarded to be <= INLINE_CAP (23), a valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(len as u8) }, buf, }) @@ -683,7 +704,7 @@ impl StrExt for str { buf[..len].copy_from_slice(self.as_bytes()); buf[..len].make_ascii_uppercase(); SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `len` is guarded to be <= INLINE_CAP (23), a valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(len as u8) }, buf, }) @@ -703,8 +724,11 @@ impl StrExt for str { if let [from_u8] = from.as_bytes() && let [to_u8] = to.as_bytes() { + // SAFETY: `from` and `to` are single-byte `&str`s. In valid UTF-8, a single-byte + // code unit is always in the range 0x00..=0x7F (i.e. ASCII). The closure only + // replaces the matching ASCII byte with another ASCII byte, and returns all + // other bytes unchanged, so UTF-8 validity is preserved. return if self.len() <= count { - // SAFETY: `from_u8` & `to_u8` are ascii unsafe { replacen_1_ascii(self, |b| if b == from_u8 { *to_u8 } else { *b }) } } else { unsafe { @@ -736,7 +760,11 @@ impl StrExt for str { } } -/// SAFETY: `map` fn must only replace ascii with ascii or return unchanged bytes. +/// # Safety +/// +/// `map` must satisfy: for every byte `b` in `src`, if `b <= 0x7F` (ASCII) then `map(b)` must +/// also be `<= 0x7F` (ASCII). If `b > 0x7F` (part of a multi-byte UTF-8 sequence), `map` must +/// return `b` unchanged. This ensures the output is valid UTF-8 whenever the input is. #[inline] unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr { if src.len() <= INLINE_CAP { @@ -745,13 +773,16 @@ unsafe fn replacen_1_ascii(src: &str, mut map: impl FnMut(&u8) -> u8) -> SmolStr buf[idx] = map(b); } SmolStr(Repr::Inline { - // SAFETY: `len` is in bounds + // SAFETY: `src` is a `&str` so `src.len()` <= INLINE_CAP <= 23, which is a + // valid `InlineSize` discriminant. len: unsafe { InlineSize::transmute_from_u8(src.len() as u8) }, buf, }) } else { let out = src.as_bytes().iter().map(map).collect(); - // SAFETY: We replaced ascii with ascii on valid utf8 strings. + // SAFETY: The caller guarantees `map` only substitutes ASCII bytes with ASCII + // bytes and leaves multi-byte UTF-8 continuation bytes untouched, so the + // output byte sequence is valid UTF-8. unsafe { String::from_utf8_unchecked(out).into() } } } @@ -773,9 +804,11 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C let mut is_ascii = [false; N]; while slice.len() >= N { - // SAFETY: checked in loop condition + // SAFETY: The loop condition guarantees `slice.len() >= N`, so `..N` is in bounds. let chunk = unsafe { slice.get_unchecked(..N) }; - // SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets + // SAFETY: `out_slice` starts with the same length as `slice` (both derived from + // `s.len()`) and both are advanced by the same offset `N` each iteration, so + // `out_slice.len() >= N` holds whenever `slice.len() >= N`. let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) }; for j in 0..N { @@ -794,6 +827,7 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C out_chunk[j] = convert(&chunk[j]); } + // SAFETY: Same reasoning as above — both slices have len >= N at this point. slice = unsafe { slice.get_unchecked(N..) }; out_slice = unsafe { out_slice.get_unchecked_mut(N..) }; } @@ -804,7 +838,9 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C if byte > 127 { break; } - // SAFETY: out_slice has at least same length as input slice + // SAFETY: `out_slice` is always the same length as `slice` (both start equal and + // are advanced by 1 together), and `slice` is non-empty per the loop condition, + // so index 0 and `1..` are in bounds for both. unsafe { *out_slice.get_unchecked_mut(0) = convert(&byte); } @@ -813,8 +849,10 @@ fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_C } unsafe { - // SAFETY: we know this is a valid char boundary - // since we only skipped over leading ascii bytes + // SAFETY: We only advanced past bytes that satisfy `b <= 127`, i.e. ASCII bytes. + // In UTF-8, ASCII bytes (0x00..=0x7F) are always single-byte code points and + // never appear as continuation bytes, so the remaining `slice` starts at a valid + // UTF-8 char boundary. let rest = core::str::from_utf8_unchecked(slice); (out, rest) } @@ -875,9 +913,9 @@ impl SmolStrBuilder { /// Builds a [`SmolStr`] from `self`. #[must_use] - pub fn finish(&self) -> SmolStr { - SmolStr(match &self.0 { - &SmolStrBuilderRepr::Inline { len, buf } => { + pub fn finish(self) -> SmolStr { + SmolStr(match self.0 { + SmolStrBuilderRepr::Inline { len, buf } => { debug_assert!(len <= INLINE_CAP); Repr::Inline { // SAFETY: We know that `value.len` is less than or equal to the maximum value of `InlineSize` @@ -885,7 +923,7 @@ impl SmolStrBuilder { buf, } } - SmolStrBuilderRepr::Heap(heap) => Repr::new(heap), + SmolStrBuilderRepr::Heap(heap) => Repr::new(&heap), }) } @@ -900,8 +938,10 @@ impl SmolStrBuilder { *len += char_len; } else { let mut heap = String::with_capacity(new_len); - // copy existing inline bytes over to the heap - // SAFETY: inline data is guaranteed to be valid utf8 for `old_len` bytes + // SAFETY: `buf[..*len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8 (from `char::encode_utf8` or `&str` + // byte copies), so extending the Vec with these bytes preserves the + // String's UTF-8 invariant. unsafe { heap.as_mut_vec().extend_from_slice(&buf[..*len]) }; heap.push(c); self.0 = SmolStrBuilderRepr::Heap(heap); @@ -926,8 +966,10 @@ impl SmolStrBuilder { let mut heap = String::with_capacity(*len); - // copy existing inline bytes over to the heap - // SAFETY: inline data is guaranteed to be valid utf8 for `old_len` bytes + // SAFETY: `buf[..old_len]` was built by prior `push`/`push_str` calls + // that only wrote valid UTF-8 (from `char::encode_utf8` or `&str` byte + // copies), so extending the Vec with these bytes preserves the String's + // UTF-8 invariant. unsafe { heap.as_mut_vec().extend_from_slice(&buf[..old_len]) }; heap.push_str(s); self.0 = SmolStrBuilderRepr::Heap(heap); @@ -945,6 +987,30 @@ impl fmt::Write for SmolStrBuilder { } } +impl iter::Extend<char> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = char>>(&mut self, iter: I) { + for c in iter { + self.push(c); + } + } +} + +impl<'a> iter::Extend<&'a str> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = &'a str>>(&mut self, iter: I) { + for s in iter { + self.push_str(s); + } + } +} + +impl<'a> iter::Extend<&'a String> for SmolStrBuilder { + fn extend<I: iter::IntoIterator<Item = &'a String>>(&mut self, iter: I) { + for s in iter { + self.push_str(s); + } + } +} + impl From<SmolStrBuilder> for SmolStr { fn from(value: SmolStrBuilder) -> Self { value.finish() |