From 4a5ec04a29cba07c5960792528bd66b0f99ee3ee Mon Sep 17 00:00:00 2001 From: Botahamec Date: Fri, 7 Feb 2025 17:48:26 -0500 Subject: Fix lifetimes for poison guards --- src/collection/boxed.rs | 93 +++++++++++++++++++++++++++++++++++++++++++++++++ src/collection/owned.rs | 4 +-- src/collection/ref.rs | 26 +++++++------- src/collection/utils.rs | 2 +- src/poisonable.rs | 84 +++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 189 insertions(+), 20 deletions(-) (limited to 'src') diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs index 72489bf..0597e90 100644 --- a/src/collection/boxed.rs +++ b/src/collection/boxed.rs @@ -627,6 +627,69 @@ mod tests { use super::*; use crate::{Mutex, RwLock, ThreadKey}; + #[test] + fn from_iterator() { + let key = ThreadKey::get().unwrap(); + let collection: BoxedLockCollection>> = + [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")] + .into_iter() + .collect(); + let guard = collection.lock(key); + // TODO impl PartialEq for MutexRef + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn from() { + let key = ThreadKey::get().unwrap(); + let collection = + BoxedLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]); + let guard = collection.lock(key); + // TODO impl PartialEq for MutexRef + assert_eq!(*guard[0], "foo"); + assert_eq!(*guard[1], "bar"); + assert_eq!(*guard[2], "baz"); + } + + #[test] + fn into_owned_iterator() { + let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.into_iter().enumerate() { + assert_eq!(mutex.into_inner(), i); + } + } + + #[test] + fn into_ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in (&collection).into_iter().enumerate() { + assert_eq!(*mutex.lock(&mut key), i); + } + } + + #[test] + fn ref_iterator() { + let mut key = ThreadKey::get().unwrap(); + let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); + for (i, mutex) in collection.iter().enumerate() { + assert_eq!(*mutex.lock(&mut key), i); + } + } + + #[test] + #[allow(clippy::float_cmp)] + fn uses_correct_default() { + let collection = + BoxedLockCollection::<(Mutex, Mutex>, Mutex)>::default(); + let tuple = collection.into_inner(); + assert_eq!(tuple.0, 0.0); + assert!(tuple.1.is_none()); + assert_eq!(tuple.2, 0) + } + #[test] fn non_duplicates_allowed() { let mutex1 = Mutex::new(0); @@ -732,6 +795,36 @@ mod tests { assert!(guard.is_ok()); } + #[test] + fn unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let mutex1 = Mutex::new("foo"); + let mutex2 = Mutex::new("bar"); + let collection = BoxedLockCollection::try_new((&mutex1, &mutex2)).unwrap(); + let guard = collection.lock(key); + let key = BoxedLockCollection::<(&Mutex<_>, &Mutex<_>)>::unlock(guard); + + assert!(mutex1.try_lock(key).is_ok()) + } + + #[test] + fn read_unlock_collection_works() { + let key = ThreadKey::get().unwrap(); + let lock1 = RwLock::new("foo"); + let lock2 = RwLock::new("bar"); + let collection = BoxedLockCollection::try_new((&lock1, &lock2)).unwrap(); + let guard = collection.read(key); + let key = BoxedLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard); + + assert!(lock1.try_write(key).is_ok()) + } + + #[test] + fn into_inner_works() { + let collection = BoxedLockCollection::new((Mutex::new("Hello"), Mutex::new(47))); + assert_eq!(collection.into_inner(), ("Hello", 47)) + } + #[test] fn works_in_collection() { let key = ThreadKey::get().unwrap(); diff --git a/src/collection/owned.rs b/src/collection/owned.rs index 4a0d1ef..c345b43 100644 --- a/src/collection/owned.rs +++ b/src/collection/owned.rs @@ -197,7 +197,7 @@ impl OwnedLockCollection { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key, Key: Keyable + 'key>( + pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( &'g self, key: Key, ) -> LockGuard<'key, L::Guard<'g>, Key> { @@ -315,7 +315,7 @@ impl OwnedLockCollection { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key, Key: Keyable + 'key>( + pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( &'g self, key: Key, ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { diff --git a/src/collection/ref.rs b/src/collection/ref.rs index 37973f6..c86f298 100644 --- a/src/collection/ref.rs +++ b/src/collection/ref.rs @@ -257,10 +257,10 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'key: 'a, Key: Keyable + 'key>( - &'a self, + pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( + &'g self, key: Key, - ) -> LockGuard<'key, L::Guard<'a>, Key> { + ) -> LockGuard<'key, L::Guard<'g>, Key> { let guard = unsafe { // safety: we have the thread key self.raw_lock(); @@ -306,10 +306,10 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// }; /// /// ``` - pub fn try_lock<'key: 'a, Key: Keyable + 'key>( - &'a self, + pub fn try_lock<'g, 'key: 'a, Key: Keyable + 'key>( + &'g self, key: Key, - ) -> Result, Key>, Key> { + ) -> Result, Key>, Key> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -345,7 +345,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// let key = RefLockCollection::<(Mutex, Mutex<&str>)>::unlock(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock<'key: 'a, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'a>, Key>) -> Key { + pub fn unlock<'g, 'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'g>, Key>) -> Key { drop(guard.guard); guard.key } @@ -372,10 +372,10 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'key: 'a, Key: Keyable + 'key>( - &'a self, + pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( + &'g self, key: Key, - ) -> LockGuard<'key, L::ReadGuard<'a>, Key> { + ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { unsafe { // safety: we have the thread key self.raw_read(); @@ -420,10 +420,10 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// }; /// /// ``` - pub fn try_read<'key: 'a, Key: Keyable + 'key>( - &'a self, + pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( + &'g self, key: Key, - ) -> Option, Key>> { + ) -> Option, Key>> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { diff --git a/src/collection/utils.rs b/src/collection/utils.rs index 7f29037..d6d50f4 100644 --- a/src/collection/utils.rs +++ b/src/collection/utils.rs @@ -109,7 +109,7 @@ pub unsafe fn attempt_to_recover_reads_from_panic(locked: &[&dyn RawLock]) { handle_unwind( || { // safety: the caller assumes these are already locked - locked.iter().for_each(|lock| lock.raw_unlock()); + locked.iter().for_each(|lock| lock.raw_unlock_read()); }, // if we get another panic in here, we'll just have to poison what remains || locked.iter().for_each(|l| l.poison()), diff --git a/src/poisonable.rs b/src/poisonable.rs index e577ce9..bb6ad28 100644 --- a/src/poisonable.rs +++ b/src/poisonable.rs @@ -55,7 +55,7 @@ pub struct PoisonRef<'a, G> { /// An RAII guard for a [`Poisonable`]. /// /// This is created by calling methods like [`Poisonable::lock`]. -pub struct PoisonGuard<'a, 'key, G, Key> { +pub struct PoisonGuard<'a, 'key: 'a, G, Key: 'key> { guard: PoisonRef<'a, G>, key: Key, _phantom: PhantomData<&'key ()>, @@ -69,7 +69,7 @@ pub struct PoisonError { /// An enumeration of possible errors associated with /// [`TryLockPoisonableResult`] which can occur while trying to acquire a lock /// (i.e.: [`Poisonable::try_lock`]). -pub enum TryLockPoisonableError<'flag, 'key, G, Key: 'key> { +pub enum TryLockPoisonableError<'flag, 'key: 'flag, G, Key: 'key> { Poisoned(PoisonError>), WouldBlock(Key), } @@ -94,6 +94,8 @@ pub type TryLockPoisonableResult<'flag, 'key, G, Key> = #[cfg(test)] mod tests { + use std::sync::Arc; + use super::*; use crate::lockable::Lockable; use crate::{LockCollection, Mutex, ThreadKey}; @@ -120,8 +122,8 @@ mod tests { let guard1 = guard.0.as_ref().unwrap(); let guard2 = guard.1.as_ref().unwrap(); let guard3 = guard.2.as_ref().unwrap(); - assert!(guard1 < guard2); - assert!(guard2 > guard1); + assert_eq!(guard1.cmp(guard2), std::cmp::Ordering::Less); + assert_eq!(guard2.cmp(guard1), std::cmp::Ordering::Greater); assert!(guard2 == guard3); assert!(guard1 != guard3); } @@ -137,6 +139,80 @@ mod tests { assert!(std::ptr::addr_eq(lock_ptrs[0], &poisonable.inner)); } + #[test] + fn clear_poison_for_poisoned_mutex() { + let mutex = Arc::new(Poisonable::new(Mutex::new(0))); + let c_mutex = Arc::clone(&mutex); + + let _ = std::thread::spawn(move || { + let key = ThreadKey::get().unwrap(); + let _lock = c_mutex.lock(key).unwrap(); + panic!(); // the mutex gets poisoned + }) + .join(); + + assert!(mutex.is_poisoned()); + + let key = ThreadKey::get().unwrap(); + let _ = mutex.lock(key).unwrap_or_else(|mut e| { + **e.get_mut() = 1; + mutex.clear_poison(); + e.into_inner() + }); + + assert!(!mutex.is_poisoned()); + } + + #[test] + fn error_as_ref() { + let mutex = Poisonable::new(Mutex::new("foo")); + + let _ = std::panic::catch_unwind(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused_variables)] + let guard = mutex.lock(key); + panic!(); + + #[allow(unknown_lints)] + #[allow(unreachable_code)] + drop(guard); + }); + + assert!(mutex.is_poisoned()); + + let key = ThreadKey::get().unwrap(); + let error = mutex.lock(key).unwrap_err(); + assert_eq!(&***error.as_ref(), "foo"); + } + + #[test] + fn error_as_mut() { + let mutex = Poisonable::new(Mutex::new("foo")); + + let _ = std::panic::catch_unwind(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused_variables)] + let guard = mutex.lock(key); + panic!(); + + #[allow(unknown_lints)] + #[allow(unreachable_code)] + drop(guard); + }); + + assert!(mutex.is_poisoned()); + + let mut key = ThreadKey::get().unwrap(); + let mut error = mutex.lock(&mut key).unwrap_err(); + let error1 = error.as_mut(); + **error1 = "bar"; + drop(error); + + mutex.clear_poison(); + let guard = mutex.lock(&mut key).unwrap(); + assert_eq!(&**guard, "bar"); + } + #[test] fn new_poisonable_is_not_poisoned() { let mutex = Poisonable::new(Mutex::new(42)); -- cgit v1.2.3