From 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 Mon Sep 17 00:00:00 2001 From: Botahamec Date: Fri, 28 Feb 2025 16:09:11 -0500 Subject: Scoped lock API --- src/collection/boxed.rs | 170 ++++++++++++++++++++++++++++++++---------------- 1 file changed, 113 insertions(+), 57 deletions(-) (limited to 'src/collection/boxed.rs') diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs index 0597e90..364ec97 100644 --- a/src/collection/boxed.rs +++ b/src/collection/boxed.rs @@ -1,25 +1,12 @@ use std::cell::UnsafeCell; use std::fmt::Debug; -use std::marker::PhantomData; use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable}; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; +use super::utils::ordered_contains_duplicates; use super::{utils, BoxedLockCollection, LockGuard}; -/// returns `true` if the sorted list contains a duplicate -#[must_use] -fn contains_duplicates(l: &[&dyn RawLock]) -> bool { - if l.is_empty() { - // Return early to prevent panic in the below call to `windows` - return false; - } - - l.windows(2) - // NOTE: addr_eq is necessary because eq would also compare the v-table pointers - .any(|window| std::ptr::addr_eq(window[0], window[1])) -} - unsafe impl RawLock for BoxedLockCollection { #[mutants::skip] // this should never be called #[cfg(not(tarpaulin_include))] @@ -65,6 +52,11 @@ unsafe impl Lockable for BoxedLockCollection { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.extend(self.locks()) } @@ -72,6 +64,10 @@ unsafe impl Lockable for BoxedLockCollection { unsafe fn guard(&self) -> Self::Guard<'_> { self.child().guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.child().data_mut() + } } unsafe impl Sharable for BoxedLockCollection { @@ -80,9 +76,18 @@ unsafe impl Sharable for BoxedLockCollection { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.child().read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.child().data_ref() + } } unsafe impl OwnedLockable for BoxedLockCollection {} @@ -352,13 +357,53 @@ impl BoxedLockCollection { // safety: we are checking for duplicates before returning unsafe { let this = Self::new_unchecked(data); - if contains_duplicates(this.locks()) { + if ordered_contains_duplicates(this.locks()) { return None; } Some(this) } } + pub fn scoped_lock(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -378,10 +423,8 @@ impl BoxedLockCollection { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> LockGuard> { unsafe { // safety: we have the thread key self.raw_lock(); @@ -390,7 +433,6 @@ impl BoxedLockCollection { // safety: we've already acquired the lock guard: self.child().guard(), key, - _phantom: PhantomData, } } } @@ -424,10 +466,7 @@ impl BoxedLockCollection { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -437,11 +476,7 @@ impl BoxedLockCollection { self.child().guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -461,13 +496,53 @@ impl BoxedLockCollection { /// *guard.1 = "1"; /// let key = LockCollection::<(Mutex, Mutex<&str>)>::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key { + pub fn unlock(guard: LockGuard>) -> ThreadKey { drop(guard.guard); guard.key } } impl BoxedLockCollection { + pub fn scoped_read(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -487,10 +562,8 @@ impl BoxedLockCollection { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + #[must_use] + pub fn read(&self, key: ThreadKey) -> LockGuard> { unsafe { // safety: we have the thread key self.raw_read(); @@ -499,7 +572,6 @@ impl BoxedLockCollection { // safety: we've already acquired the lock guard: self.child().read_guard(), key, - _phantom: PhantomData, } } } @@ -534,10 +606,7 @@ impl BoxedLockCollection { /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result, Key>, Key> { + pub fn try_read(&self, key: ThreadKey) -> Result>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { @@ -548,11 +617,7 @@ impl BoxedLockCollection { self.child().read_guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -570,9 +635,7 @@ impl BoxedLockCollection { /// let mut guard = lock.read(key); /// let key = LockCollection::<(RwLock, RwLock<&str>)>::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'_>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard>) -> ThreadKey { drop(guard.guard); guard.key } @@ -635,7 +698,6 @@ mod tests { .into_iter() .collect(); let guard = collection.lock(key); - // TODO impl PartialEq for MutexRef assert_eq!(*guard[0], "foo"); assert_eq!(*guard[1], "bar"); assert_eq!(*guard[2], "baz"); @@ -647,7 +709,6 @@ mod tests { let collection = BoxedLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]); let guard = collection.lock(key); - // TODO impl PartialEq for MutexRef assert_eq!(*guard[0], "foo"); assert_eq!(*guard[1], "bar"); assert_eq!(*guard[2], "baz"); @@ -666,7 +727,7 @@ mod tests { let mut key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); for (i, mutex) in (&collection).into_iter().enumerate() { - assert_eq!(*mutex.lock(&mut key), i); + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) } } @@ -675,7 +736,7 @@ mod tests { let mut key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); for (i, mutex) in collection.iter().enumerate() { - assert_eq!(*mutex.lock(&mut key), i); + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) } } @@ -703,11 +764,6 @@ mod tests { assert!(BoxedLockCollection::try_new([&mutex1, &mutex1]).is_none()) } - #[test] - fn contains_duplicates_empty() { - assert!(!contains_duplicates(&[])) - } - #[test] fn try_lock_works() { let key = ThreadKey::get().unwrap(); -- cgit v1.2.3