diff options
| author | Botahamec <botahamec@outlook.com> | 2025-02-28 16:09:11 -0500 |
|---|---|---|
| committer | Botahamec <botahamec@outlook.com> | 2025-02-28 16:09:11 -0500 |
| commit | 4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (patch) | |
| tree | a257184577a93ddf240aba698755c2886188788b /src | |
| parent | 4a5ec04a29cba07c5960792528bd66b0f99ee3ee (diff) | |
Scoped lock API
Diffstat (limited to 'src')
| -rw-r--r-- | src/collection.rs | 8 | ||||
| -rw-r--r-- | src/collection/boxed.rs | 170 | ||||
| -rw-r--r-- | src/collection/guard.rs | 83 | ||||
| -rw-r--r-- | src/collection/owned.rs | 262 | ||||
| -rw-r--r-- | src/collection/ref.rs | 261 | ||||
| -rw-r--r-- | src/collection/retry.rs | 176 | ||||
| -rw-r--r-- | src/collection/utils.rs | 40 | ||||
| -rw-r--r-- | src/lockable.rs | 126 | ||||
| -rw-r--r-- | src/mutex.rs | 58 | ||||
| -rw-r--r-- | src/mutex/guard.rs | 71 | ||||
| -rw-r--r-- | src/mutex/mutex.rs | 61 | ||||
| -rw-r--r-- | src/poisonable.rs | 229 | ||||
| -rw-r--r-- | src/poisonable/error.rs | 14 | ||||
| -rw-r--r-- | src/poisonable/guard.rs | 64 | ||||
| -rw-r--r-- | src/poisonable/poisonable.rs | 157 | ||||
| -rw-r--r-- | src/rwlock.rs | 57 | ||||
| -rw-r--r-- | src/rwlock/read_guard.rs | 71 | ||||
| -rw-r--r-- | src/rwlock/read_lock.rs | 34 | ||||
| -rw-r--r-- | src/rwlock/rwlock.rs | 130 | ||||
| -rw-r--r-- | src/rwlock/write_guard.rs | 75 | ||||
| -rw-r--r-- | src/rwlock/write_lock.rs | 25 |
21 files changed, 1438 insertions, 734 deletions
diff --git a/src/collection.rs b/src/collection.rs index db68382..e50cc30 100644 --- a/src/collection.rs +++ b/src/collection.rs @@ -1,7 +1,6 @@ use std::cell::UnsafeCell; -use std::marker::PhantomData; -use crate::{key::Keyable, lockable::RawLock}; +use crate::{lockable::RawLock, ThreadKey}; mod boxed; mod guard; @@ -122,8 +121,7 @@ pub struct RetryingLockCollection<L> { /// A RAII guard for a generic [`Lockable`] type. /// /// [`Lockable`]: `crate::lockable::Lockable` -pub struct LockGuard<'key, Guard, Key: Keyable + 'key> { +pub struct LockGuard<Guard> { guard: Guard, - key: Key, - _phantom: PhantomData<&'key ()>, + key: ThreadKey, } diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs index 0597e90..364ec97 100644 --- a/src/collection/boxed.rs +++ b/src/collection/boxed.rs @@ -1,25 +1,12 @@ use std::cell::UnsafeCell; use std::fmt::Debug; -use std::marker::PhantomData; use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable}; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; +use super::utils::ordered_contains_duplicates; use super::{utils, BoxedLockCollection, LockGuard}; -/// returns `true` if the sorted list contains a duplicate -#[must_use] -fn contains_duplicates(l: &[&dyn RawLock]) -> bool { - if l.is_empty() { - // Return early to prevent panic in the below call to `windows` - return false; - } - - l.windows(2) - // NOTE: addr_eq is necessary because eq would also compare the v-table pointers - .any(|window| std::ptr::addr_eq(window[0], window[1])) -} - unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> { #[mutants::skip] // this should never be called #[cfg(not(tarpaulin_include))] @@ -65,6 +52,11 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.extend(self.locks()) } @@ -72,6 +64,10 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.child().guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.child().data_mut() + } } unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> { @@ -80,9 +76,18 @@ unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.child().read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.child().data_ref() + } } unsafe impl<L: OwnedLockable> OwnedLockable for BoxedLockCollection<L> {} @@ -352,13 +357,53 @@ impl<L: Lockable> BoxedLockCollection<L> { // safety: we are checking for duplicates before returning unsafe { let this = Self::new_unchecked(data); - if contains_duplicates(this.locks()) { + if ordered_contains_duplicates(this.locks()) { return None; } Some(this) } } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -378,10 +423,8 @@ impl<L: Lockable> BoxedLockCollection<L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { unsafe { // safety: we have the thread key self.raw_lock(); @@ -390,7 +433,6 @@ impl<L: Lockable> BoxedLockCollection<L> { // safety: we've already acquired the lock guard: self.child().guard(), key, - _phantom: PhantomData, } } } @@ -424,10 +466,7 @@ impl<L: Lockable> BoxedLockCollection<L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -437,11 +476,7 @@ impl<L: Lockable> BoxedLockCollection<L> { self.child().guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -461,13 +496,53 @@ impl<L: Lockable> BoxedLockCollection<L> { /// *guard.1 = "1"; /// let key = LockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable> BoxedLockCollection<L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -487,10 +562,8 @@ impl<L: Sharable> BoxedLockCollection<L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + #[must_use] + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we have the thread key self.raw_read(); @@ -499,7 +572,6 @@ impl<L: Sharable> BoxedLockCollection<L> { // safety: we've already acquired the lock guard: self.child().read_guard(), key, - _phantom: PhantomData, } } } @@ -534,10 +606,7 @@ impl<L: Sharable> BoxedLockCollection<L> { /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::ReadGuard<'g>, Key>, Key> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { @@ -548,11 +617,7 @@ impl<L: Sharable> BoxedLockCollection<L> { self.child().read_guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -570,9 +635,7 @@ impl<L: Sharable> BoxedLockCollection<L> { /// let mut guard = lock.read(key); /// let key = LockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'_>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -635,7 +698,6 @@ mod tests { .into_iter() .collect(); let guard = collection.lock(key); - // TODO impl PartialEq<T> for MutexRef<T> assert_eq!(*guard[0], "foo"); assert_eq!(*guard[1], "bar"); assert_eq!(*guard[2], "baz"); @@ -647,7 +709,6 @@ mod tests { let collection = BoxedLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]); let guard = collection.lock(key); - // TODO impl PartialEq<T> for MutexRef<T> assert_eq!(*guard[0], "foo"); assert_eq!(*guard[1], "bar"); assert_eq!(*guard[2], "baz"); @@ -666,7 +727,7 @@ mod tests { let mut key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); for (i, mutex) in (&collection).into_iter().enumerate() { - assert_eq!(*mutex.lock(&mut key), i); + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) } } @@ -675,7 +736,7 @@ mod tests { let mut key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]); for (i, mutex) in collection.iter().enumerate() { - assert_eq!(*mutex.lock(&mut key), i); + mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i)) } } @@ -704,11 +765,6 @@ mod tests { } #[test] - fn contains_duplicates_empty() { - assert!(!contains_duplicates(&[])) - } - - #[test] fn try_lock_works() { let key = ThreadKey::get().unwrap(); let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]); diff --git a/src/collection/guard.rs b/src/collection/guard.rs index eea13ed..78d9895 100644 --- a/src/collection/guard.rs +++ b/src/collection/guard.rs @@ -2,41 +2,11 @@ use std::fmt::{Debug, Display}; use std::hash::Hash; use std::ops::{Deref, DerefMut}; -use crate::key::Keyable; - use super::LockGuard; -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: PartialEq, Key: Keyable> PartialEq for LockGuard<'_, Guard, Key> { - fn eq(&self, other: &Self) -> bool { - self.guard.eq(&other.guard) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: PartialOrd, Key: Keyable> PartialOrd for LockGuard<'_, Guard, Key> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.guard.partial_cmp(&other.guard) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: Eq, Key: Keyable> Eq for LockGuard<'_, Guard, Key> {} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: Ord, Key: Keyable> Ord for LockGuard<'_, Guard, Key> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.guard.cmp(&other.guard) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] -impl<Guard: Hash, Key: Keyable> Hash for LockGuard<'_, Guard, Key> { +impl<Guard: Hash> Hash for LockGuard<Guard> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.guard.hash(state) } @@ -44,19 +14,19 @@ impl<Guard: Hash, Key: Keyable> Hash for LockGuard<'_, Guard, Key> { #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<Guard: Debug, Key: Keyable> Debug for LockGuard<'_, Guard, Key> { +impl<Guard: Debug> Debug for LockGuard<Guard> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&**self, f) } } -impl<Guard: Display, Key: Keyable> Display for LockGuard<'_, Guard, Key> { +impl<Guard: Display> Display for LockGuard<Guard> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&**self, f) } } -impl<Guard, Key: Keyable> Deref for LockGuard<'_, Guard, Key> { +impl<Guard> Deref for LockGuard<Guard> { type Target = Guard; fn deref(&self) -> &Self::Target { @@ -64,19 +34,19 @@ impl<Guard, Key: Keyable> Deref for LockGuard<'_, Guard, Key> { } } -impl<Guard, Key: Keyable> DerefMut for LockGuard<'_, Guard, Key> { +impl<Guard> DerefMut for LockGuard<Guard> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.guard } } -impl<Guard, Key: Keyable> AsRef<Guard> for LockGuard<'_, Guard, Key> { +impl<Guard> AsRef<Guard> for LockGuard<Guard> { fn as_ref(&self) -> &Guard { &self.guard } } -impl<Guard, Key: Keyable> AsMut<Guard> for LockGuard<'_, Guard, Key> { +impl<Guard> AsMut<Guard> for LockGuard<Guard> { fn as_mut(&mut self) -> &mut Guard { &mut self.guard } @@ -97,56 +67,53 @@ mod tests { #[test] fn deref_mut_works() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let locks = (Mutex::new(1), Mutex::new(2)); let lock = LockCollection::new_ref(&locks); - let mut guard = lock.lock(&mut key); + let mut guard = lock.lock(key); *guard.0 = 3; - drop(guard); + let key = LockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); - let guard = locks.0.lock(&mut key); + let guard = locks.0.lock(key); assert_eq!(*guard, 3); - drop(guard); + let key = Mutex::unlock(guard); - let guard = locks.1.lock(&mut key); + let guard = locks.1.lock(key); assert_eq!(*guard, 2); - drop(guard); } #[test] fn as_ref_works() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let locks = (Mutex::new(1), Mutex::new(2)); let lock = LockCollection::new_ref(&locks); - let mut guard = lock.lock(&mut key); + let mut guard = lock.lock(key); *guard.0 = 3; - drop(guard); + let key = LockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); - let guard = locks.0.lock(&mut key); + let guard = locks.0.lock(key); assert_eq!(guard.as_ref(), &3); - drop(guard); + let key = Mutex::unlock(guard); - let guard = locks.1.lock(&mut key); + let guard = locks.1.lock(key); assert_eq!(guard.as_ref(), &2); - drop(guard); } #[test] fn as_mut_works() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let locks = (Mutex::new(1), Mutex::new(2)); let lock = LockCollection::new_ref(&locks); - let mut guard = lock.lock(&mut key); + let mut guard = lock.lock(key); let guard_mut = guard.as_mut(); *guard_mut.0 = 3; - drop(guard); + let key = LockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard); - let guard = locks.0.lock(&mut key); + let guard = locks.0.lock(key); assert_eq!(guard.as_ref(), &3); - drop(guard); + let key = Mutex::unlock(guard); - let guard = locks.1.lock(&mut key); + let guard = locks.1.lock(key); assert_eq!(guard.as_ref(), &2); - drop(guard); } } diff --git a/src/collection/owned.rs b/src/collection/owned.rs index c345b43..b9cf313 100644 --- a/src/collection/owned.rs +++ b/src/collection/owned.rs @@ -1,57 +1,47 @@ -use std::marker::PhantomData; - use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; use super::{utils, LockGuard, OwnedLockCollection}; -#[mutants::skip] // it's hard to test individual locks in an OwnedLockCollection -#[cfg(not(tarpaulin_include))] -fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); - locks -} - unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> { #[mutants::skip] // this should never run #[cfg(not(tarpaulin_include))] fn poison(&self) { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); for lock in locks { lock.poison(); } } unsafe fn raw_lock(&self) { - utils::ordered_lock(&get_locks(&self.data)) + utils::ordered_lock(&utils::get_locks_unsorted(&self.data)) } unsafe fn raw_try_lock(&self) -> bool { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); utils::ordered_try_lock(&locks) } unsafe fn raw_unlock(&self) { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock(); } } unsafe fn raw_read(&self) { - utils::ordered_read(&get_locks(&self.data)) + utils::ordered_read(&utils::get_locks_unsorted(&self.data)) } unsafe fn raw_try_read(&self) -> bool { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); utils::ordered_try_read(&locks) } unsafe fn raw_unlock_read(&self) { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock_read(); } @@ -64,6 +54,11 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + #[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned #[cfg(not(tarpaulin_include))] fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { @@ -73,6 +68,10 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.data.guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.data_mut() + } } impl<L: LockableGetMut> LockableGetMut for OwnedLockCollection<L> { @@ -100,9 +99,18 @@ unsafe impl<L: Sharable> Sharable for OwnedLockCollection<L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.data.read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.data_ref() + } } unsafe impl<L: OwnedLockable> OwnedLockable for OwnedLockCollection<L> {} @@ -177,6 +185,46 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { Self { data } } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -197,10 +245,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { let guard = unsafe { // safety: we have the thread key, and these locks happen in a // predetermined order @@ -210,11 +255,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { self.data.guard() }; - LockGuard { - guard, - key, - _phantom: PhantomData, - } + LockGuard { guard, key } } /// Attempts to lock the without blocking. @@ -247,10 +288,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -260,11 +298,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { self.data.guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -286,15 +320,53 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// let key = OwnedLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock<'g, 'key: 'g, Key: Keyable + 'key>( - guard: LockGuard<'key, L::Guard<'g>, Key>, - ) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable> OwnedLockCollection<L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -315,10 +387,7 @@ impl<L: Sharable> OwnedLockCollection<L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we have the thread key self.raw_read(); @@ -327,7 +396,6 @@ impl<L: Sharable> OwnedLockCollection<L> { // safety: we've already acquired the lock guard: self.data.read_guard(), key, - _phantom: PhantomData, } } } @@ -355,33 +423,26 @@ impl<L: Sharable> OwnedLockCollection<L> { /// let lock = OwnedLockCollection::new(data); /// /// match lock.try_read(key) { - /// Some(mut guard) => { + /// Ok(mut guard) => { /// assert_eq!(*guard.0, 5); /// assert_eq!(*guard.1, "6"); /// }, - /// None => unreachable!(), + /// Err(_) => unreachable!(), /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { - return None; + return Err(key); } // safety: we've acquired the locks self.data.read_guard() }; - Some(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -401,9 +462,7 @@ impl<L: Sharable> OwnedLockCollection<L> { /// let key = OwnedLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock_read<'g, 'key: 'g, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'g>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -495,7 +554,90 @@ impl<L: LockableIntoInner> OwnedLockCollection<L> { #[cfg(test)] mod tests { use super::*; - use crate::Mutex; + use crate::{Mutex, ThreadKey}; + + #[test] + fn get_mut_applies_changes() { + let key = ThreadKey::get().unwrap(); + let mut collection = OwnedLockCollection::new([Mutex::new("foo"), Mutex::new("bar")]); + assert_eq!(*collection.get_mut()[0], "foo"); + assert_eq!(*collection.get_mut()[1], "bar"); + *collection.get_mut()[0] = "baz"; + + let guard = collection.lock(key); + assert_eq!(*guard[0], "baz"); + assert_eq!(*guard[1], "bar"); + } + + #[test] + fn into_inner_works() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::from([Mutex::new("foo")]); + let mut guard = collection.lock(key); + *guard[0] = "bar"; + drop(guard); + + let array = collection.into_inner(); + assert_eq!(array.len(), 1); + assert_eq!(array[0], "bar"); + } + + #[test] + fn from_into_iter_is_correct() { + let array = [Mutex::new(0), Mutex::new(1), Mutex::new(2), Mutex::new(3)]; + let mut collection: OwnedLockCollection<Vec<Mutex<usize>>> = array.into_iter().collect(); + assert_eq!(collection.get_mut().len(), 4); + for (i, lock) in collection.into_iter().enumerate() { + assert_eq!(lock.into_inner(), i); + } + } + + #[test] + fn from_iter_is_correct() { + let array = [Mutex::new(0), Mutex::new(1), Mutex::new(2), Mutex::new(3)]; + let mut collection: OwnedLockCollection<Vec<Mutex<usize>>> = array.into_iter().collect(); + let collection: &mut Vec<_> = collection.as_mut(); + assert_eq!(collection.len(), 4); + for (i, lock) in collection.iter_mut().enumerate() { + assert_eq!(*lock.get_mut(), i); + } + } + + #[test] + fn try_lock_works_on_unlocked() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1))); + let guard = collection.try_lock(key).unwrap(); + assert_eq!(*guard.0, 0); + assert_eq!(*guard.1, 1); + } + + #[test] + fn try_lock_fails_on_locked() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1))); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused)] + let guard = collection.lock(key); + std::mem::forget(guard); + }); + }); + + assert!(collection.try_lock(key).is_err()); + } + + #[test] + fn default_works() { + type MyCollection = OwnedLockCollection<(Mutex<i32>, Mutex<Option<i32>>, Mutex<String>)>; + let collection = MyCollection::default(); + let inner = collection.into_inner(); + assert_eq!(inner.0, 0); + assert_eq!(inner.1, None); + assert_eq!(inner.2, String::new()); + } #[test] fn can_be_extended() { diff --git a/src/collection/ref.rs b/src/collection/ref.rs index c86f298..b68b72f 100644 --- a/src/collection/ref.rs +++ b/src/collection/ref.rs @@ -1,32 +1,11 @@ use std::fmt::Debug; -use std::marker::PhantomData; use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable}; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; +use super::utils::{get_locks, ordered_contains_duplicates}; use super::{utils, LockGuard, RefLockCollection}; -#[must_use] -pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); - locks.sort_by_key(|lock| &raw const **lock); - locks -} - -/// returns `true` if the sorted list contains a duplicate -#[must_use] -fn contains_duplicates(l: &[&dyn RawLock]) -> bool { - if l.is_empty() { - // Return early to prevent panic in the below call to `windows` - return false; - } - - l.windows(2) - // NOTE: addr_eq is necessary because eq would also compare the v-table pointers - .any(|window| std::ptr::addr_eq(window[0], window[1])) -} - impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L> where &'a L: IntoIterator, @@ -83,6 +62,11 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.extend_from_slice(&self.locks); } @@ -90,6 +74,10 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.data.guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.data_mut() + } } unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> { @@ -98,9 +86,18 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.data.read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.data_ref() + } } impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> { @@ -230,13 +227,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { #[must_use] pub fn try_new(data: &'a L) -> Option<Self> { let locks = get_locks(data); - if contains_duplicates(&locks) { + if ordered_contains_duplicates(&locks) { return None; } Some(Self { data, locks }) } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -257,10 +294,8 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { let guard = unsafe { // safety: we have the thread key self.raw_lock(); @@ -269,11 +304,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { self.data.guard() }; - LockGuard { - guard, - key, - _phantom: PhantomData, - } + LockGuard { guard, key } } /// Attempts to lock the without blocking. @@ -306,10 +337,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'a, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -319,11 +347,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { self.data.guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -345,13 +369,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> { /// let key = RefLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock<'g, 'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'g>, Key>) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } -impl<'a, L: Sharable> RefLockCollection<'a, L> { +impl<L: Sharable> RefLockCollection<'_, L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -372,10 +436,8 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + #[must_use] + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we have the thread key self.raw_read(); @@ -384,7 +446,6 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { // safety: we've already acquired the lock guard: self.data.read_guard(), key, - _phantom: PhantomData, } } } @@ -412,33 +473,26 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// let lock = RefLockCollection::new(&data); /// /// match lock.try_read(key) { - /// Some(mut guard) => { + /// Ok(mut guard) => { /// assert_eq!(*guard.0, 5); /// assert_eq!(*guard.1, "6"); /// }, - /// None => unreachable!(), + /// Err(_) => unreachable!(), /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { - return None; + return Err(key); } // safety: we've acquired the locks self.data.read_guard() }; - Some(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -458,9 +512,7 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> { /// let key = RefLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock_read<'key: 'a, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'a>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -497,7 +549,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{Mutex, ThreadKey}; + use crate::{Mutex, RwLock, ThreadKey}; #[test] fn non_duplicates_allowed() { @@ -513,6 +565,85 @@ mod tests { } #[test] + fn try_lock_succeeds_for_unlocked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(24), Mutex::new(42)]; + let collection = RefLockCollection::new(&mutexes); + let guard = collection.try_lock(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn try_lock_fails_for_locked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [Mutex::new(24), Mutex::new(42)]; + let collection = RefLockCollection::new(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = mutexes[1].lock(key); + assert_eq!(*guard, 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_lock(key); + assert!(guard.is_err()); + } + + #[test] + fn try_read_succeeds_for_unlocked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] + fn try_read_fails_for_locked_collection() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = mutexes[1].write(key); + assert_eq!(*guard, 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key); + assert!(guard.is_err()); + } + + #[test] + fn can_read_twice_on_different_threads() { + let key = ThreadKey::get().unwrap(); + let mutexes = [RwLock::new(24), RwLock::new(42)]; + let collection = RefLockCollection::new(&mutexes); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let guard = collection.read(key); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + std::mem::forget(guard); + }); + }); + + let guard = collection.try_read(key).unwrap(); + assert_eq!(*guard[0], 24); + assert_eq!(*guard[1], 42); + } + + #[test] fn works_in_collection() { let key = ThreadKey::get().unwrap(); let mutex1 = Mutex::new(0); diff --git a/src/collection/retry.rs b/src/collection/retry.rs index 331b669..775ea29 100644 --- a/src/collection/retry.rs +++ b/src/collection/retry.rs @@ -1,24 +1,18 @@ use std::cell::Cell; use std::collections::HashSet; -use std::marker::PhantomData; use crate::collection::utils; use crate::handle_unwind::handle_unwind; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; -use super::utils::{attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic}; +use super::utils::{ + attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted, +}; use super::{LockGuard, RetryingLockCollection}; -/// Get all raw locks in the collection -fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); - locks -} - /// Checks that a collection contains no duplicate references to a lock. fn contains_duplicates<L: Lockable>(data: L) -> bool { let mut locks = Vec::new(); @@ -40,14 +34,14 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { #[mutants::skip] // this should never run #[cfg(not(tarpaulin_include))] fn poison(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); for lock in locks { lock.poison(); } } unsafe fn raw_lock(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this probably prevents a panic later @@ -109,7 +103,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_try_lock(&self) -> bool { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this is an interesting case, but it doesn't give us access to @@ -139,7 +133,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_unlock(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock(); @@ -147,7 +141,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_read(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this probably prevents a panic later @@ -200,7 +194,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_try_read(&self) -> bool { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); if locks.is_empty() { // this is an interesting case, but it doesn't give us access to @@ -229,7 +223,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> { } unsafe fn raw_unlock_read(&self) { - let locks = get_locks(&self.data); + let locks = get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock_read(); @@ -243,6 +237,11 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { self.data.get_ptrs(ptrs) } @@ -250,6 +249,10 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.data.guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.data_mut() + } } unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> { @@ -258,9 +261,18 @@ unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.data.read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.data_ref() + } } unsafe impl<L: OwnedLockable> OwnedLockable for RetryingLockCollection<L> {} @@ -516,6 +528,46 @@ impl<L: Lockable> RetryingLockCollection<L> { (!contains_duplicates(&data)).then_some(Self { data }) } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -536,10 +588,7 @@ impl<L: Lockable> RetryingLockCollection<L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { unsafe { // safety: we're taking the thread key self.raw_lock(); @@ -548,7 +597,6 @@ impl<L: Lockable> RetryingLockCollection<L> { // safety: we just locked the collection guard: self.guard(), key, - _phantom: PhantomData, } } } @@ -583,10 +631,7 @@ impl<L: Lockable> RetryingLockCollection<L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { unsafe { // safety: we're taking the thread key if self.raw_try_lock() { @@ -594,7 +639,6 @@ impl<L: Lockable> RetryingLockCollection<L> { // safety: we just succeeded in locking everything guard: self.guard(), key, - _phantom: PhantomData, }) } else { Err(key) @@ -620,13 +664,53 @@ impl<L: Lockable> RetryingLockCollection<L> { /// *guard.1 = "1"; /// let key = RetryingLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable> RetryingLockCollection<L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -647,10 +731,7 @@ impl<L: Sharable> RetryingLockCollection<L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we're taking the thread key self.raw_read(); @@ -659,7 +740,6 @@ impl<L: Sharable> RetryingLockCollection<L> { // safety: we just locked the collection guard: self.read_guard(), key, - _phantom: PhantomData, } } } @@ -687,25 +767,25 @@ impl<L: Sharable> RetryingLockCollection<L> { /// let lock = RetryingLockCollection::new(data); /// /// match lock.try_read(key) { - /// Some(mut guard) => { + /// Ok(mut guard) => { /// assert_eq!(*guard.0, 5); /// assert_eq!(*guard.1, "6"); /// }, - /// None => unreachable!(), + /// Err(_) => unreachable!(), /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { unsafe { // safety: we're taking the thread key - self.raw_try_lock().then(|| LockGuard { + if !self.raw_try_lock() { + return Err(key); + } + + Ok(LockGuard { // safety: we just succeeded in locking everything guard: self.read_guard(), key, - _phantom: PhantomData, }) } } @@ -726,9 +806,7 @@ impl<L: Sharable> RetryingLockCollection<L> { /// let mut guard = lock.read(key); /// let key = RetryingLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'_>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -833,7 +911,7 @@ where mod tests { use super::*; use crate::collection::BoxedLockCollection; - use crate::{Mutex, RwLock, ThreadKey}; + use crate::{LockCollection, Mutex, RwLock, ThreadKey}; #[test] fn nonduplicate_lock_references_are_allowed() { @@ -869,7 +947,6 @@ mod tests { let rwlock1 = RwLock::new(0); let rwlock2 = RwLock::new(0); let collection = RetryingLockCollection::try_new([&rwlock1, &rwlock2]).unwrap(); - // TODO Poisonable::read let guard = collection.read(key); @@ -909,13 +986,14 @@ mod tests { #[test] fn lock_empty_lock_collection() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let collection: RetryingLockCollection<[RwLock<i32>; 0]> = RetryingLockCollection::new([]); - let guard = collection.lock(&mut key); + let guard = collection.lock(key); assert!(guard.len() == 0); + let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard); - let guard = collection.read(&mut key); + let guard = collection.read(key); assert!(guard.len() == 0); } } diff --git a/src/collection/utils.rs b/src/collection/utils.rs index d6d50f4..1d96e5c 100644 --- a/src/collection/utils.rs +++ b/src/collection/utils.rs @@ -1,7 +1,35 @@ use std::cell::Cell; use crate::handle_unwind::handle_unwind; -use crate::lockable::RawLock; +use crate::lockable::{Lockable, RawLock}; + +#[must_use] +pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { + let mut locks = Vec::new(); + data.get_ptrs(&mut locks); + locks.sort_by_key(|lock| &raw const **lock); + locks +} + +#[must_use] +pub fn get_locks_unsorted<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { + let mut locks = Vec::new(); + data.get_ptrs(&mut locks); + locks +} + +/// returns `true` if the sorted list contains a duplicate +#[must_use] +pub fn ordered_contains_duplicates(l: &[&dyn RawLock]) -> bool { + if l.is_empty() { + // Return early to prevent panic in the below call to `windows` + return false; + } + + l.windows(2) + // NOTE: addr_eq is necessary because eq would also compare the v-table pointers + .any(|window| std::ptr::addr_eq(window[0], window[1])) +} /// Lock a set of locks in the given order. It's UB to call this without a `ThreadKey` pub unsafe fn ordered_lock(locks: &[&dyn RawLock]) { @@ -115,3 +143,13 @@ pub unsafe fn attempt_to_recover_reads_from_panic(locked: &[&dyn RawLock]) { || locked.iter().for_each(|l| l.poison()), ) } + +#[cfg(test)] +mod tests { + use crate::collection::utils::ordered_contains_duplicates; + + #[test] + fn empty_array_does_not_contain_duplicates() { + assert!(!ordered_contains_duplicates(&[])) + } +} diff --git a/src/lockable.rs b/src/lockable.rs index 4f7cfe5..f125c02 100644 --- a/src/lockable.rs +++ b/src/lockable.rs @@ -111,6 +111,10 @@ pub unsafe trait Lockable { where Self: 'g; + type DataMut<'a> + where + Self: 'a; + /// Yields a list of references to the [`RawLock`]s contained within this /// value. /// @@ -129,6 +133,9 @@ pub unsafe trait Lockable { /// unlocked until this guard is dropped. #[must_use] unsafe fn guard(&self) -> Self::Guard<'_>; + + #[must_use] + unsafe fn data_mut(&self) -> Self::DataMut<'_>; } /// Allows a lock to be accessed by multiple readers. @@ -145,6 +152,10 @@ pub unsafe trait Sharable: Lockable { where Self: 'g; + type DataRef<'a> + where + Self: 'a; + /// Returns a guard that can be used to immutably access the underlying /// data. /// @@ -155,6 +166,9 @@ pub unsafe trait Sharable: Lockable { /// unlocked until this guard is dropped. #[must_use] unsafe fn read_guard(&self) -> Self::ReadGuard<'_>; + + #[must_use] + unsafe fn data_ref(&self) -> Self::DataRef<'_>; } /// A type that may be locked and unlocked, and is known to be the only valid @@ -210,6 +224,11 @@ unsafe impl<T: Lockable> Lockable for &T { where Self: 'g; + type DataMut<'a> + = T::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { (*self).get_ptrs(ptrs); } @@ -217,6 +236,10 @@ unsafe impl<T: Lockable> Lockable for &T { unsafe fn guard(&self) -> Self::Guard<'_> { (*self).guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + (*self).data_mut() + } } unsafe impl<T: Sharable> Sharable for &T { @@ -225,9 +248,18 @@ unsafe impl<T: Sharable> Sharable for &T { where Self: 'g; + type DataRef<'a> + = T::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { (*self).read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + (*self).data_ref() + } } unsafe impl<T: Lockable> Lockable for &mut T { @@ -236,6 +268,11 @@ unsafe impl<T: Lockable> Lockable for &mut T { where Self: 'g; + type DataMut<'a> + = T::DataMut<'a> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { (**self).get_ptrs(ptrs) } @@ -243,6 +280,10 @@ unsafe impl<T: Lockable> Lockable for &mut T { unsafe fn guard(&self) -> Self::Guard<'_> { (**self).guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + (**self).data_mut() + } } impl<T: LockableGetMut> LockableGetMut for &mut T { @@ -262,9 +303,18 @@ unsafe impl<T: Sharable> Sharable for &mut T { where Self: 'g; + type DataRef<'a> + = T::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { (**self).read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + (**self).data_ref() + } } unsafe impl<T: OwnedLockable> OwnedLockable for &mut T {} @@ -276,6 +326,8 @@ macro_rules! tuple_impls { unsafe impl<$($generic: Lockable,)*> Lockable for ($($generic,)*) { type Guard<'g> = ($($generic::Guard<'g>,)*) where Self: 'g; + type DataMut<'a> = ($($generic::DataMut<'a>,)*) where Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { $(self.$value.get_ptrs(ptrs));* } @@ -285,6 +337,10 @@ macro_rules! tuple_impls { // I don't think any other way of doing it compiles ($(self.$value.guard(),)*) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + ($(self.$value.data_mut(),)*) + } } impl<$($generic: LockableGetMut,)*> LockableGetMut for ($($generic,)*) { @@ -306,9 +362,15 @@ macro_rules! tuple_impls { unsafe impl<$($generic: Sharable,)*> Sharable for ($($generic,)*) { type ReadGuard<'g> = ($($generic::ReadGuard<'g>,)*) where Self: 'g; + type DataRef<'a> = ($($generic::DataRef<'a>,)*) where Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { ($(self.$value.read_guard(),)*) } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + ($(self.$value.data_ref(),)*) + } } unsafe impl<$($generic: OwnedLockable,)*> OwnedLockable for ($($generic,)*) {} @@ -329,6 +391,11 @@ unsafe impl<T: Lockable, const N: usize> Lockable for [T; N] { where Self: 'g; + type DataMut<'a> + = [T::DataMut<'a>; N] + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { for lock in self { lock.get_ptrs(ptrs); @@ -345,6 +412,15 @@ unsafe impl<T: Lockable, const N: usize> Lockable for [T; N] { guards.map(|g| g.assume_init()) } + + unsafe fn data_mut<'a>(&'a self) -> Self::DataMut<'a> { + let mut guards = MaybeUninit::<[MaybeUninit<T::DataMut<'a>>; N]>::uninit().assume_init(); + for i in 0..N { + guards[i].write(self[i].data_mut()); + } + + guards.map(|g| g.assume_init()) + } } impl<T: LockableGetMut, const N: usize> LockableGetMut for [T; N] { @@ -386,6 +462,11 @@ unsafe impl<T: Sharable, const N: usize> Sharable for [T; N] { where Self: 'g; + type DataRef<'a> + = [T::DataRef<'a>; N] + where + Self: 'a; + unsafe fn read_guard<'g>(&'g self) -> Self::ReadGuard<'g> { let mut guards = MaybeUninit::<[MaybeUninit<T::ReadGuard<'g>>; N]>::uninit().assume_init(); for i in 0..N { @@ -394,6 +475,15 @@ unsafe impl<T: Sharable, const N: usize> Sharable for [T; N] { guards.map(|g| g.assume_init()) } + + unsafe fn data_ref<'a>(&'a self) -> Self::DataRef<'a> { + let mut guards = MaybeUninit::<[MaybeUninit<T::DataRef<'a>>; N]>::uninit().assume_init(); + for i in 0..N { + guards[i].write(self[i].data_ref()); + } + + guards.map(|g| g.assume_init()) + } } unsafe impl<T: OwnedLockable, const N: usize> OwnedLockable for [T; N] {} @@ -404,6 +494,11 @@ unsafe impl<T: Lockable> Lockable for Box<[T]> { where Self: 'g; + type DataMut<'a> + = Box<[T::DataMut<'a>]> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { for lock in self { lock.get_ptrs(ptrs); @@ -413,6 +508,10 @@ unsafe impl<T: Lockable> Lockable for Box<[T]> { unsafe fn guard(&self) -> Self::Guard<'_> { self.iter().map(|lock| lock.guard()).collect() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.iter().map(|lock| lock.data_mut()).collect() + } } impl<T: LockableGetMut + 'static> LockableGetMut for Box<[T]> { @@ -432,9 +531,18 @@ unsafe impl<T: Sharable> Sharable for Box<[T]> { where Self: 'g; + type DataRef<'a> + = Box<[T::DataRef<'a>]> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.iter().map(|lock| lock.read_guard()).collect() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.iter().map(|lock| lock.data_ref()).collect() + } } unsafe impl<T: Sharable> Sharable for Vec<T> { @@ -443,9 +551,18 @@ unsafe impl<T: Sharable> Sharable for Vec<T> { where Self: 'g; + type DataRef<'a> + = Box<[T::DataRef<'a>]> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.iter().map(|lock| lock.read_guard()).collect() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.iter().map(|lock| lock.data_ref()).collect() + } } unsafe impl<T: OwnedLockable> OwnedLockable for Box<[T]> {} @@ -457,6 +574,11 @@ unsafe impl<T: Lockable> Lockable for Vec<T> { where Self: 'g; + type DataMut<'a> + = Box<[T::DataMut<'a>]> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { for lock in self { lock.get_ptrs(ptrs); @@ -466,6 +588,10 @@ unsafe impl<T: Lockable> Lockable for Vec<T> { unsafe fn guard(&self) -> Self::Guard<'_> { self.iter().map(|lock| lock.guard()).collect() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.iter().map(|lock| lock.data_mut()).collect() + } } // I'd make a generic impl<T: Lockable, I: IntoIterator<Item=T>> Lockable for I diff --git a/src/mutex.rs b/src/mutex.rs index d6cba7d..2022501 100644 --- a/src/mutex.rs +++ b/src/mutex.rs @@ -3,8 +3,8 @@ use std::marker::PhantomData; use lock_api::RawMutex; -use crate::key::Keyable; use crate::poisonable::PoisonFlag; +use crate::ThreadKey; mod guard; mod mutex; @@ -87,20 +87,19 @@ pub type ParkingMutex<T> = Mutex<T, parking_lot::RawMutex>; /// let mut key = ThreadKey::get().unwrap(); /// /// // Here we use a block to limit the lifetime of the lock guard. -/// let result = { -/// let mut data = data_mutex_clone.lock(&mut key); +/// let result = data_mutex_clone.scoped_lock(&mut key, |data| { /// let result = data.iter().fold(0, |acc, x| acc + x * 2); /// data.push(result); /// result /// // The mutex guard gets dropped here, so the lock is released -/// }; +/// }); /// // The thread key is available again /// *res_mutex_clone.lock(key) += result; /// })); /// }); /// -/// let mut key = ThreadKey::get().unwrap(); -/// let mut data = data_mutex.lock(&mut key); +/// let key = ThreadKey::get().unwrap(); +/// let mut data = data_mutex.lock(key); /// let result = data.iter().fold(0, |acc, x| acc + x * 2); /// data.push(result); /// @@ -108,12 +107,12 @@ pub type ParkingMutex<T> = Mutex<T, parking_lot::RawMutex>; /// // allows other threads to start working on the data immediately. Dropping /// // the data also gives us access to the thread key, so we can lock /// // another mutex. -/// drop(data); +/// let key = Mutex::unlock(data); /// /// // Here the mutex guard is not assigned to a variable and so, even if the /// // scope does not end after this line, the mutex is still released: there is /// // no deadlock. -/// *res_mutex.lock(&mut key) += result; +/// *res_mutex.lock(key) += result; /// /// threads.into_iter().for_each(|thread| { /// thread @@ -121,6 +120,7 @@ pub type ParkingMutex<T> = Mutex<T, parking_lot::RawMutex>; /// .expect("The thread creating or execution failed !") /// }); /// +/// let key = ThreadKey::get().unwrap(); /// assert_eq!(*res_mutex.lock(key), 800); /// ``` /// @@ -153,10 +153,9 @@ pub struct MutexRef<'a, T: ?Sized + 'a, R: RawMutex>( // // This is the most lifetime-intensive thing I've ever written. Can I graduate // from borrow checker university now? -pub struct MutexGuard<'a, 'key: 'a, T: ?Sized + 'a, Key: Keyable + 'key, R: RawMutex> { +pub struct MutexGuard<'a, T: ?Sized + 'a, R: RawMutex> { mutex: MutexRef<'a, T, R>, // this way we don't need to re-implement Drop - thread_key: Key, - _phantom: PhantomData<&'key ()>, + thread_key: ThreadKey, } #[cfg(test)] @@ -194,61 +193,46 @@ mod tests { #[test] fn display_works_for_ref() { let mutex: crate::Mutex<_> = Mutex::new("Hello, world!"); - let guard = unsafe { mutex.try_lock_no_key().unwrap() }; // TODO lock_no_key + let guard = unsafe { mutex.try_lock_no_key().unwrap() }; assert_eq!(guard.to_string(), "Hello, world!".to_string()); } #[test] - fn ord_works() { - let key = ThreadKey::get().unwrap(); - let mutex1: crate::Mutex<_> = Mutex::new(1); - let mutex2: crate::Mutex<_> = Mutex::new(2); - let mutex3: crate::Mutex<_> = Mutex::new(2); - let collection = LockCollection::try_new((&mutex1, &mutex2, &mutex3)).unwrap(); - - let guard = collection.lock(key); - assert!(guard.0 < guard.1); - assert!(guard.1 > guard.0); - assert!(guard.1 == guard.2); - assert!(guard.0 != guard.2) - } - - #[test] fn ref_as_mut() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let collection = LockCollection::new(crate::Mutex::new(0)); - let mut guard = collection.lock(&mut key); + let mut guard = collection.lock(key); let guard_mut = guard.as_mut().as_mut(); *guard_mut = 3; - drop(guard); + let key = LockCollection::<crate::Mutex<_>>::unlock(guard); - let guard = collection.lock(&mut key); + let guard = collection.lock(key); assert_eq!(guard.as_ref().as_ref(), &3); } #[test] fn guard_as_mut() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let mutex = crate::Mutex::new(0); - let mut guard = mutex.lock(&mut key); + let mut guard = mutex.lock(key); let guard_mut = guard.as_mut(); *guard_mut = 3; - drop(guard); + let key = Mutex::unlock(guard); - let guard = mutex.lock(&mut key); + let guard = mutex.lock(key); assert_eq!(guard.as_ref(), &3); } #[test] fn dropping_guard_releases_mutex() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let mutex: crate::Mutex<_> = Mutex::new("Hello, world!"); - let guard = mutex.lock(&mut key); + let guard = mutex.lock(key); drop(guard); assert!(!mutex.is_locked()); diff --git a/src/mutex/guard.rs b/src/mutex/guard.rs index 4e4d5f1..22e59c1 100644 --- a/src/mutex/guard.rs +++ b/src/mutex/guard.rs @@ -5,34 +5,14 @@ use std::ops::{Deref, DerefMut}; use lock_api::RawMutex; -use crate::key::Keyable; use crate::lockable::RawLock; +use crate::ThreadKey; use super::{Mutex, MutexGuard, MutexRef}; // These impls make things slightly easier because now you can use // `println!("{guard}")` instead of `println!("{}", *guard)` -impl<T: PartialEq + ?Sized, R: RawMutex> PartialEq for MutexRef<'_, T, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -impl<T: Eq + ?Sized, R: RawMutex> Eq for MutexRef<'_, T, R> {} - -impl<T: PartialOrd + ?Sized, R: RawMutex> PartialOrd for MutexRef<'_, T, R> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -impl<T: Ord + ?Sized, R: RawMutex> Ord for MutexRef<'_, T, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] impl<T: Hash + ?Sized, R: RawMutex> Hash for MutexRef<'_, T, R> { @@ -107,39 +87,9 @@ impl<'a, T: ?Sized, R: RawMutex> MutexRef<'a, T, R> { // it's kinda annoying to re-implement some of this stuff on guards // there's nothing i can do about that -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialEq + ?Sized, R: RawMutex, Key: Keyable> PartialEq for MutexGuard<'_, '_, T, Key, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Eq + ?Sized, R: RawMutex, Key: Keyable> Eq for MutexGuard<'_, '_, T, Key, R> {} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialOrd + ?Sized, R: RawMutex, Key: Keyable> PartialOrd - for MutexGuard<'_, '_, T, Key, R> -{ - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Ord + ?Sized, R: RawMutex, Key: Keyable> Ord for MutexGuard<'_, '_, T, Key, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] -impl<T: Hash + ?Sized, R: RawMutex, Key: Keyable> Hash for MutexGuard<'_, '_, T, Key, R> { +impl<T: Hash + ?Sized, R: RawMutex> Hash for MutexGuard<'_, T, R> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.deref().hash(state) } @@ -147,19 +97,19 @@ impl<T: Hash + ?Sized, R: RawMutex, Key: Keyable> Hash for MutexGuard<'_, '_, T, #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<T: Debug + ?Sized, Key: Keyable, R: RawMutex> Debug for MutexGuard<'_, '_, T, Key, R> { +impl<T: Debug + ?Sized, R: RawMutex> Debug for MutexGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&**self, f) } } -impl<T: Display + ?Sized, Key: Keyable, R: RawMutex> Display for MutexGuard<'_, '_, T, Key, R> { +impl<T: Display + ?Sized, R: RawMutex> Display for MutexGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&**self, f) } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> Deref for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> Deref for MutexGuard<'_, T, R> { type Target = T; fn deref(&self) -> &Self::Target { @@ -167,33 +117,32 @@ impl<T: ?Sized, Key: Keyable, R: RawMutex> Deref for MutexGuard<'_, '_, T, Key, } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> DerefMut for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> DerefMut for MutexGuard<'_, T, R> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.mutex } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> AsRef<T> for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> AsRef<T> for MutexGuard<'_, T, R> { fn as_ref(&self) -> &T { self } } -impl<T: ?Sized, Key: Keyable, R: RawMutex> AsMut<T> for MutexGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawMutex> AsMut<T> for MutexGuard<'_, T, R> { fn as_mut(&mut self) -> &mut T { self } } -impl<'a, T: ?Sized, Key: Keyable, R: RawMutex> MutexGuard<'a, '_, T, Key, R> { +impl<'a, T: ?Sized, R: RawMutex> MutexGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: Key) -> Self { + pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: ThreadKey) -> Self { Self { mutex: MutexRef(mutex, PhantomData), thread_key, - _phantom: PhantomData, } } } diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs index 0bd5286..1d8ce8b 100644 --- a/src/mutex/mutex.rs +++ b/src/mutex/mutex.rs @@ -6,9 +6,9 @@ use std::panic::AssertUnwindSafe; use lock_api::RawMutex; use crate::handle_unwind::handle_unwind; -use crate::key::Keyable; use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock}; use crate::poisonable::PoisonFlag; +use crate::{Keyable, ThreadKey}; use super::{Mutex, MutexGuard, MutexRef}; @@ -62,6 +62,11 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> { where Self: 'g; + type DataMut<'a> + = &'a mut T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self); } @@ -69,6 +74,10 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> { unsafe fn guard(&self) -> Self::Guard<'_> { MutexRef::new(self) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.get().as_mut().unwrap_unchecked() + } } impl<T: Send, R: RawMutex + Send + Sync> LockableIntoInner for Mutex<T, R> { @@ -214,6 +223,46 @@ impl<T: ?Sized, R> Mutex<T, R> { } impl<T: ?Sized, R: RawMutex> Mutex<T, R> { + pub fn scoped_lock<Ret>(&self, key: impl Keyable, f: impl FnOnce(&mut T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the mutex was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: we locked the mutex already + self.raw_unlock(); + + drop(key); // ensures we drop the key in the correct place + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, Ret>( + &self, + key: Key, + f: impl FnOnce(&mut T) -> Ret, + ) -> Result<Ret, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: the mutex was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: we locked the mutex already + self.raw_unlock(); + + drop(key); // ensures we drop the key in the correct place + + Ok(r) + } + } + /// Block the thread until this mutex can be locked, and lock it. /// /// Upon returning, the thread is the only thread with a lock on the @@ -237,7 +286,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key), 10); /// ``` - pub fn lock<'s, 'k: 's, Key: Keyable>(&'s self, key: Key) -> MutexGuard<'s, 'k, T, Key, R> { + pub fn lock(&self, key: ThreadKey) -> MutexGuard<'_, T, R> { unsafe { // safety: we have the thread key self.raw_lock(); @@ -280,10 +329,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key), 10); /// ``` - pub fn try_lock<'s, 'k: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result<MutexGuard<'s, 'k, T, Key, R>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<MutexGuard<'_, T, R>, ThreadKey> { unsafe { // safety: we have the key to the mutex if self.raw_try_lock() { @@ -322,7 +368,8 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> { /// /// let key = Mutex::unlock(guard); /// ``` - pub fn unlock<'a, 'k: 'a, Key: Keyable + 'k>(guard: MutexGuard<'a, 'k, T, Key, R>) -> Key { + #[must_use] + pub fn unlock(guard: MutexGuard<'_, T, R>) -> ThreadKey { unsafe { guard.mutex.0.raw_unlock(); } diff --git a/src/poisonable.rs b/src/poisonable.rs index bb6ad28..8d5a810 100644 --- a/src/poisonable.rs +++ b/src/poisonable.rs @@ -1,6 +1,8 @@ use std::marker::PhantomData; use std::sync::atomic::AtomicBool; +use crate::ThreadKey; + mod error; mod flag; mod guard; @@ -55,10 +57,9 @@ pub struct PoisonRef<'a, G> { /// An RAII guard for a [`Poisonable`]. /// /// This is created by calling methods like [`Poisonable::lock`]. -pub struct PoisonGuard<'a, 'key: 'a, G, Key: 'key> { +pub struct PoisonGuard<'a, G> { guard: PoisonRef<'a, G>, - key: Key, - _phantom: PhantomData<&'key ()>, + key: ThreadKey, } /// A type of error which can be returned when acquiring a [`Poisonable`] lock. @@ -69,9 +70,9 @@ pub struct PoisonError<Guard> { /// An enumeration of possible errors associated with /// [`TryLockPoisonableResult`] which can occur while trying to acquire a lock /// (i.e.: [`Poisonable::try_lock`]). -pub enum TryLockPoisonableError<'flag, 'key: 'flag, G, Key: 'key> { - Poisoned(PoisonError<PoisonGuard<'flag, 'key, G, Key>>), - WouldBlock(Key), +pub enum TryLockPoisonableError<'flag, G> { + Poisoned(PoisonError<PoisonGuard<'flag, G>>), + WouldBlock(ThreadKey), } /// A type alias for the result of a lock method which can poisoned. @@ -89,8 +90,8 @@ pub type PoisonResult<Guard> = Result<Guard, PoisonError<Guard>>; /// For more information, see [`PoisonResult`]. A `TryLockPoisonableResult` /// doesn't necessarily hold the associated guard in the [`Err`] type as the /// lock might not have been acquired for other reasons. -pub type TryLockPoisonableResult<'flag, 'key, G, Key> = - Result<PoisonGuard<'flag, 'key, G, Key>, TryLockPoisonableError<'flag, 'key, G, Key>>; +pub type TryLockPoisonableResult<'flag, G> = + Result<PoisonGuard<'flag, G>, TryLockPoisonableError<'flag, G>>; #[cfg(test)] mod tests { @@ -101,6 +102,105 @@ mod tests { use crate::{LockCollection, Mutex, ThreadKey}; #[test] + fn locking_poisoned_mutex_returns_error_in_collection() { + let key = ThreadKey::get().unwrap(); + let mutex = LockCollection::new(Poisonable::new(Mutex::new(42))); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + let mut guard1 = mutex.lock(key); + let guard = guard1.as_deref_mut().unwrap(); + assert_eq!(**guard, 42); + panic!(); + + #[allow(unreachable_code)] + drop(guard1); + }) + .join() + .unwrap_err(); + }); + + let error = mutex.lock(key); + let error = error.as_deref().unwrap_err(); + assert_eq!(***error.get_ref(), 42); + } + + #[test] + fn non_poisoned_get_mut_is_ok() { + let mut mutex = Poisonable::new(Mutex::new(42)); + let guard = mutex.get_mut(); + assert!(guard.is_ok()); + assert_eq!(*guard.unwrap(), 42); + } + + #[test] + fn non_poisoned_get_mut_is_err() { + let mut mutex = Poisonable::new(Mutex::new(42)); + + let _ = std::panic::catch_unwind(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused_variables)] + let guard = mutex.lock(key); + panic!(); + #[allow(unreachable_code)] + drop(guard); + }); + + let guard = mutex.get_mut(); + assert!(guard.is_err()); + assert_eq!(**guard.unwrap_err().get_ref(), 42); + } + + #[test] + fn unpoisoned_into_inner() { + let mutex = Poisonable::new(Mutex::new("foo")); + assert_eq!(mutex.into_inner().unwrap(), "foo"); + } + + #[test] + fn poisoned_into_inner() { + let mutex = Poisonable::from(Mutex::new("foo")); + + std::panic::catch_unwind(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused_variables)] + let guard = mutex.lock(key); + panic!(); + #[allow(unreachable_code)] + drop(guard); + }) + .unwrap_err(); + + let error = mutex.into_inner().unwrap_err(); + assert_eq!(error.into_inner(), "foo"); + } + + #[test] + fn unpoisoned_into_child() { + let mutex = Poisonable::new(Mutex::new("foo")); + assert_eq!(mutex.into_child().unwrap().into_inner(), "foo"); + } + + #[test] + fn poisoned_into_child() { + let mutex = Poisonable::from(Mutex::new("foo")); + + std::panic::catch_unwind(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused_variables)] + let guard = mutex.lock(key); + panic!(); + #[allow(unreachable_code)] + drop(guard); + }) + .unwrap_err(); + + let error = mutex.into_child().unwrap_err(); + assert_eq!(error.into_inner().into_inner(), "foo"); + } + + #[test] fn display_works() { let key = ThreadKey::get().unwrap(); let mutex = Poisonable::new(Mutex::new("Hello, world!")); @@ -111,21 +211,75 @@ mod tests { } #[test] - fn ord_works() { + fn ref_as_ref() { + let key = ThreadKey::get().unwrap(); + let collection = LockCollection::new(Poisonable::new(Mutex::new("foo"))); + let guard = collection.lock(key); + let Ok(ref guard) = guard.as_ref() else { + panic!() + }; + assert_eq!(**guard.as_ref(), "foo"); + } + + #[test] + fn ref_as_mut() { let key = ThreadKey::get().unwrap(); - let lock1 = Poisonable::new(Mutex::new(1)); - let lock2 = Poisonable::new(Mutex::new(3)); - let lock3 = Poisonable::new(Mutex::new(3)); - let collection = LockCollection::try_new((&lock1, &lock2, &lock3)).unwrap(); + let collection = LockCollection::new(Poisonable::new(Mutex::new("foo"))); + let mut guard1 = collection.lock(key); + let Ok(ref mut guard) = guard1.as_mut() else { + panic!() + }; + let guard = guard.as_mut(); + **guard = "bar"; + + let key = LockCollection::<Poisonable<Mutex<_>>>::unlock(guard1); + let guard = collection.lock(key); + let guard = guard.as_deref().unwrap(); + assert_eq!(*guard.as_ref(), "bar"); + } + #[test] + fn guard_as_ref() { + let key = ThreadKey::get().unwrap(); + let collection = Poisonable::new(Mutex::new("foo")); let guard = collection.lock(key); - let guard1 = guard.0.as_ref().unwrap(); - let guard2 = guard.1.as_ref().unwrap(); - let guard3 = guard.2.as_ref().unwrap(); - assert_eq!(guard1.cmp(guard2), std::cmp::Ordering::Less); - assert_eq!(guard2.cmp(guard1), std::cmp::Ordering::Greater); - assert!(guard2 == guard3); - assert!(guard1 != guard3); + let Ok(ref guard) = guard.as_ref() else { + panic!() + }; + assert_eq!(**guard.as_ref(), "foo"); + } + + #[test] + fn guard_as_mut() { + let key = ThreadKey::get().unwrap(); + let mutex = Poisonable::new(Mutex::new("foo")); + let mut guard1 = mutex.lock(key); + let Ok(ref mut guard) = guard1.as_mut() else { + panic!() + }; + let guard = guard.as_mut(); + **guard = "bar"; + + let key = Poisonable::<Mutex<_>>::unlock(guard1.unwrap()); + let guard = mutex.lock(key); + let guard = guard.as_deref().unwrap(); + assert_eq!(*guard, "bar"); + } + + #[test] + fn deref_mut_in_collection() { + let key = ThreadKey::get().unwrap(); + let collection = LockCollection::new(Poisonable::new(Mutex::new(42))); + let mut guard1 = collection.lock(key); + let Ok(ref mut guard) = guard1.as_mut() else { + panic!() + }; + // TODO make this more convenient + assert_eq!(***guard, 42); + ***guard = 24; + + let key = LockCollection::<Poisonable<Mutex<_>>>::unlock(guard1); + _ = collection.lock(key); } #[test] @@ -202,18 +356,45 @@ mod tests { assert!(mutex.is_poisoned()); - let mut key = ThreadKey::get().unwrap(); - let mut error = mutex.lock(&mut key).unwrap_err(); + let key: ThreadKey = ThreadKey::get().unwrap(); + let mut error = mutex.lock(key).unwrap_err(); let error1 = error.as_mut(); **error1 = "bar"; - drop(error); + let key = Poisonable::<Mutex<_>>::unlock(error.into_inner()); mutex.clear_poison(); - let guard = mutex.lock(&mut key).unwrap(); + let guard = mutex.lock(key).unwrap(); assert_eq!(&**guard, "bar"); } #[test] + fn try_error_from_lock_error() { + let mutex = Poisonable::new(Mutex::new("foo")); + + let _ = std::panic::catch_unwind(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused_variables)] + let guard = mutex.lock(key); + panic!(); + + #[allow(unknown_lints)] + #[allow(unreachable_code)] + drop(guard); + }); + + assert!(mutex.is_poisoned()); + + let key = ThreadKey::get().unwrap(); + let error = mutex.lock(key).unwrap_err(); + let error = TryLockPoisonableError::from(error); + + let TryLockPoisonableError::Poisoned(error) = error else { + panic!() + }; + assert_eq!(&**error.into_inner(), "foo"); + } + + #[test] fn new_poisonable_is_not_poisoned() { let mutex = Poisonable::new(Mutex::new(42)); assert!(!mutex.is_poisoned()); diff --git a/src/poisonable/error.rs b/src/poisonable/error.rs index bff011d..b69df5d 100644 --- a/src/poisonable/error.rs +++ b/src/poisonable/error.rs @@ -109,7 +109,7 @@ impl<Guard> PoisonError<Guard> { /// /// let key = ThreadKey::get().unwrap(); /// let p_err = mutex.lock(key).unwrap_err(); - /// let data: &PoisonGuard<_, _> = p_err.get_ref(); + /// let data: &PoisonGuard<_> = p_err.get_ref(); /// println!("recovered {} items", data.len()); /// ``` #[must_use] @@ -154,7 +154,7 @@ impl<Guard> PoisonError<Guard> { #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<G, Key> fmt::Debug for TryLockPoisonableError<'_, '_, G, Key> { +impl<G> fmt::Debug for TryLockPoisonableError<'_, G> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match *self { Self::Poisoned(..) => "Poisoned(..)".fmt(f), @@ -163,7 +163,7 @@ impl<G, Key> fmt::Debug for TryLockPoisonableError<'_, '_, G, Key> { } } -impl<G, Key> fmt::Display for TryLockPoisonableError<'_, '_, G, Key> { +impl<G> fmt::Display for TryLockPoisonableError<'_, G> { #[cfg_attr(test, mutants::skip)] #[cfg(not(tarpaulin_include))] fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -175,12 +175,10 @@ impl<G, Key> fmt::Display for TryLockPoisonableError<'_, '_, G, Key> { } } -impl<G, Key> Error for TryLockPoisonableError<'_, '_, G, Key> {} +impl<G> Error for TryLockPoisonableError<'_, G> {} -impl<'flag, 'key, G, Key> From<PoisonError<PoisonGuard<'flag, 'key, G, Key>>> - for TryLockPoisonableError<'flag, 'key, G, Key> -{ - fn from(value: PoisonError<PoisonGuard<'flag, 'key, G, Key>>) -> Self { +impl<'flag, G> From<PoisonError<PoisonGuard<'flag, G>>> for TryLockPoisonableError<'flag, G> { + fn from(value: PoisonError<PoisonGuard<'flag, G>>) -> Self { Self::Poisoned(value) } } diff --git a/src/poisonable/guard.rs b/src/poisonable/guard.rs index 3f85d25..b887e2d 100644 --- a/src/poisonable/guard.rs +++ b/src/poisonable/guard.rs @@ -3,8 +3,6 @@ use std::hash::Hash; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; -use crate::Keyable; - use super::{PoisonFlag, PoisonGuard, PoisonRef}; impl<'a, Guard> PoisonRef<'a, Guard> { @@ -28,26 +26,6 @@ impl<Guard> Drop for PoisonRef<'_, Guard> { } } -impl<Guard: PartialEq> PartialEq for PoisonRef<'_, Guard> { - fn eq(&self, other: &Self) -> bool { - self.guard.eq(&other.guard) - } -} - -impl<Guard: PartialOrd> PartialOrd for PoisonRef<'_, Guard> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.guard.partial_cmp(&other.guard) - } -} - -impl<Guard: Eq> Eq for PoisonRef<'_, Guard> {} - -impl<Guard: Ord> Ord for PoisonRef<'_, Guard> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.guard.cmp(&other.guard) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] impl<Guard: Hash> Hash for PoisonRef<'_, Guard> { @@ -96,37 +74,9 @@ impl<Guard> AsMut<Guard> for PoisonRef<'_, Guard> { } } -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: PartialEq, Key: Keyable> PartialEq for PoisonGuard<'_, '_, Guard, Key> { - fn eq(&self, other: &Self) -> bool { - self.guard.eq(&other.guard) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: PartialOrd, Key: Keyable> PartialOrd for PoisonGuard<'_, '_, Guard, Key> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.guard.partial_cmp(&other.guard) - } -} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: Eq, Key: Keyable> Eq for PoisonGuard<'_, '_, Guard, Key> {} - -#[mutants::skip] // it's hard to get two guards safely -#[cfg(not(tarpaulin_include))] -impl<Guard: Ord, Key: Keyable> Ord for PoisonGuard<'_, '_, Guard, Key> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.guard.cmp(&other.guard) - } -} - #[mutants::skip] // hashing involves RNG and is hard to test #[cfg(not(tarpaulin_include))] -impl<Guard: Hash, Key: Keyable> Hash for PoisonGuard<'_, '_, Guard, Key> { +impl<Guard: Hash> Hash for PoisonGuard<'_, Guard> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.guard.hash(state) } @@ -134,19 +84,19 @@ impl<Guard: Hash, Key: Keyable> Hash for PoisonGuard<'_, '_, Guard, Key> { #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<Guard: Debug, Key: Keyable> Debug for PoisonGuard<'_, '_, Guard, Key> { +impl<Guard: Debug> Debug for PoisonGuard<'_, Guard> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&self.guard, f) } } -impl<Guard: Display, Key: Keyable> Display for PoisonGuard<'_, '_, Guard, Key> { +impl<Guard: Display> Display for PoisonGuard<'_, Guard> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&self.guard, f) } } -impl<T, Guard: Deref<Target = T>, Key: Keyable> Deref for PoisonGuard<'_, '_, Guard, Key> { +impl<T, Guard: Deref<Target = T>> Deref for PoisonGuard<'_, Guard> { type Target = T; fn deref(&self) -> &Self::Target { @@ -155,20 +105,20 @@ impl<T, Guard: Deref<Target = T>, Key: Keyable> Deref for PoisonGuard<'_, '_, Gu } } -impl<T, Guard: DerefMut<Target = T>, Key: Keyable> DerefMut for PoisonGuard<'_, '_, Guard, Key> { +impl<T, Guard: DerefMut<Target = T>> DerefMut for PoisonGuard<'_, Guard> { fn deref_mut(&mut self) -> &mut Self::Target { #[allow(clippy::explicit_auto_deref)] // fixing this results in a compiler error &mut *self.guard.guard } } -impl<Guard, Key: Keyable> AsRef<Guard> for PoisonGuard<'_, '_, Guard, Key> { +impl<Guard> AsRef<Guard> for PoisonGuard<'_, Guard> { fn as_ref(&self) -> &Guard { &self.guard.guard } } -impl<Guard, Key: Keyable> AsMut<Guard> for PoisonGuard<'_, '_, Guard, Key> { +impl<Guard> AsMut<Guard> for PoisonGuard<'_, Guard> { fn as_mut(&mut self) -> &mut Guard { &mut self.guard.guard } diff --git a/src/poisonable/poisonable.rs b/src/poisonable/poisonable.rs index 2dac4bb..efe4ed0 100644 --- a/src/poisonable/poisonable.rs +++ b/src/poisonable/poisonable.rs @@ -1,10 +1,9 @@ -use std::marker::PhantomData; use std::panic::{RefUnwindSafe, UnwindSafe}; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; use super::{ PoisonError, PoisonFlag, PoisonGuard, PoisonRef, PoisonResult, Poisonable, @@ -49,6 +48,11 @@ unsafe impl<L: Lockable> Lockable for Poisonable<L> { where Self: 'g; + type DataMut<'a> + = PoisonResult<L::DataMut<'a>> + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { self.inner.get_ptrs(ptrs) } @@ -62,6 +66,14 @@ unsafe impl<L: Lockable> Lockable for Poisonable<L> { Ok(ref_guard) } } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + if self.is_poisoned() { + Err(PoisonError::new(self.inner.data_mut())) + } else { + Ok(self.inner.data_mut()) + } + } } unsafe impl<L: Sharable> Sharable for Poisonable<L> { @@ -70,6 +82,11 @@ unsafe impl<L: Sharable> Sharable for Poisonable<L> { where Self: 'g; + type DataRef<'a> + = PoisonResult<L::DataRef<'a>> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { let ref_guard = PoisonRef::new(&self.poisoned, self.inner.read_guard()); @@ -79,6 +96,14 @@ unsafe impl<L: Sharable> Sharable for Poisonable<L> { Ok(ref_guard) } } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + if self.is_poisoned() { + Err(PoisonError::new(self.inner.data_ref())) + } else { + Ok(self.inner.data_ref()) + } + } } unsafe impl<L: OwnedLockable> OwnedLockable for Poisonable<L> {} @@ -266,14 +291,10 @@ impl<L> Poisonable<L> { impl<L: Lockable> Poisonable<L> { /// Creates a guard for the poisonable, without locking it - unsafe fn guard<'flag, 'key, Key: Keyable + 'key>( - &'flag self, - key: Key, - ) -> PoisonResult<PoisonGuard<'flag, 'key, L::Guard<'flag>, Key>> { + unsafe fn guard(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::Guard<'_>>> { let guard = PoisonGuard { guard: PoisonRef::new(&self.poisoned, self.inner.guard()), key, - _phantom: PhantomData, }; if self.is_poisoned() { @@ -285,6 +306,50 @@ impl<L: Lockable> Poisonable<L> { } impl<L: Lockable + RawLock> Poisonable<L> { + pub fn scoped_lock<'a, R>( + &'a self, + key: impl Keyable, + f: impl Fn(<Self as Lockable>::DataMut<'a>) -> R, + ) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<'a, Key: Keyable, R>( + &'a self, + key: Key, + f: impl Fn(<Self as Lockable>::DataMut<'a>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Acquires the lock, blocking the current thread until it is ok to do so. /// /// This function will block the current thread until it is available to @@ -316,10 +381,7 @@ impl<L: Lockable + RawLock> Poisonable<L> { /// let key = ThreadKey::get().unwrap(); /// assert_eq!(*mutex.lock(key).unwrap(), 10); /// ``` - pub fn lock<'flag, 'key, Key: Keyable + 'key>( - &'flag self, - key: Key, - ) -> PoisonResult<PoisonGuard<'flag, 'key, L::Guard<'flag>, Key>> { + pub fn lock(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::Guard<'_>>> { unsafe { self.inner.raw_lock(); self.guard(key) @@ -370,10 +432,7 @@ impl<L: Lockable + RawLock> Poisonable<L> { /// /// [`Poisoned`]: `TryLockPoisonableError::Poisoned` /// [`WouldBlock`]: `TryLockPoisonableError::WouldBlock` - pub fn try_lock<'flag, 'key, Key: Keyable + 'key>( - &'flag self, - key: Key, - ) -> TryLockPoisonableResult<'flag, 'key, L::Guard<'flag>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> TryLockPoisonableResult<'_, L::Guard<'_>> { unsafe { if self.inner.raw_try_lock() { Ok(self.guard(key)?) @@ -398,23 +457,17 @@ impl<L: Lockable + RawLock> Poisonable<L> { /// /// let key = Poisonable::<Mutex<_>>::unlock(guard); /// ``` - pub fn unlock<'flag, 'key, Key: Keyable + 'key>( - guard: PoisonGuard<'flag, 'key, L::Guard<'flag>, Key>, - ) -> Key { + pub fn unlock<'flag>(guard: PoisonGuard<'flag, L::Guard<'flag>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable + RawLock> Poisonable<L> { - unsafe fn read_guard<'flag, 'key, Key: Keyable + 'key>( - &'flag self, - key: Key, - ) -> PoisonResult<PoisonGuard<'flag, 'key, L::ReadGuard<'flag>, Key>> { + unsafe fn read_guard(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::ReadGuard<'_>>> { let guard = PoisonGuard { guard: PoisonRef::new(&self.poisoned, self.inner.read_guard()), key, - _phantom: PhantomData, }; if self.is_poisoned() { @@ -424,6 +477,50 @@ impl<L: Sharable + RawLock> Poisonable<L> { Ok(guard) } + pub fn scoped_read<'a, R>( + &'a self, + key: impl Keyable, + f: impl Fn(<Self as Sharable>::DataRef<'a>) -> R, + ) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<'a, Key: Keyable, R>( + &'a self, + key: Key, + f: impl Fn(<Self as Sharable>::DataRef<'a>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks with shared read access, blocking the current thread until it can /// be acquired. /// @@ -457,10 +554,7 @@ impl<L: Sharable + RawLock> Poisonable<L> { /// assert!(c_lock.read(key).is_ok()); /// }).join().expect("thread::spawn failed"); /// ``` - pub fn read<'flag, 'key, Key: Keyable + 'key>( - &'flag self, - key: Key, - ) -> PoisonResult<PoisonGuard<'flag, 'key, L::ReadGuard<'flag>, Key>> { + pub fn read(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::ReadGuard<'_>>> { unsafe { self.inner.raw_read(); self.read_guard(key) @@ -504,10 +598,7 @@ impl<L: Sharable + RawLock> Poisonable<L> { /// /// [`Poisoned`]: `TryLockPoisonableError::Poisoned` /// [`WouldBlock`]: `TryLockPoisonableError::WouldBlock` - pub fn try_read<'flag, 'key, Key: Keyable + 'key>( - &'flag self, - key: Key, - ) -> TryLockPoisonableResult<'flag, 'key, L::ReadGuard<'flag>, Key> { + pub fn try_read(&self, key: ThreadKey) -> TryLockPoisonableResult<'_, L::ReadGuard<'_>> { unsafe { if self.inner.raw_try_read() { Ok(self.read_guard(key)?) @@ -530,9 +621,7 @@ impl<L: Sharable + RawLock> Poisonable<L> { /// let mut guard = lock.read(key).unwrap(); /// let key = Poisonable::<RwLock<_>>::unlock_read(guard); /// ``` - pub fn unlock_read<'flag, 'key, Key: Keyable + 'key>( - guard: PoisonGuard<'flag, 'key, L::ReadGuard<'flag>, Key>, - ) -> Key { + pub fn unlock_read<'flag>(guard: PoisonGuard<'flag, L::ReadGuard<'flag>>) -> ThreadKey { drop(guard.guard); guard.key } diff --git a/src/rwlock.rs b/src/rwlock.rs index f78e648..b604370 100644 --- a/src/rwlock.rs +++ b/src/rwlock.rs @@ -3,8 +3,8 @@ use std::marker::PhantomData; use lock_api::RawRwLock; -use crate::key::Keyable; use crate::poisonable::PoisonFlag; +use crate::ThreadKey; mod rwlock; @@ -95,10 +95,9 @@ pub struct RwLockWriteRef<'a, T: ?Sized, R: RawRwLock>( /// /// [`read`]: `RwLock::read` /// [`try_read`]: `RwLock::try_read` -pub struct RwLockReadGuard<'a, 'key: 'a, T: ?Sized, Key: Keyable + 'key, R: RawRwLock> { +pub struct RwLockReadGuard<'a, T: ?Sized, R: RawRwLock> { rwlock: RwLockReadRef<'a, T, R>, - thread_key: Key, - _phantom: PhantomData<&'key ()>, + thread_key: ThreadKey, } /// RAII structure used to release the exclusive write access of a lock when @@ -108,16 +107,14 @@ pub struct RwLockReadGuard<'a, 'key: 'a, T: ?Sized, Key: Keyable + 'key, R: RawR /// [`RwLock`] /// /// [`try_write`]: `RwLock::try_write` -pub struct RwLockWriteGuard<'a, 'key: 'a, T: ?Sized, Key: Keyable + 'key, R: RawRwLock> { +pub struct RwLockWriteGuard<'a, T: ?Sized, R: RawRwLock> { rwlock: RwLockWriteRef<'a, T, R>, - thread_key: Key, - _phantom: PhantomData<&'key ()>, + thread_key: ThreadKey, } #[cfg(test)] mod tests { use crate::lockable::Lockable; - use crate::LockCollection; use crate::RwLock; use crate::ThreadKey; @@ -142,6 +139,16 @@ mod tests { } #[test] + fn read_lock_from_works() { + let key = ThreadKey::get().unwrap(); + let lock: crate::RwLock<_> = RwLock::from("Hello, world!"); + let reader = ReadLock::from(&lock); + + let guard = reader.lock(key); + assert_eq!(*guard, "Hello, world!"); + } + + #[test] fn write_lock_unlocked_when_initialized() { let key = ThreadKey::get().unwrap(); let lock: crate::RwLock<_> = RwLock::new("Hello, world!"); @@ -235,21 +242,6 @@ mod tests { } #[test] - fn write_ord() { - let key = ThreadKey::get().unwrap(); - let lock1: crate::RwLock<_> = RwLock::new(1); - let lock2: crate::RwLock<_> = RwLock::new(5); - let lock3: crate::RwLock<_> = RwLock::new(5); - let collection = LockCollection::try_new((&lock1, &lock2, &lock3)).unwrap(); - let guard = collection.lock(key); - - assert!(guard.0 < guard.1); - assert!(guard.1 > guard.0); - assert!(guard.1 == guard.2); - assert!(guard.0 != guard.2); - } - - #[test] fn read_ref_display_works() { let lock: crate::RwLock<_> = RwLock::new("Hello, world!"); let guard = unsafe { lock.try_read_no_key().unwrap() }; @@ -264,21 +256,6 @@ mod tests { } #[test] - fn read_ord() { - let key = ThreadKey::get().unwrap(); - let lock1: crate::RwLock<_> = RwLock::new(1); - let lock2: crate::RwLock<_> = RwLock::new(5); - let lock3: crate::RwLock<_> = RwLock::new(5); - let collection = LockCollection::try_new((&lock1, &lock2, &lock3)).unwrap(); - let guard = collection.read(key); - - assert!(guard.0 < guard.1); - assert!(guard.1 > guard.0); - assert!(guard.1 == guard.2); - assert!(guard.0 != guard.2); - } - - #[test] fn dropping_read_ref_releases_rwlock() { let lock: crate::RwLock<_> = RwLock::new("Hello, world!"); @@ -290,10 +267,10 @@ mod tests { #[test] fn dropping_write_guard_releases_rwlock() { - let mut key = ThreadKey::get().unwrap(); + let key = ThreadKey::get().unwrap(); let lock: crate::RwLock<_> = RwLock::new("Hello, world!"); - let guard = lock.write(&mut key); + let guard = lock.write(key); drop(guard); assert!(!lock.is_locked()); diff --git a/src/rwlock/read_guard.rs b/src/rwlock/read_guard.rs index bd22837..0d68c75 100644 --- a/src/rwlock/read_guard.rs +++ b/src/rwlock/read_guard.rs @@ -5,34 +5,14 @@ use std::ops::Deref; use lock_api::RawRwLock; -use crate::key::Keyable; use crate::lockable::RawLock; +use crate::ThreadKey; use super::{RwLock, RwLockReadGuard, RwLockReadRef}; // These impls make things slightly easier because now you can use // `println!("{guard}")` instead of `println!("{}", *guard)` -impl<T: PartialEq + ?Sized, R: RawRwLock> PartialEq for RwLockReadRef<'_, T, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -impl<T: Eq + ?Sized, R: RawRwLock> Eq for RwLockReadRef<'_, T, R> {} - -impl<T: PartialOrd + ?Sized, R: RawRwLock> PartialOrd for RwLockReadRef<'_, T, R> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -impl<T: Ord + ?Sized, R: RawRwLock> Ord for RwLockReadRef<'_, T, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves PRNG and is hard to test #[cfg(not(tarpaulin_include))] impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockReadRef<'_, T, R> { @@ -89,41 +69,9 @@ impl<'a, T: ?Sized, R: RawRwLock> RwLockReadRef<'a, T, R> { } } -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialEq + ?Sized, R: RawRwLock, Key: Keyable> PartialEq - for RwLockReadGuard<'_, '_, T, Key, R> -{ - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Eq + ?Sized, R: RawRwLock, Key: Keyable> Eq for RwLockReadGuard<'_, '_, T, Key, R> {} - -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialOrd + ?Sized, R: RawRwLock, Key: Keyable> PartialOrd - for RwLockReadGuard<'_, '_, T, Key, R> -{ - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Ord + ?Sized, R: RawRwLock, Key: Keyable> Ord for RwLockReadGuard<'_, '_, T, Key, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves PRNG and is hard to test #[cfg(not(tarpaulin_include))] -impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockReadGuard<'_, '_, T, Key, R> { +impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockReadGuard<'_, T, R> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.deref().hash(state) } @@ -131,21 +79,19 @@ impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockReadGuard<'_, #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<T: Debug + ?Sized, Key: Keyable, R: RawRwLock> Debug for RwLockReadGuard<'_, '_, T, Key, R> { +impl<T: Debug + ?Sized, R: RawRwLock> Debug for RwLockReadGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&**self, f) } } -impl<T: Display + ?Sized, Key: Keyable, R: RawRwLock> Display - for RwLockReadGuard<'_, '_, T, Key, R> -{ +impl<T: Display + ?Sized, R: RawRwLock> Display for RwLockReadGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&**self, f) } } -impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockReadGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawRwLock> Deref for RwLockReadGuard<'_, T, R> { type Target = T; fn deref(&self) -> &Self::Target { @@ -153,21 +99,20 @@ impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockReadGuard<'_, '_, T, } } -impl<T: ?Sized, Key: Keyable, R: RawRwLock> AsRef<T> for RwLockReadGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawRwLock> AsRef<T> for RwLockReadGuard<'_, T, R> { fn as_ref(&self) -> &T { self } } -impl<'a, T: ?Sized, Key: Keyable, R: RawRwLock> RwLockReadGuard<'a, '_, T, Key, R> { +impl<'a, T: ?Sized, R: RawRwLock> RwLockReadGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: Key) -> Self { + pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self { Self { rwlock: RwLockReadRef(rwlock, PhantomData), thread_key, - _phantom: PhantomData, } } } diff --git a/src/rwlock/read_lock.rs b/src/rwlock/read_lock.rs index 5dd83a7..05b184a 100644 --- a/src/rwlock/read_lock.rs +++ b/src/rwlock/read_lock.rs @@ -2,8 +2,8 @@ use std::fmt::Debug; use lock_api::RawRwLock; -use crate::key::Keyable; use crate::lockable::{Lockable, RawLock, Sharable}; +use crate::ThreadKey; use super::{ReadLock, RwLock, RwLockReadGuard, RwLockReadRef}; @@ -13,6 +13,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R> where Self: 'g; + type DataMut<'a> + = &'a T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self.as_ref()); } @@ -20,6 +25,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R> unsafe fn guard(&self) -> Self::Guard<'_> { RwLockReadRef::new(self.as_ref()) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.0.data_ref() + } } unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for ReadLock<'_, T, R> { @@ -28,9 +37,18 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for ReadLock<'_, T, R> where Self: 'g; + type DataRef<'a> + = &'a T + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::Guard<'_> { RwLockReadRef::new(self.as_ref()) } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.0.data_ref() + } } #[mutants::skip] @@ -117,10 +135,8 @@ impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> { /// ``` /// /// [`ThreadKey`]: `crate::ThreadKey` - pub fn lock<'s, 'key: 's, Key: Keyable + 'key>( - &'s self, - key: Key, - ) -> RwLockReadGuard<'s, 'key, T, Key, R> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> RwLockReadGuard<'_, T, R> { self.0.read(key) } @@ -155,10 +171,7 @@ impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> { /// Err(_) => unreachable!(), /// }; /// ``` - pub fn try_lock<'s, 'key: 's, Key: Keyable + 'key>( - &'s self, - key: Key, - ) -> Result<RwLockReadGuard<'s, 'key, T, Key, R>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<RwLockReadGuard<'_, T, R>, ThreadKey> { self.0.try_read(key) } @@ -189,7 +202,8 @@ impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> { /// assert_eq!(*guard, 0); /// let key = ReadLock::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: RwLockReadGuard<'_, 'key, T, Key, R>) -> Key { + #[must_use] + pub fn unlock(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey { RwLock::unlock_read(guard) } } diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs index 038e6c7..905ecf8 100644 --- a/src/rwlock/rwlock.rs +++ b/src/rwlock/rwlock.rs @@ -6,10 +6,10 @@ use std::panic::AssertUnwindSafe; use lock_api::RawRwLock; use crate::handle_unwind::handle_unwind; -use crate::key::Keyable; use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; +use crate::{Keyable, ThreadKey}; use super::{PoisonFlag, RwLock, RwLockReadGuard, RwLockReadRef, RwLockWriteGuard, RwLockWriteRef}; @@ -79,6 +79,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> { where Self: 'g; + type DataMut<'a> + = &'a mut T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self); } @@ -86,6 +91,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> { unsafe fn guard(&self) -> Self::Guard<'_> { RwLockWriteRef::new(self) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.get().as_mut().unwrap_unchecked() + } } unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> { @@ -94,9 +103,18 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> { where Self: 'g; + type DataRef<'a> + = &'a T + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { RwLockReadRef::new(self) } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.get().as_ref().unwrap_unchecked() + } } unsafe impl<T: Send, R: RawRwLock + Send + Sync> OwnedLockable for RwLock<T, R> {} @@ -230,6 +248,86 @@ impl<T: ?Sized, R> RwLock<T, R> { } impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { + pub fn scoped_read<Ret>(&self, key: impl Keyable, f: impl Fn(&T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the rwlock was just locked + let r = f(self.data.get().as_ref().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays valid for long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, Ret>( + &self, + key: Key, + f: impl Fn(&T) -> Ret, + ) -> Result<Ret, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: the rwlock was just locked + let r = f(self.data.get().as_ref().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays valid for long enough + + Ok(r) + } + } + + pub fn scoped_write<Ret>(&self, key: impl Keyable, f: impl Fn(&mut T) -> Ret) -> Ret { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: we just locked the rwlock + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock(); + + drop(key); // ensure the key stays valid for long enough + + r + } + } + + pub fn scoped_try_write<Key: Keyable, Ret>( + &self, + key: Key, + f: impl Fn(&mut T) -> Ret, + ) -> Result<Ret, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: the rwlock was just locked + let r = f(self.data.get().as_mut().unwrap_unchecked()); + + // safety: the rwlock is already locked + self.raw_unlock(); + + drop(key); // ensure the key stays valid for long enough + + Ok(r) + } + } + /// Locks this `RwLock` with shared read access, blocking the current /// thread until it can be acquired. /// @@ -264,10 +362,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { /// ``` /// /// [`ThreadKey`]: `crate::ThreadKey` - pub fn read<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> RwLockReadGuard<'s, 'key, T, Key, R> { + pub fn read(&self, key: ThreadKey) -> RwLockReadGuard<'_, T, R> { unsafe { self.raw_read(); @@ -305,10 +400,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { /// Err(_) => unreachable!(), /// }; /// ``` - pub fn try_read<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result<RwLockReadGuard<'s, 'key, T, Key, R>, Key> { + pub fn try_read(&self, key: ThreadKey) -> Result<RwLockReadGuard<'_, T, R>, ThreadKey> { unsafe { if self.raw_try_read() { // safety: the lock is locked first @@ -369,10 +461,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { /// ``` /// /// [`ThreadKey`]: `crate::ThreadKey` - pub fn write<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> RwLockWriteGuard<'s, 'key, T, Key, R> { + pub fn write(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> { unsafe { self.raw_lock(); @@ -407,10 +496,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { /// let n = lock.read(key); /// assert_eq!(*n, 1); /// ``` - pub fn try_write<'s, 'key: 's, Key: Keyable>( - &'s self, - key: Key, - ) -> Result<RwLockWriteGuard<'s, 'key, T, Key, R>, Key> { + pub fn try_write(&self, key: ThreadKey) -> Result<RwLockWriteGuard<'_, T, R>, ThreadKey> { unsafe { if self.raw_try_lock() { // safety: the lock is locked first @@ -445,9 +531,8 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { /// assert_eq!(*guard, 0); /// let key = RwLock::unlock_read(guard); /// ``` - pub fn unlock_read<'key, Key: Keyable + 'key>( - guard: RwLockReadGuard<'_, 'key, T, Key, R>, - ) -> Key { + #[must_use] + pub fn unlock_read(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey { unsafe { guard.rwlock.0.raw_unlock_read(); } @@ -473,9 +558,8 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> { /// *guard += 20; /// let key = RwLock::unlock_write(guard); /// ``` - pub fn unlock_write<'key, Key: Keyable + 'key>( - guard: RwLockWriteGuard<'_, 'key, T, Key, R>, - ) -> Key { + #[must_use] + pub fn unlock_write(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey { unsafe { guard.rwlock.0.raw_unlock(); } diff --git a/src/rwlock/write_guard.rs b/src/rwlock/write_guard.rs index c971260..3fabf8e 100644 --- a/src/rwlock/write_guard.rs +++ b/src/rwlock/write_guard.rs @@ -5,34 +5,14 @@ use std::ops::{Deref, DerefMut}; use lock_api::RawRwLock; -use crate::key::Keyable; use crate::lockable::RawLock; +use crate::ThreadKey; use super::{RwLock, RwLockWriteGuard, RwLockWriteRef}; // These impls make things slightly easier because now you can use // `println!("{guard}")` instead of `println!("{}", *guard)` -impl<T: PartialEq + ?Sized, R: RawRwLock> PartialEq for RwLockWriteRef<'_, T, R> { - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -impl<T: Eq + ?Sized, R: RawRwLock> Eq for RwLockWriteRef<'_, T, R> {} - -impl<T: PartialOrd + ?Sized, R: RawRwLock> PartialOrd for RwLockWriteRef<'_, T, R> { - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -impl<T: Ord + ?Sized, R: RawRwLock> Ord for RwLockWriteRef<'_, T, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves PRNG and is difficult to test #[cfg(not(tarpaulin_include))] impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockWriteRef<'_, T, R> { @@ -104,41 +84,9 @@ impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteRef<'a, T, R> { } } -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialEq + ?Sized, R: RawRwLock, Key: Keyable> PartialEq - for RwLockWriteGuard<'_, '_, T, Key, R> -{ - fn eq(&self, other: &Self) -> bool { - self.deref().eq(&**other) - } -} - -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Eq + ?Sized, R: RawRwLock, Key: Keyable> Eq for RwLockWriteGuard<'_, '_, T, Key, R> {} - -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: PartialOrd + ?Sized, R: RawRwLock, Key: Keyable> PartialOrd - for RwLockWriteGuard<'_, '_, T, Key, R> -{ - fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> { - self.deref().partial_cmp(&**other) - } -} - -#[mutants::skip] // it's hard to get two read guards safely -#[cfg(not(tarpaulin_include))] -impl<T: Ord + ?Sized, R: RawRwLock, Key: Keyable> Ord for RwLockWriteGuard<'_, '_, T, Key, R> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - self.deref().cmp(&**other) - } -} - #[mutants::skip] // hashing involves PRNG and is difficult to test #[cfg(not(tarpaulin_include))] -impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockWriteGuard<'_, '_, T, Key, R> { +impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockWriteGuard<'_, T, R> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { self.deref().hash(state) } @@ -146,21 +94,19 @@ impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockWriteGuard<'_, #[mutants::skip] #[cfg(not(tarpaulin_include))] -impl<T: Debug + ?Sized, Key: Keyable, R: RawRwLock> Debug for RwLockWriteGuard<'_, '_, T, Key, R> { +impl<T: Debug + ?Sized, R: RawRwLock> Debug for RwLockWriteGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Debug::fmt(&**self, f) } } -impl<T: Display + ?Sized, Key: Keyable, R: RawRwLock> Display - for RwLockWriteGuard<'_, '_, T, Key, R> -{ +impl<T: Display + ?Sized, R: RawRwLock> Display for RwLockWriteGuard<'_, T, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Display::fmt(&**self, f) } } -impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockWriteGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawRwLock> Deref for RwLockWriteGuard<'_, T, R> { type Target = T; fn deref(&self) -> &Self::Target { @@ -168,33 +114,32 @@ impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockWriteGuard<'_, '_, T } } -impl<T: ?Sized, Key: Keyable, R: RawRwLock> DerefMut for RwLockWriteGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawRwLock> DerefMut for RwLockWriteGuard<'_, T, R> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.rwlock } } -impl<T: ?Sized, Key: Keyable, R: RawRwLock> AsRef<T> for RwLockWriteGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawRwLock> AsRef<T> for RwLockWriteGuard<'_, T, R> { fn as_ref(&self) -> &T { self } } -impl<T: ?Sized, Key: Keyable, R: RawRwLock> AsMut<T> for RwLockWriteGuard<'_, '_, T, Key, R> { +impl<T: ?Sized, R: RawRwLock> AsMut<T> for RwLockWriteGuard<'_, T, R> { fn as_mut(&mut self) -> &mut T { self } } -impl<'a, T: ?Sized + 'a, Key: Keyable, R: RawRwLock> RwLockWriteGuard<'a, '_, T, Key, R> { +impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteGuard<'a, T, R> { /// Create a guard to the given mutex. Undefined if multiple guards to the /// same mutex exist at once. #[must_use] - pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: Key) -> Self { + pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self { Self { rwlock: RwLockWriteRef(rwlock, PhantomData), thread_key, - _phantom: PhantomData, } } } diff --git a/src/rwlock/write_lock.rs b/src/rwlock/write_lock.rs index cc96953..8a44a2d 100644 --- a/src/rwlock/write_lock.rs +++ b/src/rwlock/write_lock.rs @@ -2,8 +2,8 @@ use std::fmt::Debug; use lock_api::RawRwLock; -use crate::key::Keyable; use crate::lockable::{Lockable, RawLock}; +use crate::ThreadKey; use super::{RwLock, RwLockWriteGuard, RwLockWriteRef, WriteLock}; @@ -13,6 +13,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R where Self: 'g; + type DataMut<'a> + = &'a mut T + where + Self: 'a; + fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { ptrs.push(self.as_ref()); } @@ -20,6 +25,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R unsafe fn guard(&self) -> Self::Guard<'_> { RwLockWriteRef::new(self.as_ref()) } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.0.data_mut() + } } // Technically, the exclusive locks can also be shared, but there's currently @@ -108,10 +117,8 @@ impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> { /// ``` /// /// [`ThreadKey`]: `crate::ThreadKey` - pub fn lock<'s, 'key: 's, Key: Keyable + 'key>( - &'s self, - key: Key, - ) -> RwLockWriteGuard<'s, 'key, T, Key, R> { + #[must_use] + pub fn lock(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> { self.0.write(key) } @@ -145,10 +152,7 @@ impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> { /// Err(_) => unreachable!(), /// }; /// ``` - pub fn try_lock<'s, 'key: 's, Key: Keyable + 'key>( - &'s self, - key: Key, - ) -> Result<RwLockWriteGuard<'s, 'key, T, Key, R>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<RwLockWriteGuard<'_, T, R>, ThreadKey> { self.0.try_write(key) } @@ -176,7 +180,8 @@ impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> { /// *guard += 20; /// let key = WriteLock::unlock(guard); /// ``` - pub fn unlock<'key, Key: Keyable + 'key>(guard: RwLockWriteGuard<'_, 'key, T, Key, R>) -> Key { + #[must_use] + pub fn unlock(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey { RwLock::unlock_write(guard) } } |
