summaryrefslogtreecommitdiff
path: root/src/collection/ref.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/collection/ref.rs')
-rw-r--r--src/collection/ref.rs261
1 files changed, 196 insertions, 65 deletions
diff --git a/src/collection/ref.rs b/src/collection/ref.rs
index c86f298..b68b72f 100644
--- a/src/collection/ref.rs
+++ b/src/collection/ref.rs
@@ -1,32 +1,11 @@
use std::fmt::Debug;
-use std::marker::PhantomData;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
+use super::utils::{get_locks, ordered_contains_duplicates};
use super::{utils, LockGuard, RefLockCollection};
-#[must_use]
-pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
- let mut locks = Vec::new();
- data.get_ptrs(&mut locks);
- locks.sort_by_key(|lock| &raw const **lock);
- locks
-}
-
-/// returns `true` if the sorted list contains a duplicate
-#[must_use]
-fn contains_duplicates(l: &[&dyn RawLock]) -> bool {
- if l.is_empty() {
- // Return early to prevent panic in the below call to `windows`
- return false;
- }
-
- l.windows(2)
- // NOTE: addr_eq is necessary because eq would also compare the v-table pointers
- .any(|window| std::ptr::addr_eq(window[0], window[1]))
-}
-
impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L>
where
&'a L: IntoIterator,
@@ -83,6 +62,11 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.extend_from_slice(&self.locks);
}
@@ -90,6 +74,10 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.data.guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.data_mut()
+ }
}
unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> {
@@ -98,9 +86,18 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.data.read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.data_ref()
+ }
}
impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> {
@@ -230,13 +227,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
#[must_use]
pub fn try_new(data: &'a L) -> Option<Self> {
let locks = get_locks(data);
- if contains_duplicates(&locks) {
+ if ordered_contains_duplicates(&locks) {
return None;
}
Some(Self { data, locks })
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -257,10 +294,8 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ #[must_use]
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
let guard = unsafe {
// safety: we have the thread key
self.raw_lock();
@@ -269,11 +304,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
self.data.guard()
};
- LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- }
+ LockGuard { guard, key }
}
/// Attempts to lock the without blocking.
@@ -306,10 +337,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'a, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
if !self.raw_try_lock() {
return Err(key);
@@ -319,11 +347,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
self.data.guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -345,13 +369,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// let key = RefLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
#[allow(clippy::missing_const_for_fn)]
- pub fn unlock<'g, 'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'g>, Key>) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
-impl<'a, L: Sharable> RefLockCollection<'a, L> {
+impl<L: Sharable> RefLockCollection<'_, L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -372,10 +436,8 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ #[must_use]
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_read();
@@ -384,7 +446,6 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
// safety: we've already acquired the lock
guard: self.data.read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -412,33 +473,26 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
/// let lock = RefLockCollection::new(&data);
///
/// match lock.try_read(key) {
- /// Some(mut guard) => {
+ /// Ok(mut guard) => {
/// assert_eq!(*guard.0, 5);
/// assert_eq!(*guard.1, "6");
/// },
- /// None => unreachable!(),
+ /// Err(_) => unreachable!(),
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
let guard = unsafe {
// safety: we have the thread key
if !self.raw_try_read() {
- return None;
+ return Err(key);
}
// safety: we've acquired the locks
self.data.read_guard()
};
- Some(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -458,9 +512,7 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
/// let key = RefLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
#[allow(clippy::missing_const_for_fn)]
- pub fn unlock_read<'key: 'a, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'a>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -497,7 +549,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
- use crate::{Mutex, ThreadKey};
+ use crate::{Mutex, RwLock, ThreadKey};
#[test]
fn non_duplicates_allowed() {
@@ -513,6 +565,85 @@ mod tests {
}
#[test]
+ fn try_lock_succeeds_for_unlocked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(24), Mutex::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ let guard = collection.try_lock(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn try_lock_fails_for_locked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(24), Mutex::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = mutexes[1].lock(key);
+ assert_eq!(*guard, 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_lock(key);
+ assert!(guard.is_err());
+ }
+
+ #[test]
+ fn try_read_succeeds_for_unlocked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn try_read_fails_for_locked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = mutexes[1].write(key);
+ assert_eq!(*guard, 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key);
+ assert!(guard.is_err());
+ }
+
+ #[test]
+ fn can_read_twice_on_different_threads() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
fn works_in_collection() {
let key = ThreadKey::get().unwrap();
let mutex1 = Mutex::new(0);