summaryrefslogtreecommitdiff
path: root/src/collection/ref.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/collection/ref.rs')
-rw-r--r--src/collection/ref.rs261
1 files changed, 179 insertions, 82 deletions
diff --git a/src/collection/ref.rs b/src/collection/ref.rs
index b68b72f..5f96533 100644
--- a/src/collection/ref.rs
+++ b/src/collection/ref.rs
@@ -3,7 +3,10 @@ use std::fmt::Debug;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
use crate::{Keyable, ThreadKey};
-use super::utils::{get_locks, ordered_contains_duplicates};
+use super::utils::{
+ get_locks, ordered_contains_duplicates, scoped_read, scoped_try_read, scoped_try_write,
+ scoped_write,
+};
use super::{utils, LockGuard, RefLockCollection};
impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L>
@@ -27,17 +30,17 @@ unsafe impl<L: Lockable> RawLock for RefLockCollection<'_, L> {
}
}
- unsafe fn raw_lock(&self) {
- utils::ordered_lock(&self.locks)
+ unsafe fn raw_write(&self) {
+ utils::ordered_write(&self.locks)
}
- unsafe fn raw_try_lock(&self) -> bool {
- utils::ordered_try_lock(&self.locks)
+ unsafe fn raw_try_write(&self) -> bool {
+ utils::ordered_try_write(&self.locks)
}
- unsafe fn raw_unlock(&self) {
+ unsafe fn raw_unlock_write(&self) {
for lock in &self.locks {
- lock.raw_unlock();
+ lock.raw_unlock_write();
}
}
@@ -68,7 +71,7 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> {
Self: 'a;
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
- ptrs.extend_from_slice(&self.locks);
+ ptrs.push(self)
}
unsafe fn guard(&self) -> Self::Guard<'_> {
@@ -100,7 +103,7 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> {
}
}
-impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> {
+impl<T: ?Sized, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> {
fn as_ref(&self) -> &T {
self.data.as_ref()
}
@@ -234,44 +237,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
Some(Self { data, locks })
}
- pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_lock();
-
- // safety: the data was just locked
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_lock<'s, R>(&'s self, key: impl Keyable, f: impl Fn(L::DataMut<'s>) -> R) -> R {
+ scoped_write(self, key, f)
}
- pub fn scoped_try_lock<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_lock<'s, Key: Keyable, R>(
+ &'s self,
key: Key,
- f: impl Fn(L::DataMut<'_>) -> R,
+ f: impl Fn(L::DataMut<'s>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_lock() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_mut());
-
- // safety: the collection is still locked
- self.raw_unlock();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_write(self, key, f)
}
/// Locks the collection
@@ -298,7 +273,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
let guard = unsafe {
// safety: we have the thread key
- self.raw_lock();
+ self.raw_write();
// safety: we've locked all of this already
self.data.guard()
@@ -339,7 +314,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// ```
pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
- if !self.raw_try_lock() {
+ if !self.raw_try_write() {
return Err(key);
}
@@ -376,44 +351,16 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
}
impl<L: Sharable> RefLockCollection<'_, L> {
- pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
- unsafe {
- // safety: we have the thread key
- self.raw_read();
-
- // safety: the data was just locked
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensure the key stays alive long enough
-
- r
- }
+ pub fn scoped_read<'a, R>(&'a self, key: impl Keyable, f: impl Fn(L::DataRef<'a>) -> R) -> R {
+ scoped_read(self, key, f)
}
- pub fn scoped_try_read<Key: Keyable, R>(
- &self,
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
key: Key,
- f: impl Fn(L::DataRef<'_>) -> R,
+ f: impl Fn(L::DataRef<'a>) -> R,
) -> Result<R, Key> {
- unsafe {
- // safety: we have the thread key
- if !self.raw_try_read() {
- return Err(key);
- }
-
- // safety: we just locked the collection
- let r = f(self.data_ref());
-
- // safety: the collection is still locked
- self.raw_unlock_read();
-
- drop(key); // ensures the key stays valid long enough
-
- Ok(r)
- }
+ scoped_try_read(self, key, f)
}
/// Locks the collection, so that other threads can still read from it
@@ -565,6 +512,88 @@ mod tests {
}
#[test]
+ fn from() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")];
+ let collection = RefLockCollection::from(&mutexes);
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], "foo");
+ assert_eq!(*guard[1], "bar");
+ assert_eq!(*guard[2], "baz");
+ }
+
+ #[test]
+ fn scoped_lock_changes_collection() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(24), Mutex::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ let sum = collection.scoped_lock(&mut key, |guard| {
+ *guard[0] = 128;
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn scoped_read_sees_changes() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ collection.scoped_lock(&mut key, |guard| {
+ *guard[0] = 128;
+ });
+
+ let sum = collection.scoped_read(&mut key, |guard| {
+ assert_eq!(*guard[0], 128);
+ assert_eq!(*guard[1], 42);
+ *guard[0] + *guard[1]
+ });
+
+ assert_eq!(sum, 128 + 42);
+ }
+
+ #[test]
+ fn scoped_try_lock_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let locks = [Mutex::new(1), Mutex::new(2)];
+ let collection = RefLockCollection::new(&locks);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_lock(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
+ fn scoped_try_read_can_fail() {
+ let key = ThreadKey::get().unwrap();
+ let locks = [RwLock::new(1), RwLock::new(2)];
+ let collection = RefLockCollection::new(&locks);
+ let guard = collection.lock(key);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let r = collection.scoped_try_read(key, |_| {});
+ assert!(r.is_err());
+ });
+ });
+
+ drop(guard);
+ }
+
+ #[test]
fn try_lock_succeeds_for_unlocked_collection() {
let key = ThreadKey::get().unwrap();
let mutexes = [Mutex::new(24), Mutex::new(42)];
@@ -644,17 +673,85 @@ mod tests {
}
#[test]
+ fn into_ref_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)];
+ let collection = RefLockCollection::new(&mutexes);
+ for (i, mutex) in (&collection).into_iter().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
+ fn ref_iterator() {
+ let mut key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(0), Mutex::new(1), Mutex::new(2)];
+ let collection = RefLockCollection::new(&mutexes);
+ for (i, mutex) in collection.iter().enumerate() {
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
+ }
+ }
+
+ #[test]
fn works_in_collection() {
let key = ThreadKey::get().unwrap();
- let mutex1 = Mutex::new(0);
- let mutex2 = Mutex::new(1);
+ let mutex1 = RwLock::new(0);
+ let mutex2 = RwLock::new(1);
let collection0 = [&mutex1, &mutex2];
let collection1 = RefLockCollection::try_new(&collection0).unwrap();
let collection = RefLockCollection::try_new(&collection1).unwrap();
- let guard = collection.lock(key);
+ let mut guard = collection.lock(key);
assert!(mutex1.is_locked());
assert!(mutex2.is_locked());
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 1);
+ *guard[1] = 2;
drop(guard);
+
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.read(key);
+ assert!(mutex1.is_locked());
+ assert!(mutex2.is_locked());
+ assert_eq!(*guard[0], 0);
+ assert_eq!(*guard[1], 2);
+ }
+
+ #[test]
+ fn unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = (Mutex::new("foo"), Mutex::new("bar"));
+ let collection = RefLockCollection::new(&mutexes);
+ let guard = collection.lock(key);
+
+ let key = RefLockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
+ fn read_unlock_collection_works() {
+ let key = ThreadKey::get().unwrap();
+ let locks = (RwLock::new("foo"), RwLock::new("bar"));
+ let collection = RefLockCollection::new(&locks);
+ let guard = collection.read(key);
+
+ let key = RefLockCollection::<(&RwLock<_>, &RwLock<_>)>::unlock_read(guard);
+ assert!(collection.try_lock(key).is_ok())
+ }
+
+ #[test]
+ fn as_ref_works() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, collection.as_ref()))
+ }
+
+ #[test]
+ fn child() {
+ let mutexes = [Mutex::new(0), Mutex::new(1)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ assert!(std::ptr::addr_eq(&mutexes, collection.child()))
}
}