summaryrefslogtreecommitdiff
path: root/src/collection/boxed.rs
diff options
context:
space:
mode:
authorBotahamec <botahamec@outlook.com>2025-02-28 16:09:11 -0500
committerBotahamec <botahamec@outlook.com>2025-02-28 16:09:11 -0500
commit4ba03be97e6cc7e790bbc9bfc18caaa228c8a262 (patch)
treea257184577a93ddf240aba698755c2886188788b /src/collection/boxed.rs
parent4a5ec04a29cba07c5960792528bd66b0f99ee3ee (diff)
Scoped lock API
Diffstat (limited to 'src/collection/boxed.rs')
-rw-r--r--src/collection/boxed.rs170
1 files changed, 113 insertions, 57 deletions
diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs
index 0597e90..364ec97 100644
--- a/src/collection/boxed.rs
+++ b/src/collection/boxed.rs
@@ -1,25 +1,12 @@
use std::cell::UnsafeCell;
use std::fmt::Debug;
-use std::marker::PhantomData;
use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
+use super::utils::ordered_contains_duplicates;
use super::{utils, BoxedLockCollection, LockGuard};
-/// returns `true` if the sorted list contains a duplicate
-#[must_use]
-fn contains_duplicates(l: &[&dyn RawLock]) -> bool {
- if l.is_empty() {
- // Return early to prevent panic in the below call to `windows`
- return false;
- }
-
- l.windows(2)
- // NOTE: addr_eq is necessary because eq would also compare the v-table pointers
- .any(|window| std::ptr::addr_eq(window[0], window[1]))
-}
-
unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> {
#[mutants::skip] // this should never be called
#[cfg(not(tarpaulin_include))]
@@ -65,6 +52,11 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.extend(self.locks())
}
@@ -72,6 +64,10 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.child().guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.child().data_mut()
+ }
}
unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> {
@@ -80,9 +76,18 @@ unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.child().read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.child().data_ref()
+ }
}
unsafe impl<L: OwnedLockable> OwnedLockable for BoxedLockCollection<L> {}
@@ -352,13 +357,53 @@ impl<L: Lockable> BoxedLockCollection<L> {
// safety: we are checking for duplicates before returning
unsafe {
let this = Self::new_unchecked(data);
- if contains_duplicates(this.locks()) {
+ if ordered_contains_duplicates(this.locks()) {
return None;
}
Some(this)
}
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -378,10 +423,8 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ #[must_use]
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_lock();
@@ -390,7 +433,6 @@ impl<L: Lockable> BoxedLockCollection<L> {
// safety: we've already acquired the lock
guard: self.child().guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -424,10 +466,7 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
if !self.raw_try_lock() {
return Err(key);
@@ -437,11 +476,7 @@ impl<L: Lockable> BoxedLockCollection<L> {
self.child().guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -461,13 +496,53 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// *guard.1 = "1";
/// let key = LockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
- pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
impl<L: Sharable> BoxedLockCollection<L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -487,10 +562,8 @@ impl<L: Sharable> BoxedLockCollection<L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ #[must_use]
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_read();
@@ -499,7 +572,6 @@ impl<L: Sharable> BoxedLockCollection<L> {
// safety: we've already acquired the lock
guard: self.child().read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -534,10 +606,7 @@ impl<L: Sharable> BoxedLockCollection<L> {
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::ReadGuard<'g>, Key>, Key> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
let guard = unsafe {
// safety: we have the thread key
if !self.raw_try_read() {
@@ -548,11 +617,7 @@ impl<L: Sharable> BoxedLockCollection<L> {
self.child().read_guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -570,9 +635,7 @@ impl<L: Sharable> BoxedLockCollection<L> {
/// let mut guard = lock.read(key);
/// let key = LockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
- pub fn unlock_read<'key, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'_>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -635,7 +698,6 @@ mod tests {
.into_iter()
.collect();
let guard = collection.lock(key);
- // TODO impl PartialEq<T> for MutexRef<T>
assert_eq!(*guard[0], "foo");
assert_eq!(*guard[1], "bar");
assert_eq!(*guard[2], "baz");
@@ -647,7 +709,6 @@ mod tests {
let collection =
BoxedLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]);
let guard = collection.lock(key);
- // TODO impl PartialEq<T> for MutexRef<T>
assert_eq!(*guard[0], "foo");
assert_eq!(*guard[1], "bar");
assert_eq!(*guard[2], "baz");
@@ -666,7 +727,7 @@ mod tests {
let mut key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
for (i, mutex) in (&collection).into_iter().enumerate() {
- assert_eq!(*mutex.lock(&mut key), i);
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
}
}
@@ -675,7 +736,7 @@ mod tests {
let mut key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
for (i, mutex) in collection.iter().enumerate() {
- assert_eq!(*mutex.lock(&mut key), i);
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
}
}
@@ -704,11 +765,6 @@ mod tests {
}
#[test]
- fn contains_duplicates_empty() {
- assert!(!contains_duplicates(&[]))
- }
-
- #[test]
fn try_lock_works() {
let key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]);