Unnamed repository; edit this file 'description' to name the repository.
Diffstat (limited to 'lib/smol_str/src/lib.rs')
-rw-r--r--lib/smol_str/src/lib.rs104
1 files changed, 98 insertions, 6 deletions
diff --git a/lib/smol_str/src/lib.rs b/lib/smol_str/src/lib.rs
index ff25651f54..f2f021a7b5 100644
--- a/lib/smol_str/src/lib.rs
+++ b/lib/smol_str/src/lib.rs
@@ -233,8 +233,17 @@ impl iter::FromIterator<char> for SmolStr {
}
}
-fn from_char_iter(mut iter: impl Iterator<Item = char>) -> SmolStr {
- let (min_size, _) = iter.size_hint();
+#[inline]
+fn from_char_iter(iter: impl Iterator<Item = char>) -> SmolStr {
+ from_buf_and_chars([0; _], 0, iter)
+}
+
+fn from_buf_and_chars(
+ mut buf: [u8; INLINE_CAP],
+ buf_len: usize,
+ mut iter: impl Iterator<Item = char>,
+) -> SmolStr {
+ let min_size = iter.size_hint().0 + buf_len;
if min_size > INLINE_CAP {
let heap: String = iter.collect();
if heap.len() <= INLINE_CAP {
@@ -243,8 +252,7 @@ fn from_char_iter(mut iter: impl Iterator<Item = char>) -> SmolStr {
}
return SmolStr(Repr::Heap(heap.into_boxed_str().into()));
}
- let mut len = 0;
- let mut buf = [0u8; INLINE_CAP];
+ let mut len = buf_len;
while let Some(ch) = iter.next() {
let size = ch.len_utf8();
if size + len > INLINE_CAP {
@@ -634,12 +642,32 @@ pub trait StrExt: private::Sealed {
impl StrExt for str {
#[inline]
fn to_lowercase_smolstr(&self) -> SmolStr {
- from_char_iter(self.chars().flat_map(|c| c.to_lowercase()))
+ let len = self.len();
+ if len <= INLINE_CAP {
+ let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_lowercase);
+ from_buf_and_chars(
+ buf,
+ len - rest.len(),
+ rest.chars().flat_map(|c| c.to_lowercase()),
+ )
+ } else {
+ self.to_lowercase().into()
+ }
}
#[inline]
fn to_uppercase_smolstr(&self) -> SmolStr {
- from_char_iter(self.chars().flat_map(|c| c.to_uppercase()))
+ let len = self.len();
+ if len <= INLINE_CAP {
+ let (buf, rest) = inline_convert_while_ascii(self, u8::to_ascii_uppercase);
+ from_buf_and_chars(
+ buf,
+ len - rest.len(),
+ rest.chars().flat_map(|c| c.to_uppercase()),
+ )
+ } else {
+ self.to_uppercase().into()
+ }
}
#[inline]
@@ -699,6 +727,70 @@ impl StrExt for str {
}
}
+/// Inline version of std fn `convert_while_ascii`. `s` must have len <= 23.
+#[inline]
+fn inline_convert_while_ascii(s: &str, convert: fn(&u8) -> u8) -> ([u8; INLINE_CAP], &str) {
+ // Process the input in chunks of 16 bytes to enable auto-vectorization.
+ // Previously the chunk size depended on the size of `usize`,
+ // but on 32-bit platforms with sse or neon is also the better choice.
+ // The only downside on other platforms would be a bit more loop-unrolling.
+ const N: usize = 16;
+
+ debug_assert!(s.len() <= INLINE_CAP, "only for inline-able strings");
+
+ let mut slice = s.as_bytes();
+ let mut out = [0u8; INLINE_CAP];
+ let mut out_slice = &mut out[..slice.len()];
+ let mut is_ascii = [false; N];
+
+ while slice.len() >= N {
+ // SAFETY: checked in loop condition
+ let chunk = unsafe { slice.get_unchecked(..N) };
+ // SAFETY: out_slice has at least same length as input slice and gets sliced with the same offsets
+ let out_chunk = unsafe { out_slice.get_unchecked_mut(..N) };
+
+ for j in 0..N {
+ is_ascii[j] = chunk[j] <= 127;
+ }
+
+ // Auto-vectorization for this check is a bit fragile, sum and comparing against the chunk
+ // size gives the best result, specifically a pmovmsk instruction on x86.
+ // See https://github.com/llvm/llvm-project/issues/96395 for why llvm currently does not
+ // currently recognize other similar idioms.
+ if is_ascii.iter().map(|x| *x as u8).sum::<u8>() as usize != N {
+ break;
+ }
+
+ for j in 0..N {
+ out_chunk[j] = convert(&chunk[j]);
+ }
+
+ slice = unsafe { slice.get_unchecked(N..) };
+ out_slice = unsafe { out_slice.get_unchecked_mut(N..) };
+ }
+
+ // handle the remainder as individual bytes
+ while !slice.is_empty() {
+ let byte = slice[0];
+ if byte > 127 {
+ break;
+ }
+ // SAFETY: out_slice has at least same length as input slice
+ unsafe {
+ *out_slice.get_unchecked_mut(0) = convert(&byte);
+ }
+ slice = unsafe { slice.get_unchecked(1..) };
+ out_slice = unsafe { out_slice.get_unchecked_mut(1..) };
+ }
+
+ unsafe {
+ // SAFETY: we know this is a valid char boundary
+ // since we only skipped over leading ascii bytes
+ let rest = core::str::from_utf8_unchecked(slice);
+ (out, rest)
+ }
+}
+
impl<T> ToSmolStr for T
where
T: fmt::Display + ?Sized,