From 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 Mon Sep 17 00:00:00 2001 From: Botahamec Date: Fri, 28 Feb 2025 16:09:11 -0500 Subject: Scoped lock API --- src/rwlock/rwlock.rs | 130 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 107 insertions(+), 23 deletions(-) (limited to 'src/rwlock/rwlock.rs') diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs index 038e6c7..905ecf8 100644 --- a/src/rwlock/rwlock.rs +++ b/src/rwlock/rwlock.rs @@ -6,10 +6,10 @@ use std::panic::AssertUnwindSafe; use lock_api::RawRwLock; use crate::handle_unwind::handle_unwind; -use crate::key::Keyable; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; +use crate::{Keyable, ThreadKey}; use super::{PoisonFlag, RwLock, RwLockReadGuard, RwLockReadRef, RwLockWriteGuard, RwLockWriteRef}; @@ -79,6 +79,11 @@ unsafe impl Lockable for RwLock { where Self: 'g; + type DataMut<'a> + = &'a mut T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self); } @@ -86,6 +91,10 @@ unsafe impl Lockable for RwLock { unsafe fn guard(&self) -> Self::Guard<'_> { RwLockWriteRef::new(self) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.get().as_mut().unwrap_unchecked() + } } unsafe impl Sharable for RwLock { @@ -94,9 +103,18 @@ unsafe impl Sharable for RwLock { where Self: 'g; + type DataRef<'a> + = &'a T + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { RwLockReadRef::new(self) } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.get().as_ref().unwrap_unchecked() + } } unsafe impl OwnedLockable for RwLock {} @@ -230,6 +248,86 @@ impl RwLock { } impl RwLock { + pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(&T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the rwlock was just locked + let r = f(self.data.get().as_ref().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays valid for long enough + + r + } + } + + pub fn scoped_try_read( + &self, + key: Key, + f: impl Fn(&T) -> Ret, + ) -> Result { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: the rwlock was just locked + let r = f(self.data.get().as_ref().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays valid for long enough + + Ok(r) + } + } + + pub fn scoped_write(&self, key: impl Keyable, f: impl Fn(&mut T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: we just locked the rwlock + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock(); + + drop(key); // ensure the key stays valid for long enough + + r + } + } + + pub fn scoped_try_write( + &self, + key: Key, + f: impl Fn(&mut T) -> Ret, + ) -> Result { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: the rwlock was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock(); + + drop(key); // ensure the key stays valid for long enough + + Ok(r) + } + } + /// Locks this `RwLock` with shared read access, blocking the current /// thread until it can be acquired. /// @@ -264,10 +362,7 @@ impl RwLock { /// ``` /// /// [`ThreadKey`]: `crate::ThreadKey` - pub fn read<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> RwLockReadGuard<'s, 'key, T, Key, R> { + pub fn read(&self, key: ThreadKey) -> RwLockReadGuard<'_, T, R> { unsafe { self.raw_read(); @@ -305,10 +400,7 @@ impl RwLock { /// Err(_) => unreachable!(), /// }; /// ``` - pub fn try_read<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result, Key> { + pub fn try_read(&self, key: ThreadKey) -> Result, ThreadKey> { unsafe { if self.raw_try_read() { // safety: the lock is locked first @@ -369,10 +461,7 @@ impl RwLock { /// ``` /// /// [`ThreadKey`]: `crate::ThreadKey` - pub fn write<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> RwLockWriteGuard<'s, 'key, T, Key, R> { + pub fn write(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> { unsafe { self.raw_lock(); @@ -407,10 +496,7 @@ impl RwLock { /// let n = lock.read(key); /// assert_eq!(*n, 1); /// ``` - pub fn try_write<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result, Key> { + pub fn try_write(&self, key: ThreadKey) -> Result, ThreadKey> { unsafe { if self.raw_try_lock() { // safety: the lock is locked first @@ -445,9 +531,8 @@ impl RwLock { /// assert_eq!(*guard, 0); /// let key = RwLock::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: RwLockReadGuard<'_, 'key, T, Key, R>, - ) -> Key { + #[must_use] + pub fn unlock_read(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey { unsafe { guard.rwlock.0.raw_unlock_read(); } @@ -473,9 +558,8 @@ impl RwLock { /// *guard += 20; /// let key = RwLock::unlock_write(guard); /// ``` - pub fn unlock_write<'key, Key: Keyable + 'key>( - guard: RwLockWriteGuard<'_, 'key, T, Key, R>, - ) -> Key { + #[must_use] + pub fn unlock_write(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey { unsafe { guard.rwlock.0.raw_unlock(); } -- cgit v1.2.3