summaryrefslogtreecommitdiff
path: root/src/key.rs
blob: 887ac295a41294ce5e52cd695e47d3197049f4de (plain)
use std::fmt::{self, Debug};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicBool, Ordering};

use once_cell::sync::Lazy;
use thread_local::ThreadLocal;

use sealed::Sealed;

mod sealed {
	use super::ThreadKey;

	pub trait Sealed {}
	impl Sealed for ThreadKey {}
	impl Sealed for &mut ThreadKey {}
}

static KEY: Lazy<ThreadLocal<AtomicLock>> = Lazy::new(ThreadLocal::new);

/// The key for the current thread.
///
/// Only one of these exist per thread. To get the current thread's key, call
/// [`ThreadKey::lock`]. If the `ThreadKey` is dropped, it can be reobtained.
pub struct ThreadKey {
	phantom: PhantomData<*const ()>, // implement !Send and !Sync
}

/// Allows the type to be used as a key for a lock
///
/// # Safety
///
/// Only one value which implements this trait may be allowed to exist at a
/// time. Creating a new `Keyable` value requires making any other `Keyable`
/// values invalid.
pub unsafe trait Keyable: Sealed {}
unsafe impl Keyable for ThreadKey {}
unsafe impl Keyable for &mut ThreadKey {}

impl Debug for ThreadKey {
	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
		write!(f, "ThreadKey")
	}
}

impl Drop for ThreadKey {
	fn drop(&mut self) {
		unsafe { KEY.get().unwrap().force_unlock() }
	}
}

impl ThreadKey {
	/// Get the current thread's `ThreadKey`, if it's not already taken.
	///
	/// The first time this is called, it will successfully return a
	/// `ThreadKey`. However, future calls to this function on the same thread
	/// will return [`None`], unless the key is dropped or unlocked first.
	#[must_use]
	pub fn lock() -> Option<Self> {
		// safety: we just acquired the lock
		KEY.get_or_default().try_lock().then_some(Self {
			phantom: PhantomData,
		})
	}

	/// Unlocks the `ThreadKey`.
	///
	/// After this method is called, a call to [`ThreadKey::lock`] will return
	/// this `ThreadKey`.
	pub fn unlock(self) {
		drop(self);
	}
}

/// A dumb lock that's just a wrapper for an [`AtomicBool`].
#[derive(Debug, Default)]
struct AtomicLock {
	is_locked: AtomicBool,
}

impl AtomicLock {
	/// Attempt to lock the `AtomicLock`. This is not a fair lock.
	#[must_use]
	pub fn try_lock(&'static self) -> bool {
		!self.is_locked.swap(true, Ordering::Acquire)
	}

	/// Forcibly unlocks the `AtomicLock`. This should only be called if the
	/// key to the lock has been "lost".
	pub unsafe fn force_unlock(&self) {
		self.is_locked.store(false, Ordering::Release);
	}
}