use std::cell::UnsafeCell;
use std::mem::MaybeUninit;
use std::ptr;
use std::sync::atomic::{AtomicU8, Ordering};
const STATE_INIT: u8 = 0;
const STATE_LOCKED: u8 = STATE_INIT + 1;
const STATE_READY: u8 = STATE_LOCKED + 1;
pub struct OnceCell<T> {
value: UnsafeCell<MaybeUninit<T>>,
state: AtomicU8,
}
impl<T> OnceCell<T> {
#[must_use]
pub const fn new() -> Self {
Self {
value: UnsafeCell::new(MaybeUninit::uninit()),
state: AtomicU8::new(STATE_INIT),
}
}
pub const fn new_init(value: T) -> Self {
Self {
value: UnsafeCell::new(MaybeUninit::new(value)),
state: AtomicU8::new(STATE_READY),
}
}
pub fn get(&self) -> Option<&T> {
if self.state.load(Ordering::Acquire) == STATE_READY {
// SAFETY: won't be overwritten for the lifetime of this reference
Some(unsafe { (*self.value.get()).assume_init_ref() })
} else {
None
}
}
pub fn get_or_wait(&self) -> Option<&T> {
loop {
match self.state.load(Ordering::Acquire) {
STATE_INIT => return None,
STATE_LOCKED => (), // continue
STATE_READY => return Some(unsafe { (*self.value.get()).assume_init_ref() }),
x => unreachable!("invalid state {x}"),
}
}
}
pub fn get_or_init<F: Fn() -> T>(&self, init: F) -> &T {
loop {
match self.state.compare_exchange(
STATE_INIT,
STATE_LOCKED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(..) => {
let value = init();
let written = &*unsafe { &mut *self.value.get() }.write(value);
self.state.store(STATE_READY, Ordering::Release);
return written;
}
Err(STATE_READY) => return unsafe { (*self.value.get()).assume_init_ref() },
Err(..) => (), // locked or spurious failure
}
}
}
pub fn set(&self, value: T) -> Result<&T, T> {
// don't set state to STATE_READY on success because we have to release afterward anyway
match self.state.compare_exchange(
STATE_INIT,
STATE_LOCKED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(..) => {
// SAFETY: unique because only one thread can lock the atomic state
let written = &*unsafe { &mut *self.value.get() }.write(value);
self.state.store(STATE_READY, Ordering::Release);
Ok(written)
}
// SAFETY: guaranteed to be initialized & protected by acquire ordering
Err(STATE_READY) => return Ok(unsafe { (*self.value.get()).assume_init_ref() }),
Err(..) => Err(value), // locked or spurious failure
}
}
pub fn set_mut(&mut self, mut value: T) -> Result<Option<T>, T> {
// don't set state to STATE_READY on success because we have to release afterward anyway
match self.state.compare_exchange(
STATE_INIT,
STATE_LOCKED,
Ordering::AcqRel,
Ordering::Acquire,
) {
Ok(..) => {
self.value.get_mut().write(value);
self.state.store(STATE_READY, Ordering::Release);
Ok(None)
}
Err(STATE_READY) => {
// SAFETY: guaranteed to be initialized & protected by acquire ordering
std::mem::swap(
unsafe { self.value.get_mut().assume_init_mut() },
&mut value,
);
// ensure changes are visible to others acquiring the atomic state
self.state.store(STATE_READY, Ordering::Release);
// we've swapped the previous value into this variable
Ok(Some(value))
}
Err(..) => Err(value), // locked or spurious failure
}
}
pub fn into_inner(mut self) -> Option<T> {
// must be atomic so potential writes during the drop see a valid state
let inner = match self.state.load(Ordering::Acquire) {
STATE_INIT => None,
STATE_LOCKED => unreachable!("consumed cell during initialization"),
// SAFETY: initialized & we'll forget about it afterwards
STATE_READY => Some(unsafe { self.value.get_mut().assume_init_read() }),
x => unreachable!("invalid state {x}"),
};
// SAFETY: just in case AtomicU8 has a drop handler
unsafe {
ptr::drop_in_place(std::ptr::addr_of_mut!(self.state));
}
std::mem::forget(self);
inner
}
}
impl<T> Default for OnceCell<T> {
fn default() -> Self {
OnceCell::new()
}
}
impl<T> From<T> for OnceCell<T> {
fn from(value: T) -> Self {
OnceCell::new_init(value)
}
}
impl<T> From<OnceCell<T>> for Option<T> {
fn from(value: OnceCell<T>) -> Self {
value.into_inner()
}
}
impl<T> Drop for OnceCell<T> {
fn drop(&mut self) {
match *self.state.get_mut() {
STATE_INIT => (),
STATE_LOCKED => unreachable!("dropped cell during initialization"),
// MaybeUninit requires us to manually drop the value
STATE_READY => unsafe { self.value.get_mut().assume_init_drop() },
x => unreachable!("invalid state {x}"),
}
}
}
unsafe impl<T: Send> Send for OnceCell<T> {}
unsafe impl<T: Sync> Sync for OnceCell<T> {}