Unnamed repository; edit this file 'description' to name the repository.
Implement rough symbol interning infra
Lukas Wirth 2024-07-12
parent ffbc5ad · commit 6275eb1
-rw-r--r--.typos.toml2
-rw-r--r--Cargo.lock7
-rw-r--r--crates/intern/Cargo.toml3
-rw-r--r--crates/intern/src/lib.rs3
-rw-r--r--crates/intern/src/symbol.rs293
-rw-r--r--crates/intern/src/symbol/symbols.rs236
6 files changed, 543 insertions, 1 deletions
diff --git a/.typos.toml b/.typos.toml
index c2e8b26521..e7e764ce03 100644
--- a/.typos.toml
+++ b/.typos.toml
@@ -14,6 +14,8 @@ extend-ignore-re = [
"\\w*\\.{3,4}\\w*",
'"flate2"',
"raison d'ĂȘtre",
+ "inout",
+ "optin"
]
[default.extend-words]
diff --git a/Cargo.lock b/Cargo.lock
index e9ebe26f42..b165697724 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -835,6 +835,7 @@ dependencies = [
"dashmap",
"hashbrown",
"rustc-hash",
+ "sptr",
"triomphe",
]
@@ -1886,6 +1887,12 @@ dependencies = [
]
[[package]]
+name = "sptr"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3b9b39299b249ad65f3b7e96443bad61c02ca5cd3589f46cb6d610a0fd6c0d6a"
+
+[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/crates/intern/Cargo.toml b/crates/intern/Cargo.toml
index 67b4164ce1..c08ecb5c30 100644
--- a/crates/intern/Cargo.toml
+++ b/crates/intern/Cargo.toml
@@ -18,6 +18,7 @@ dashmap.workspace = true
hashbrown.workspace = true
rustc-hash.workspace = true
triomphe.workspace = true
+sptr = "0.3.2"
[lints]
-workspace = true \ No newline at end of file
+workspace = true
diff --git a/crates/intern/src/lib.rs b/crates/intern/src/lib.rs
index 40d18b1cf8..868d03caff 100644
--- a/crates/intern/src/lib.rs
+++ b/crates/intern/src/lib.rs
@@ -20,6 +20,9 @@ type Guard<T> = dashmap::RwLockWriteGuard<
HashMap<Arc<T>, SharedValue<()>, BuildHasherDefault<FxHasher>>,
>;
+mod symbol;
+pub use self::symbol::{symbols, Symbol};
+
pub struct Interned<T: Internable + ?Sized> {
arc: Arc<T>,
}
diff --git a/crates/intern/src/symbol.rs b/crates/intern/src/symbol.rs
new file mode 100644
index 0000000000..a1cffd0662
--- /dev/null
+++ b/crates/intern/src/symbol.rs
@@ -0,0 +1,293 @@
+//! Attempt at flexible symbol interning, allowing to intern and free strings at runtime while also
+//! supporting
+
+use std::{
+ borrow::Borrow,
+ fmt,
+ hash::{BuildHasherDefault, Hash, Hasher},
+ mem,
+ ptr::NonNull,
+ sync::OnceLock,
+};
+
+use dashmap::{DashMap, SharedValue};
+use hashbrown::{hash_map::RawEntryMut, HashMap};
+use rustc_hash::FxHasher;
+use sptr::Strict;
+use triomphe::Arc;
+
+pub mod symbols;
+
+// some asserts for layout compatibility
+const _: () = assert!(std::mem::size_of::<Box<str>>() == std::mem::size_of::<&str>());
+const _: () = assert!(std::mem::align_of::<Box<str>>() == std::mem::align_of::<&str>());
+
+const _: () = assert!(std::mem::size_of::<Arc<Box<str>>>() == std::mem::size_of::<&&str>());
+const _: () = assert!(std::mem::align_of::<Arc<Box<str>>>() == std::mem::align_of::<&&str>());
+
+/// A pointer that points to a pointer to a `str`, it may be backed as a `&'static &'static str` or
+/// `Arc<Box<str>>` but its size is that of a thin pointer. The active variant is encoded as a tag
+/// in the LSB of the alignment niche.
+#[derive(PartialEq, Eq, Hash, Copy, Clone, Debug)]
+struct TaggedArcPtr {
+ packed: NonNull<*const str>,
+}
+
+unsafe impl Send for TaggedArcPtr {}
+unsafe impl Sync for TaggedArcPtr {}
+
+impl TaggedArcPtr {
+ const BOOL_BITS: usize = true as usize;
+
+ const fn non_arc(r: &&str) -> Self {
+ Self {
+ // SAFETY: The pointer is non-null as it is derived from a reference
+ // Ideally we would call out to `pack_arc` but for a `false` tag, unfortunately the
+ // packing stuff requires reading out the pointer to an integer which is not supported
+ // in const contexts, so here we make use of the fact that for the non-arc version the
+ // tag is false (0) and thus does not need touching the actual pointer value.ext)
+ packed: unsafe {
+ NonNull::new_unchecked((r as *const &str).cast::<*const str>().cast_mut())
+ },
+ }
+ }
+
+ fn arc(arc: Arc<Box<str>>) -> Self {
+ Self {
+ packed: Self::pack_arc(
+ // Safety: `Arc::into_raw` always returns a non null pointer
+ unsafe { NonNull::new_unchecked(Arc::into_raw(arc).cast_mut().cast()) },
+ ),
+ }
+ }
+
+ /// Retrieves the tag.
+ #[inline]
+ pub(crate) fn try_as_arc_owned(self) -> Option<Arc<Box<str>>> {
+ // Unpack the tag from the alignment niche
+ let tag = Strict::addr(self.packed.as_ptr()) & Self::BOOL_BITS;
+ if tag != 0 {
+ // Safety: We checked that the tag is non-zero -> true, so we are pointing to the data offset of an `Arc`
+ Some(unsafe { Arc::from_raw(self.pointer().as_ptr().cast::<Box<str>>()) })
+ } else {
+ None
+ }
+ }
+
+ #[inline]
+ const fn pack_arc(ptr: NonNull<*const str>) -> NonNull<*const str> {
+ let packed_tag = true as usize;
+
+ // can't use this strict provenance stuff here due to trait methods not being const
+ // unsafe {
+ // // Safety: The pointer is derived from a non-null
+ // NonNull::new_unchecked(Strict::map_addr(ptr.as_ptr(), |addr| {
+ // // Safety:
+ // // - The pointer is `NonNull` => it's address is `NonZero<usize>`
+ // // - `P::BITS` least significant bits are always zero (`Pointer` contract)
+ // // - `T::BITS <= P::BITS` (from `Self::ASSERTION`)
+ // //
+ // // Thus `addr >> T::BITS` is guaranteed to be non-zero.
+ // //
+ // // `{non_zero} | packed_tag` can't make the value zero.
+
+ // (addr >> Self::BOOL_BITS) | packed_tag
+ // }))
+ // }
+ // so what follows is roughly what the above looks like but inlined
+
+ let self_addr = unsafe { core::mem::transmute::<*const _, usize>(ptr.as_ptr()) };
+ let addr = self_addr | packed_tag;
+ let dest_addr = addr as isize;
+ let offset = dest_addr.wrapping_sub(self_addr as isize);
+
+ // SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes
+ unsafe { NonNull::new_unchecked(ptr.as_ptr().cast::<u8>().wrapping_offset(offset).cast()) }
+ }
+
+ #[inline]
+ pub(crate) fn pointer(self) -> NonNull<*const str> {
+ // SAFETY: The resulting pointer is guaranteed to be NonNull as we only modify the niche bytes
+ unsafe {
+ NonNull::new_unchecked(Strict::map_addr(self.packed.as_ptr(), |addr| {
+ addr & !Self::BOOL_BITS
+ }))
+ }
+ }
+
+ #[inline]
+ pub(crate) fn as_str(&self) -> &str {
+ // SAFETY: We always point to a pointer to a str no matter what variant is active
+ unsafe { *self.pointer().as_ptr().cast::<&str>() }
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Debug)]
+pub struct Symbol {
+ repr: TaggedArcPtr,
+}
+const _: () = assert!(std::mem::size_of::<Symbol>() == std::mem::size_of::<NonNull<()>>());
+const _: () = assert!(std::mem::align_of::<Symbol>() == std::mem::align_of::<NonNull<()>>());
+
+static MAP: OnceLock<DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>>> = OnceLock::new();
+
+impl Symbol {
+ pub fn intern(s: &str) -> Self {
+ let (mut shard, hash) = Self::select_shard(s);
+ // Atomically,
+ // - check if `obj` is already in the map
+ // - if so, copy out its entry, conditionally bumping the backing Arc and return it
+ // - if not, put it into a box and then into an Arc, insert it, bump the ref-count and return the copy
+ // This needs to be atomic (locking the shard) to avoid races with other thread, which could
+ // insert the same object between us looking it up and inserting it.
+ match shard.raw_entry_mut().from_key_hashed_nocheck(hash, s) {
+ RawEntryMut::Occupied(occ) => Self { repr: increase_arc_refcount(occ.key().0) },
+ RawEntryMut::Vacant(vac) => Self {
+ repr: increase_arc_refcount(
+ vac.insert_hashed_nocheck(
+ hash,
+ SymbolProxy(TaggedArcPtr::arc(Arc::new(Box::<str>::from(s)))),
+ SharedValue::new(()),
+ )
+ .0
+ .0,
+ ),
+ },
+ }
+ }
+
+ pub fn as_str(&self) -> &str {
+ self.repr.as_str()
+ }
+
+ #[inline]
+ fn select_shard(
+ s: &str,
+ ) -> (
+ dashmap::RwLockWriteGuard<
+ 'static,
+ HashMap<SymbolProxy, SharedValue<()>, BuildHasherDefault<FxHasher>>,
+ >,
+ u64,
+ ) {
+ let storage = MAP.get_or_init(symbols::prefill);
+ let hash = {
+ let mut hasher = std::hash::BuildHasher::build_hasher(storage.hasher());
+ s.hash(&mut hasher);
+ hasher.finish()
+ };
+ let shard_idx = storage.determine_shard(hash as usize);
+ let shard = &storage.shards()[shard_idx];
+ (shard.write(), hash)
+ }
+
+ #[cold]
+ fn drop_slow(arc: &Arc<Box<str>>) {
+ let (mut shard, hash) = Self::select_shard(arc);
+
+ if Arc::count(arc) != 2 {
+ // Another thread has interned another copy
+ return;
+ }
+
+ match shard.raw_entry_mut().from_key_hashed_nocheck::<str>(hash, arc.as_ref()) {
+ RawEntryMut::Occupied(occ) => occ.remove_entry(),
+ RawEntryMut::Vacant(_) => unreachable!(),
+ }
+ .0
+ .0
+ .try_as_arc_owned()
+ .unwrap();
+
+ // Shrink the backing storage if the shard is less than 50% occupied.
+ if shard.len() * 2 < shard.capacity() {
+ shard.shrink_to_fit();
+ }
+ }
+}
+
+impl Drop for Symbol {
+ #[inline]
+ fn drop(&mut self) {
+ let Some(arc) = self.repr.try_as_arc_owned() else {
+ return;
+ };
+ // When the last `Ref` is dropped, remove the object from the global map.
+ if Arc::count(&arc) == 2 {
+ // Only `self` and the global map point to the object.
+
+ Self::drop_slow(&arc);
+ }
+ // decrement the ref count
+ drop(arc);
+ }
+}
+
+fn increase_arc_refcount(repr: TaggedArcPtr) -> TaggedArcPtr {
+ let Some(arc) = repr.try_as_arc_owned() else {
+ return repr;
+ };
+ // increase the ref count
+ mem::forget(arc.clone());
+ mem::forget(arc);
+ repr
+}
+
+impl fmt::Display for Symbol {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ self.as_str().fmt(f)
+ }
+}
+
+// only exists so we can use `from_key_hashed_nocheck` with a &str
+#[derive(Debug, PartialEq, Eq)]
+struct SymbolProxy(TaggedArcPtr);
+
+impl Hash for SymbolProxy {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.0.as_str().hash(state);
+ }
+}
+
+impl Borrow<str> for SymbolProxy {
+ fn borrow(&self) -> &str {
+ self.0.as_str()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn smoke_test() {
+ Symbol::intern("isize");
+ let base_len = MAP.get().unwrap().len();
+ let hello = Symbol::intern("hello");
+ let world = Symbol::intern("world");
+ let bang = Symbol::intern("!");
+ let q = Symbol::intern("?");
+ assert_eq!(MAP.get().unwrap().len(), base_len + 4);
+ let bang2 = Symbol::intern("!");
+ assert_eq!(MAP.get().unwrap().len(), base_len + 4);
+ drop(bang2);
+ assert_eq!(MAP.get().unwrap().len(), base_len + 4);
+ drop(q);
+ assert_eq!(MAP.get().unwrap().len(), base_len + 3);
+ let default = Symbol::intern("default");
+ assert_eq!(MAP.get().unwrap().len(), base_len + 3);
+ assert_eq!(
+ "hello default world!",
+ format!("{} {} {}{}", hello.as_str(), default.as_str(), world.as_str(), bang.as_str())
+ );
+ drop(default);
+ assert_eq!(
+ "hello world!",
+ format!("{} {}{}", hello.as_str(), world.as_str(), bang.as_str())
+ );
+ drop(hello);
+ drop(world);
+ drop(bang);
+ assert_eq!(MAP.get().unwrap().len(), base_len);
+ }
+}
diff --git a/crates/intern/src/symbol/symbols.rs b/crates/intern/src/symbol/symbols.rs
new file mode 100644
index 0000000000..d2ca4401b6
--- /dev/null
+++ b/crates/intern/src/symbol/symbols.rs
@@ -0,0 +1,236 @@
+#![allow(non_upper_case_globals)]
+
+use std::hash::{BuildHasherDefault, Hash as _, Hasher as _};
+
+use dashmap::{DashMap, SharedValue};
+use rustc_hash::FxHasher;
+
+use crate::{
+ symbol::{SymbolProxy, TaggedArcPtr},
+ Symbol,
+};
+macro_rules! define_symbols {
+ ($($name:ident),* $(,)?) => {
+ $(
+ pub const $name: Symbol = Symbol { repr: TaggedArcPtr::non_arc(&stringify!($name)) };
+ )*
+
+
+ pub(super) fn prefill() -> DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>> {
+ let mut dashmap_ = <DashMap<SymbolProxy, (), BuildHasherDefault<FxHasher>>>::with_hasher(BuildHasherDefault::default());
+
+ let hash_thing_ = |hasher_: &BuildHasherDefault<FxHasher>, it_: &SymbolProxy| {
+ let mut hasher_ = std::hash::BuildHasher::build_hasher(hasher_);
+ it_.hash(&mut hasher_);
+ hasher_.finish()
+ };
+ {
+ $(
+
+ let proxy_ = SymbolProxy($name.repr);
+ let hash_ = hash_thing_(dashmap_.hasher(), &proxy_);
+ let shard_idx_ = dashmap_.determine_shard(hash_ as usize);
+ dashmap_.shards_mut()[shard_idx_].get_mut().raw_entry_mut().from_hash(hash_, |k| k == &proxy_).insert(proxy_, SharedValue::new(()));
+ )*
+ }
+ dashmap_
+ }
+ };
+}
+define_symbols! {
+ add_assign,
+ add,
+ alloc,
+ as_str,
+ asm,
+ assert,
+ bench,
+ bitand_assign,
+ bitand,
+ bitor_assign,
+ bitor,
+ bitxor_assign,
+ bitxor,
+ bool,
+ Box,
+ boxed,
+ branch,
+ call_mut,
+ call_once,
+ call,
+ Center,
+ cfg_accessible,
+ cfg_attr,
+ cfg_eval,
+ cfg,
+ char,
+ Clone,
+ column,
+ compile_error,
+ concat_bytes,
+ concat_idents,
+ concat,
+ const_format_args,
+ Copy,
+ core_panic,
+ core,
+ crate_type,
+ Debug,
+ default,
+ Default,
+ deref_mut,
+ deref,
+ derive_const,
+ derive,
+ div_assign,
+ div,
+ doc,
+ drop,
+ env,
+ eq,
+ Eq,
+ f128,
+ f16,
+ f32,
+ f64,
+ feature,
+ file,
+ filter_map,
+ fmt,
+ fn_mut,
+ fn_once,
+ format_args_nl,
+ format_args,
+ format,
+ from_usize,
+ future_trait,
+ future,
+ Future,
+ ge,
+ global_allocator,
+ global_asm,
+ gt,
+ Hash,
+ i128,
+ i16,
+ i32,
+ i64,
+ i8,
+ Implied,
+ include_bytes,
+ include_str,
+ include,
+ index_mut,
+ index,
+ Index,
+ into_future,
+ IntoFuture,
+ IntoIter,
+ IntoIterator,
+ is_empty,
+ Is,
+ isize,
+ Item,
+ iter_mut,
+ iter,
+ Iterator,
+ le,
+ Left,
+ len,
+ line,
+ llvm_asm,
+ log_syntax,
+ lt,
+ macro_rules,
+ module_path,
+ mul_assign,
+ mul,
+ ne,
+ neg,
+ Neg,
+ new_binary,
+ new_debug,
+ new_display,
+ new_lower_exp,
+ new_lower_hex,
+ new_octal,
+ new_pointer,
+ new_upper_exp,
+ new_upper_hex,
+ new_v1_formatted,
+ new,
+ next,
+ no_core,
+ no_std,
+ none,
+ None,
+ not,
+ Not,
+ Ok,
+ ops,
+ option_env,
+ option,
+ Option,
+ Ord,
+ Output,
+ owned_box,
+ panic_2015,
+ panic_2021,
+ Param,
+ partial_ord,
+ PartialEq,
+ PartialOrd,
+ pieces,
+ poll,
+ prelude,
+ quote,
+ r#fn,
+ Range,
+ RangeFrom,
+ RangeFull,
+ RangeInclusive,
+ RangeTo,
+ RangeToInclusive,
+ recursion_limit,
+ register_attr,
+ register_tool,
+ rem_assign,
+ rem,
+ result,
+ Result,
+ Right,
+ rust_2015,
+ rust_2018,
+ rust_2021,
+ rust_2024,
+ shl_assign,
+ shl,
+ shr_assign,
+ shr,
+ std_panic,
+ std,
+ str,
+ string,
+ String,
+ stringify,
+ sub_assign,
+ sub,
+ Target,
+ test_case,
+ test,
+ trace_macros,
+ Try,
+ u128,
+ u16,
+ u32,
+ u64,
+ u8,
+ Unknown,
+ unreachable_2015,
+ unreachable_2021,
+ unreachable,
+ unsafe_cell,
+ usize,
+ v1,
+ va_list
+}