diff options
Diffstat (limited to 'src/collection/owned.rs')
| -rw-r--r-- | src/collection/owned.rs | 262 |
1 files changed, 202 insertions, 60 deletions
diff --git a/src/collection/owned.rs b/src/collection/owned.rs index c345b43..b9cf313 100644 --- a/src/collection/owned.rs +++ b/src/collection/owned.rs @@ -1,57 +1,47 @@ -use std::marker::PhantomData; - use crate::lockable::{ Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable, }; -use crate::Keyable; +use crate::{Keyable, ThreadKey}; use super::{utils, LockGuard, OwnedLockCollection}; -#[mutants::skip] // it's hard to test individual locks in an OwnedLockCollection -#[cfg(not(tarpaulin_include))] -fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> { - let mut locks = Vec::new(); - data.get_ptrs(&mut locks); - locks -} - unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> { #[mutants::skip] // this should never run #[cfg(not(tarpaulin_include))] fn poison(&self) { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); for lock in locks { lock.poison(); } } unsafe fn raw_lock(&self) { - utils::ordered_lock(&get_locks(&self.data)) + utils::ordered_lock(&utils::get_locks_unsorted(&self.data)) } unsafe fn raw_try_lock(&self) -> bool { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); utils::ordered_try_lock(&locks) } unsafe fn raw_unlock(&self) { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock(); } } unsafe fn raw_read(&self) { - utils::ordered_read(&get_locks(&self.data)) + utils::ordered_read(&utils::get_locks_unsorted(&self.data)) } unsafe fn raw_try_read(&self) -> bool { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); utils::ordered_try_read(&locks) } unsafe fn raw_unlock_read(&self) { - let locks = get_locks(&self.data); + let locks = utils::get_locks_unsorted(&self.data); for lock in locks { lock.raw_unlock_read(); } @@ -64,6 +54,11 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> { where Self: 'g; + type DataMut<'a> + = L::DataMut<'a> + where + Self: 'a; + #[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned #[cfg(not(tarpaulin_include))] fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) { @@ -73,6 +68,10 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> { unsafe fn guard(&self) -> Self::Guard<'_> { self.data.guard() } + + unsafe fn data_mut(&self) -> Self::DataMut<'_> { + self.data.data_mut() + } } impl<L: LockableGetMut> LockableGetMut for OwnedLockCollection<L> { @@ -100,9 +99,18 @@ unsafe impl<L: Sharable> Sharable for OwnedLockCollection<L> { where Self: 'g; + type DataRef<'a> + = L::DataRef<'a> + where + Self: 'a; + unsafe fn read_guard(&self) -> Self::ReadGuard<'_> { self.data.read_guard() } + + unsafe fn data_ref(&self) -> Self::DataRef<'_> { + self.data.data_ref() + } } unsafe impl<L: OwnedLockable> OwnedLockable for OwnedLockCollection<L> {} @@ -177,6 +185,46 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { Self { data } } + pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_lock(); + + // safety: the data was just locked + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_lock<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataMut<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_lock() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_mut()); + + // safety: the collection is still locked + self.raw_unlock(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection /// /// This function returns a guard that can be used to access the underlying @@ -197,10 +245,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// *guard.0 += 1; /// *guard.1 = "1"; /// ``` - pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::Guard<'g>, Key> { + pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> { let guard = unsafe { // safety: we have the thread key, and these locks happen in a // predetermined order @@ -210,11 +255,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { self.data.guard() }; - LockGuard { - guard, - key, - _phantom: PhantomData, - } + LockGuard { guard, key } } /// Attempts to lock the without blocking. @@ -247,10 +288,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// }; /// /// ``` - pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> { + pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> { let guard = unsafe { if !self.raw_try_lock() { return Err(key); @@ -260,11 +298,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { self.data.guard() }; - Ok(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -286,15 +320,53 @@ impl<L: OwnedLockable> OwnedLockCollection<L> { /// let key = OwnedLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock<'g, 'key: 'g, Key: Keyable + 'key>( - guard: LockGuard<'key, L::Guard<'g>, Key>, - ) -> Key { + pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } } impl<L: Sharable> OwnedLockCollection<L> { + pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R { + unsafe { + // safety: we have the thread key + self.raw_read(); + + // safety: the data was just locked + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensure the key stays alive long enough + + r + } + } + + pub fn scoped_try_read<Key: Keyable, R>( + &self, + key: Key, + f: impl Fn(L::DataRef<'_>) -> R, + ) -> Result<R, Key> { + unsafe { + // safety: we have the thread key + if !self.raw_try_read() { + return Err(key); + } + + // safety: we just locked the collection + let r = f(self.data_ref()); + + // safety: the collection is still locked + self.raw_unlock_read(); + + drop(key); // ensures the key stays valid long enough + + Ok(r) + } + } + /// Locks the collection, so that other threads can still read from it /// /// This function returns a guard that can be used to access the underlying @@ -315,10 +387,7 @@ impl<L: Sharable> OwnedLockCollection<L> { /// assert_eq!(*guard.0, 0); /// assert_eq!(*guard.1, ""); /// ``` - pub fn read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> LockGuard<'key, L::ReadGuard<'g>, Key> { + pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> { unsafe { // safety: we have the thread key self.raw_read(); @@ -327,7 +396,6 @@ impl<L: Sharable> OwnedLockCollection<L> { // safety: we've already acquired the lock guard: self.data.read_guard(), key, - _phantom: PhantomData, } } } @@ -355,33 +423,26 @@ impl<L: Sharable> OwnedLockCollection<L> { /// let lock = OwnedLockCollection::new(data); /// /// match lock.try_read(key) { - /// Some(mut guard) => { + /// Ok(mut guard) => { /// assert_eq!(*guard.0, 5); /// assert_eq!(*guard.1, "6"); /// }, - /// None => unreachable!(), + /// Err(_) => unreachable!(), /// }; /// /// ``` - pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>( - &'g self, - key: Key, - ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> { + pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> { let guard = unsafe { // safety: we have the thread key if !self.raw_try_read() { - return None; + return Err(key); } // safety: we've acquired the locks self.data.read_guard() }; - Some(LockGuard { - guard, - key, - _phantom: PhantomData, - }) + Ok(LockGuard { guard, key }) } /// Unlocks the underlying lockable data type, returning the key that's @@ -401,9 +462,7 @@ impl<L: Sharable> OwnedLockCollection<L> { /// let key = OwnedLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard); /// ``` #[allow(clippy::missing_const_for_fn)] - pub fn unlock_read<'g, 'key: 'g, Key: Keyable + 'key>( - guard: LockGuard<'key, L::ReadGuard<'g>, Key>, - ) -> Key { + pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey { drop(guard.guard); guard.key } @@ -495,7 +554,90 @@ impl<L: LockableIntoInner> OwnedLockCollection<L> { #[cfg(test)] mod tests { use super::*; - use crate::Mutex; + use crate::{Mutex, ThreadKey}; + + #[test] + fn get_mut_applies_changes() { + let key = ThreadKey::get().unwrap(); + let mut collection = OwnedLockCollection::new([Mutex::new("foo"), Mutex::new("bar")]); + assert_eq!(*collection.get_mut()[0], "foo"); + assert_eq!(*collection.get_mut()[1], "bar"); + *collection.get_mut()[0] = "baz"; + + let guard = collection.lock(key); + assert_eq!(*guard[0], "baz"); + assert_eq!(*guard[1], "bar"); + } + + #[test] + fn into_inner_works() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::from([Mutex::new("foo")]); + let mut guard = collection.lock(key); + *guard[0] = "bar"; + drop(guard); + + let array = collection.into_inner(); + assert_eq!(array.len(), 1); + assert_eq!(array[0], "bar"); + } + + #[test] + fn from_into_iter_is_correct() { + let array = [Mutex::new(0), Mutex::new(1), Mutex::new(2), Mutex::new(3)]; + let mut collection: OwnedLockCollection<Vec<Mutex<usize>>> = array.into_iter().collect(); + assert_eq!(collection.get_mut().len(), 4); + for (i, lock) in collection.into_iter().enumerate() { + assert_eq!(lock.into_inner(), i); + } + } + + #[test] + fn from_iter_is_correct() { + let array = [Mutex::new(0), Mutex::new(1), Mutex::new(2), Mutex::new(3)]; + let mut collection: OwnedLockCollection<Vec<Mutex<usize>>> = array.into_iter().collect(); + let collection: &mut Vec<_> = collection.as_mut(); + assert_eq!(collection.len(), 4); + for (i, lock) in collection.iter_mut().enumerate() { + assert_eq!(*lock.get_mut(), i); + } + } + + #[test] + fn try_lock_works_on_unlocked() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1))); + let guard = collection.try_lock(key).unwrap(); + assert_eq!(*guard.0, 0); + assert_eq!(*guard.1, 1); + } + + #[test] + fn try_lock_fails_on_locked() { + let key = ThreadKey::get().unwrap(); + let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1))); + + std::thread::scope(|s| { + s.spawn(|| { + let key = ThreadKey::get().unwrap(); + #[allow(unused)] + let guard = collection.lock(key); + std::mem::forget(guard); + }); + }); + + assert!(collection.try_lock(key).is_err()); + } + + #[test] + fn default_works() { + type MyCollection = OwnedLockCollection<(Mutex<i32>, Mutex<Option<i32>>, Mutex<String>)>; + let collection = MyCollection::default(); + let inner = collection.into_inner(); + assert_eq!(inner.0, 0); + assert_eq!(inner.1, None); + assert_eq!(inner.2, String::new()); + } #[test] fn can_be_extended() { |
