diff options
Diffstat (limited to 'src/collection/retry.rs')
| -rw-r--r-- | src/collection/retry.rs | 176 |
1 files changed, 127 insertions, 49 deletions
diff --git a/src/collection/retry.rs b/src/collection/retry.rs index 331b669..775ea29 100644 --- a/src/collection/retry.rs +++ b/src/collection/retry.rs @@ -1,24 +1,18 @@ use std::cell::Cell; use std::collections::HashSet; -use std::marker::PhantomData; use crate::collection::utils; use crate::handle_unwind::handle_unwind; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; -use super::utils::{attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic}; +use super::utils::{ + attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted, +}; use super::{LockGuard, RetryingLockCollection}; -/// Get all raw locks in the collection -fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); - locks -} - /// Checks that a collection contains no duplicate references to a lock. fn contains_duplicates<L: Lockable>(data: L) -> bool { let mut locks = Vec::new(); @@ -40,14 +34,14 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { #[mutants::skip] // this should never run #[cfg(not(tarpaulin_include))] fn poison(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); for lock in locks { lock.poison(); } } unsafe fn raw_lock(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this probably prevents a panic later @@ -109,7 +103,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_try_lock(&self) -> bool { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this is an interesting case, but it doesn't give us access to @@ -139,7 +133,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_unlock(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock(); @@ -147,7 +141,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_read(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this probably prevents a panic later @@ -200,7 +194,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_try_read(&self) -> bool { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this is an interesting case, but it doesn't give us access to @@ -229,7 +223,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_unlock_read(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock_read(); @@ -243,6 +237,11 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { self.data.get_ptrs(ptrs) } @@ -250,6 +249,10 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.data.guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.data_mut() + } } unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> { @@ -258,9 +261,18 @@ unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.data.read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.data_ref() + } } unsafe impl<L: OwnedLockable> OwnedLockable for RetryingLockCollection<L> {} @@ -516,6 +528,46 @@ impl<L: Lockable> RetryingLockCollection<L> { (!contains_duplicates(&data)).then_some(Self { data }) } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -536,10 +588,7 @@ impl<L: Lockable> RetryingLockCollection<L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { unsafe { // safety: we're taking the thread key self.raw_lock(); @@ -548,7 +597,6 @@ impl<L: Lockable> RetryingLockCollection<L> { // safety: we just locked the collection guard: self.guard(), key, - _phantom: PhantomData, } } } @@ -583,10 +631,7 @@ impl<L: Lockable> RetryingLockCollection<L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { unsafe { // safety: we're taking the thread key if self.raw_try_lock() { @@ -594,7 +639,6 @@ impl<L: Lockable> RetryingLockCollection<L> { // safety: we just succeeded in locking everything guard: self.guard(), key, - _phantom: PhantomData, }) } else { Err(key) @@ -620,13 +664,53 @@ impl<L: Lockable> RetryingLockCollection<L> { /// *guard.1 = "1"; /// let key = RetryingLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable> RetryingLockCollection<L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -647,10 +731,7 @@ impl<L: Sharable> RetryingLockCollection<L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we're taking the thread key self.raw_read(); @@ -659,7 +740,6 @@ impl<L: Sharable> RetryingLockCollection<L> { // safety: we just locked the collection guard: self.read_guard(), key, - _phantom: PhantomData, } } } @@ -687,25 +767,25 @@ impl<L: Sharable> RetryingLockCollection<L> { /// let lock = RetryingLockCollection::new(data); /// /// match lock.try_read(key) { - /// Some(mut guard) => { + /// Ok(mut guard) => { /// assert_eq!(*guard.0, 5); /// assert_eq!(*guard.1, "6"); /// }, - /// None => unreachable!(), + /// Err(_) => unreachable!(), /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { unsafe { // safety: we're taking the thread key - self.raw_try_lock().then(|| LockGuard { + if !self.raw_try_lock() { + return Err(key); + } + + Ok(LockGuard { // safety: we just succeeded in locking everything guard: self.read_guard(), key, - _phantom: PhantomData, }) } } @@ -726,9 +806,7 @@ impl<L: Sharable> RetryingLockCollection<L> { /// let mut guard = lock.read(key); /// let key = RetryingLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'_>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -833,7 +911,7 @@ where mod tests { use super::*; use crate::collection::BoxedLockCollection; - use crate::{Mutex, RwLock, ThreadKey}; + use crate::{LockCollection, Mutex, RwLock, ThreadKey}; #[test] fn nonduplicate_lock_references_are_allowed() { @@ -869,7 +947,6 @@ mod tests { let rwlock1 = RwLock::new(0); let rwlock2 = RwLock::new(0); let collection = RetryingLockCollection::try_new([&rwlock1, &rwlock2]).unwrap(); - // TODO Poisonable::read let guard = collection.read(key); @@ -909,13 +986,14 @@ mod tests { #[test] fn lock_empty_lock_collection() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let collection: RetryingLockCollection<[RwLock<i32>; 0]> = RetryingLockCollection::new([]); - let guard = collection.lock(&mut key); + let guard = collection.lock(key); assert!(guard.len() == 0); + let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard); - let guard = collection.read(&mut key); + let guard = collection.read(key); assert!(guard.len() == 0); } } |
