summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/collection.rs8
-rw-r--r--src/collection/boxed.rs170
-rw-r--r--src/collection/guard.rs83
-rw-r--r--src/collection/owned.rs262
-rw-r--r--src/collection/ref.rs261
-rw-r--r--src/collection/retry.rs176
-rw-r--r--src/collection/utils.rs40
-rw-r--r--src/lockable.rs126
-rw-r--r--src/mutex.rs58
-rw-r--r--src/mutex/guard.rs71
-rw-r--r--src/mutex/mutex.rs61
-rw-r--r--src/poisonable.rs229
-rw-r--r--src/poisonable/error.rs14
-rw-r--r--src/poisonable/guard.rs64
-rw-r--r--src/poisonable/poisonable.rs157
-rw-r--r--src/rwlock.rs57
-rw-r--r--src/rwlock/read_guard.rs71
-rw-r--r--src/rwlock/read_lock.rs34
-rw-r--r--src/rwlock/rwlock.rs130
-rw-r--r--src/rwlock/write_guard.rs75
-rw-r--r--src/rwlock/write_lock.rs25
21 files changed, 1438 insertions, 734 deletions
diff --git a/src/collection.rs b/src/collection.rs
index db68382..e50cc30 100644
--- a/src/collection.rs
+++ b/src/collection.rs
@@ -1,7 +1,6 @@
use std::cell::UnsafeCell;
-use std::marker::PhantomData;
-use crate::{key::Keyable, lockable::RawLock};
+use crate::{lockable::RawLock, ThreadKey};
mod boxed;
mod guard;
@@ -122,8 +121,7 @@ pub struct RetryingLockCollection<L> {
/// A RAII guard for a generic [`Lockable`] type.
///
/// [`Lockable`]: `crate::lockable::Lockable`
-pub struct LockGuard<'key, Guard, Key: Keyable + 'key> {
+pub struct LockGuard<Guard> {
guard: Guard,
- key: Key,
- _phantom: PhantomData<&'key ()>,
+ key: ThreadKey,
}
diff --git a/src/collection/boxed.rs b/src/collection/boxed.rs
index 0597e90..364ec97 100644
--- a/src/collection/boxed.rs
+++ b/src/collection/boxed.rs
@@ -1,25 +1,12 @@
use std::cell::UnsafeCell;
use std::fmt::Debug;
-use std::marker::PhantomData;
use crate::lockable::{Lockable, LockableIntoInner, OwnedLockable, RawLock, Sharable};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
+use super::utils::ordered_contains_duplicates;
use super::{utils, BoxedLockCollection, LockGuard};
-/// returns `true` if the sorted list contains a duplicate
-#[must_use]
-fn contains_duplicates(l: &[&dyn RawLock]) -> bool {
- if l.is_empty() {
- // Return early to prevent panic in the below call to `windows`
- return false;
- }
-
- l.windows(2)
- // NOTE: addr_eq is necessary because eq would also compare the v-table pointers
- .any(|window| std::ptr::addr_eq(window[0], window[1]))
-}
-
unsafe impl<L: Lockable> RawLock for BoxedLockCollection<L> {
#[mutants::skip] // this should never be called
#[cfg(not(tarpaulin_include))]
@@ -65,6 +52,11 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.extend(self.locks())
}
@@ -72,6 +64,10 @@ unsafe impl<L: Lockable> Lockable for BoxedLockCollection<L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.child().guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.child().data_mut()
+ }
}
unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> {
@@ -80,9 +76,18 @@ unsafe impl<L: Sharable> Sharable for BoxedLockCollection<L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.child().read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.child().data_ref()
+ }
}
unsafe impl<L: OwnedLockable> OwnedLockable for BoxedLockCollection<L> {}
@@ -352,13 +357,53 @@ impl<L: Lockable> BoxedLockCollection<L> {
// safety: we are checking for duplicates before returning
unsafe {
let this = Self::new_unchecked(data);
- if contains_duplicates(this.locks()) {
+ if ordered_contains_duplicates(this.locks()) {
return None;
}
Some(this)
}
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -378,10 +423,8 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ #[must_use]
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_lock();
@@ -390,7 +433,6 @@ impl<L: Lockable> BoxedLockCollection<L> {
// safety: we've already acquired the lock
guard: self.child().guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -424,10 +466,7 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
if !self.raw_try_lock() {
return Err(key);
@@ -437,11 +476,7 @@ impl<L: Lockable> BoxedLockCollection<L> {
self.child().guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -461,13 +496,53 @@ impl<L: Lockable> BoxedLockCollection<L> {
/// *guard.1 = "1";
/// let key = LockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
- pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
impl<L: Sharable> BoxedLockCollection<L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -487,10 +562,8 @@ impl<L: Sharable> BoxedLockCollection<L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ #[must_use]
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_read();
@@ -499,7 +572,6 @@ impl<L: Sharable> BoxedLockCollection<L> {
// safety: we've already acquired the lock
guard: self.child().read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -534,10 +606,7 @@ impl<L: Sharable> BoxedLockCollection<L> {
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::ReadGuard<'g>, Key>, Key> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
let guard = unsafe {
// safety: we have the thread key
if !self.raw_try_read() {
@@ -548,11 +617,7 @@ impl<L: Sharable> BoxedLockCollection<L> {
self.child().read_guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -570,9 +635,7 @@ impl<L: Sharable> BoxedLockCollection<L> {
/// let mut guard = lock.read(key);
/// let key = LockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
- pub fn unlock_read<'key, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'_>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -635,7 +698,6 @@ mod tests {
.into_iter()
.collect();
let guard = collection.lock(key);
- // TODO impl PartialEq<T> for MutexRef<T>
assert_eq!(*guard[0], "foo");
assert_eq!(*guard[1], "bar");
assert_eq!(*guard[2], "baz");
@@ -647,7 +709,6 @@ mod tests {
let collection =
BoxedLockCollection::from([Mutex::new("foo"), Mutex::new("bar"), Mutex::new("baz")]);
let guard = collection.lock(key);
- // TODO impl PartialEq<T> for MutexRef<T>
assert_eq!(*guard[0], "foo");
assert_eq!(*guard[1], "bar");
assert_eq!(*guard[2], "baz");
@@ -666,7 +727,7 @@ mod tests {
let mut key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
for (i, mutex) in (&collection).into_iter().enumerate() {
- assert_eq!(*mutex.lock(&mut key), i);
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
}
}
@@ -675,7 +736,7 @@ mod tests {
let mut key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(0), Mutex::new(1), Mutex::new(2)]);
for (i, mutex) in collection.iter().enumerate() {
- assert_eq!(*mutex.lock(&mut key), i);
+ mutex.scoped_lock(&mut key, |val| assert_eq!(*val, i))
}
}
@@ -704,11 +765,6 @@ mod tests {
}
#[test]
- fn contains_duplicates_empty() {
- assert!(!contains_duplicates(&[]))
- }
-
- #[test]
fn try_lock_works() {
let key = ThreadKey::get().unwrap();
let collection = BoxedLockCollection::new([Mutex::new(1), Mutex::new(2)]);
diff --git a/src/collection/guard.rs b/src/collection/guard.rs
index eea13ed..78d9895 100644
--- a/src/collection/guard.rs
+++ b/src/collection/guard.rs
@@ -2,41 +2,11 @@ use std::fmt::{Debug, Display};
use std::hash::Hash;
use std::ops::{Deref, DerefMut};
-use crate::key::Keyable;
-
use super::LockGuard;
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: PartialEq, Key: Keyable> PartialEq for LockGuard<'_, Guard, Key> {
- fn eq(&self, other: &Self) -> bool {
- self.guard.eq(&other.guard)
- }
-}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: PartialOrd, Key: Keyable> PartialOrd for LockGuard<'_, Guard, Key> {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.guard.partial_cmp(&other.guard)
- }
-}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: Eq, Key: Keyable> Eq for LockGuard<'_, Guard, Key> {}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: Ord, Key: Keyable> Ord for LockGuard<'_, Guard, Key> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.guard.cmp(&other.guard)
- }
-}
-
#[mutants::skip] // hashing involves RNG and is hard to test
#[cfg(not(tarpaulin_include))]
-impl<Guard: Hash, Key: Keyable> Hash for LockGuard<'_, Guard, Key> {
+impl<Guard: Hash> Hash for LockGuard<Guard> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.guard.hash(state)
}
@@ -44,19 +14,19 @@ impl<Guard: Hash, Key: Keyable> Hash for LockGuard<'_, Guard, Key> {
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<Guard: Debug, Key: Keyable> Debug for LockGuard<'_, Guard, Key> {
+impl<Guard: Debug> Debug for LockGuard<Guard> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
-impl<Guard: Display, Key: Keyable> Display for LockGuard<'_, Guard, Key> {
+impl<Guard: Display> Display for LockGuard<Guard> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&**self, f)
}
}
-impl<Guard, Key: Keyable> Deref for LockGuard<'_, Guard, Key> {
+impl<Guard> Deref for LockGuard<Guard> {
type Target = Guard;
fn deref(&self) -> &Self::Target {
@@ -64,19 +34,19 @@ impl<Guard, Key: Keyable> Deref for LockGuard<'_, Guard, Key> {
}
}
-impl<Guard, Key: Keyable> DerefMut for LockGuard<'_, Guard, Key> {
+impl<Guard> DerefMut for LockGuard<Guard> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.guard
}
}
-impl<Guard, Key: Keyable> AsRef<Guard> for LockGuard<'_, Guard, Key> {
+impl<Guard> AsRef<Guard> for LockGuard<Guard> {
fn as_ref(&self) -> &Guard {
&self.guard
}
}
-impl<Guard, Key: Keyable> AsMut<Guard> for LockGuard<'_, Guard, Key> {
+impl<Guard> AsMut<Guard> for LockGuard<Guard> {
fn as_mut(&mut self) -> &mut Guard {
&mut self.guard
}
@@ -97,56 +67,53 @@ mod tests {
#[test]
fn deref_mut_works() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let locks = (Mutex::new(1), Mutex::new(2));
let lock = LockCollection::new_ref(&locks);
- let mut guard = lock.lock(&mut key);
+ let mut guard = lock.lock(key);
*guard.0 = 3;
- drop(guard);
+ let key = LockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
- let guard = locks.0.lock(&mut key);
+ let guard = locks.0.lock(key);
assert_eq!(*guard, 3);
- drop(guard);
+ let key = Mutex::unlock(guard);
- let guard = locks.1.lock(&mut key);
+ let guard = locks.1.lock(key);
assert_eq!(*guard, 2);
- drop(guard);
}
#[test]
fn as_ref_works() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let locks = (Mutex::new(1), Mutex::new(2));
let lock = LockCollection::new_ref(&locks);
- let mut guard = lock.lock(&mut key);
+ let mut guard = lock.lock(key);
*guard.0 = 3;
- drop(guard);
+ let key = LockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
- let guard = locks.0.lock(&mut key);
+ let guard = locks.0.lock(key);
assert_eq!(guard.as_ref(), &3);
- drop(guard);
+ let key = Mutex::unlock(guard);
- let guard = locks.1.lock(&mut key);
+ let guard = locks.1.lock(key);
assert_eq!(guard.as_ref(), &2);
- drop(guard);
}
#[test]
fn as_mut_works() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let locks = (Mutex::new(1), Mutex::new(2));
let lock = LockCollection::new_ref(&locks);
- let mut guard = lock.lock(&mut key);
+ let mut guard = lock.lock(key);
let guard_mut = guard.as_mut();
*guard_mut.0 = 3;
- drop(guard);
+ let key = LockCollection::<(Mutex<_>, Mutex<_>)>::unlock(guard);
- let guard = locks.0.lock(&mut key);
+ let guard = locks.0.lock(key);
assert_eq!(guard.as_ref(), &3);
- drop(guard);
+ let key = Mutex::unlock(guard);
- let guard = locks.1.lock(&mut key);
+ let guard = locks.1.lock(key);
assert_eq!(guard.as_ref(), &2);
- drop(guard);
}
}
diff --git a/src/collection/owned.rs b/src/collection/owned.rs
index c345b43..b9cf313 100644
--- a/src/collection/owned.rs
+++ b/src/collection/owned.rs
@@ -1,57 +1,47 @@
-use std::marker::PhantomData;
-
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
use super::{utils, LockGuard, OwnedLockCollection};
-#[mutants::skip] // it's hard to test individual locks in an OwnedLockCollection
-#[cfg(not(tarpaulin_include))]
-fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
- let mut locks = Vec::new();
- data.get_ptrs(&mut locks);
- locks
-}
-
unsafe impl<L: Lockable> RawLock for OwnedLockCollection<L> {
#[mutants::skip] // this should never run
#[cfg(not(tarpaulin_include))]
fn poison(&self) {
- let locks = get_locks(&self.data);
+ let locks = utils::get_locks_unsorted(&self.data);
for lock in locks {
lock.poison();
}
}
unsafe fn raw_lock(&self) {
- utils::ordered_lock(&get_locks(&self.data))
+ utils::ordered_lock(&utils::get_locks_unsorted(&self.data))
}
unsafe fn raw_try_lock(&self) -> bool {
- let locks = get_locks(&self.data);
+ let locks = utils::get_locks_unsorted(&self.data);
utils::ordered_try_lock(&locks)
}
unsafe fn raw_unlock(&self) {
- let locks = get_locks(&self.data);
+ let locks = utils::get_locks_unsorted(&self.data);
for lock in locks {
lock.raw_unlock();
}
}
unsafe fn raw_read(&self) {
- utils::ordered_read(&get_locks(&self.data))
+ utils::ordered_read(&utils::get_locks_unsorted(&self.data))
}
unsafe fn raw_try_read(&self) -> bool {
- let locks = get_locks(&self.data);
+ let locks = utils::get_locks_unsorted(&self.data);
utils::ordered_try_read(&locks)
}
unsafe fn raw_unlock_read(&self) {
- let locks = get_locks(&self.data);
+ let locks = utils::get_locks_unsorted(&self.data);
for lock in locks {
lock.raw_unlock_read();
}
@@ -64,6 +54,11 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
#[mutants::skip] // It's hard to test lkocks in an OwnedLockCollection, because they're owned
#[cfg(not(tarpaulin_include))]
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
@@ -73,6 +68,10 @@ unsafe impl<L: Lockable> Lockable for OwnedLockCollection<L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.data.guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.data_mut()
+ }
}
impl<L: LockableGetMut> LockableGetMut for OwnedLockCollection<L> {
@@ -100,9 +99,18 @@ unsafe impl<L: Sharable> Sharable for OwnedLockCollection<L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.data.read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.data_ref()
+ }
}
unsafe impl<L: OwnedLockable> OwnedLockable for OwnedLockCollection<L> {}
@@ -177,6 +185,46 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
Self { data }
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -197,10 +245,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
let guard = unsafe {
// safety: we have the thread key, and these locks happen in a
// predetermined order
@@ -210,11 +255,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
self.data.guard()
};
- LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- }
+ LockGuard { guard, key }
}
/// Attempts to lock the without blocking.
@@ -247,10 +288,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
if !self.raw_try_lock() {
return Err(key);
@@ -260,11 +298,7 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
self.data.guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -286,15 +320,53 @@ impl<L: OwnedLockable> OwnedLockCollection<L> {
/// let key = OwnedLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
#[allow(clippy::missing_const_for_fn)]
- pub fn unlock<'g, 'key: 'g, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::Guard<'g>, Key>,
- ) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
impl<L: Sharable> OwnedLockCollection<L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -315,10 +387,7 @@ impl<L: Sharable> OwnedLockCollection<L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_read();
@@ -327,7 +396,6 @@ impl<L: Sharable> OwnedLockCollection<L> {
// safety: we've already acquired the lock
guard: self.data.read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -355,33 +423,26 @@ impl<L: Sharable> OwnedLockCollection<L> {
/// let lock = OwnedLockCollection::new(data);
///
/// match lock.try_read(key) {
- /// Some(mut guard) => {
+ /// Ok(mut guard) => {
/// assert_eq!(*guard.0, 5);
/// assert_eq!(*guard.1, "6");
/// },
- /// None => unreachable!(),
+ /// Err(_) => unreachable!(),
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
let guard = unsafe {
// safety: we have the thread key
if !self.raw_try_read() {
- return None;
+ return Err(key);
}
// safety: we've acquired the locks
self.data.read_guard()
};
- Some(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -401,9 +462,7 @@ impl<L: Sharable> OwnedLockCollection<L> {
/// let key = OwnedLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
#[allow(clippy::missing_const_for_fn)]
- pub fn unlock_read<'g, 'key: 'g, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'g>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -495,7 +554,90 @@ impl<L: LockableIntoInner> OwnedLockCollection<L> {
#[cfg(test)]
mod tests {
use super::*;
- use crate::Mutex;
+ use crate::{Mutex, ThreadKey};
+
+ #[test]
+ fn get_mut_applies_changes() {
+ let key = ThreadKey::get().unwrap();
+ let mut collection = OwnedLockCollection::new([Mutex::new("foo"), Mutex::new("bar")]);
+ assert_eq!(*collection.get_mut()[0], "foo");
+ assert_eq!(*collection.get_mut()[1], "bar");
+ *collection.get_mut()[0] = "baz";
+
+ let guard = collection.lock(key);
+ assert_eq!(*guard[0], "baz");
+ assert_eq!(*guard[1], "bar");
+ }
+
+ #[test]
+ fn into_inner_works() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::from([Mutex::new("foo")]);
+ let mut guard = collection.lock(key);
+ *guard[0] = "bar";
+ drop(guard);
+
+ let array = collection.into_inner();
+ assert_eq!(array.len(), 1);
+ assert_eq!(array[0], "bar");
+ }
+
+ #[test]
+ fn from_into_iter_is_correct() {
+ let array = [Mutex::new(0), Mutex::new(1), Mutex::new(2), Mutex::new(3)];
+ let mut collection: OwnedLockCollection<Vec<Mutex<usize>>> = array.into_iter().collect();
+ assert_eq!(collection.get_mut().len(), 4);
+ for (i, lock) in collection.into_iter().enumerate() {
+ assert_eq!(lock.into_inner(), i);
+ }
+ }
+
+ #[test]
+ fn from_iter_is_correct() {
+ let array = [Mutex::new(0), Mutex::new(1), Mutex::new(2), Mutex::new(3)];
+ let mut collection: OwnedLockCollection<Vec<Mutex<usize>>> = array.into_iter().collect();
+ let collection: &mut Vec<_> = collection.as_mut();
+ assert_eq!(collection.len(), 4);
+ for (i, lock) in collection.iter_mut().enumerate() {
+ assert_eq!(*lock.get_mut(), i);
+ }
+ }
+
+ #[test]
+ fn try_lock_works_on_unlocked() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1)));
+ let guard = collection.try_lock(key).unwrap();
+ assert_eq!(*guard.0, 0);
+ assert_eq!(*guard.1, 1);
+ }
+
+ #[test]
+ fn try_lock_fails_on_locked() {
+ let key = ThreadKey::get().unwrap();
+ let collection = OwnedLockCollection::new((Mutex::new(0), Mutex::new(1)));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused)]
+ let guard = collection.lock(key);
+ std::mem::forget(guard);
+ });
+ });
+
+ assert!(collection.try_lock(key).is_err());
+ }
+
+ #[test]
+ fn default_works() {
+ type MyCollection = OwnedLockCollection<(Mutex<i32>, Mutex<Option<i32>>, Mutex<String>)>;
+ let collection = MyCollection::default();
+ let inner = collection.into_inner();
+ assert_eq!(inner.0, 0);
+ assert_eq!(inner.1, None);
+ assert_eq!(inner.2, String::new());
+ }
#[test]
fn can_be_extended() {
diff --git a/src/collection/ref.rs b/src/collection/ref.rs
index c86f298..b68b72f 100644
--- a/src/collection/ref.rs
+++ b/src/collection/ref.rs
@@ -1,32 +1,11 @@
use std::fmt::Debug;
-use std::marker::PhantomData;
use crate::lockable::{Lockable, OwnedLockable, RawLock, Sharable};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
+use super::utils::{get_locks, ordered_contains_duplicates};
use super::{utils, LockGuard, RefLockCollection};
-#[must_use]
-pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
- let mut locks = Vec::new();
- data.get_ptrs(&mut locks);
- locks.sort_by_key(|lock| &raw const **lock);
- locks
-}
-
-/// returns `true` if the sorted list contains a duplicate
-#[must_use]
-fn contains_duplicates(l: &[&dyn RawLock]) -> bool {
- if l.is_empty() {
- // Return early to prevent panic in the below call to `windows`
- return false;
- }
-
- l.windows(2)
- // NOTE: addr_eq is necessary because eq would also compare the v-table pointers
- .any(|window| std::ptr::addr_eq(window[0], window[1]))
-}
-
impl<'a, L> IntoIterator for &'a RefLockCollection<'a, L>
where
&'a L: IntoIterator,
@@ -83,6 +62,11 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.extend_from_slice(&self.locks);
}
@@ -90,6 +74,10 @@ unsafe impl<L: Lockable> Lockable for RefLockCollection<'_, L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.data.guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.data_mut()
+ }
}
unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> {
@@ -98,9 +86,18 @@ unsafe impl<L: Sharable> Sharable for RefLockCollection<'_, L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.data.read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.data_ref()
+ }
}
impl<T, L: AsRef<T>> AsRef<T> for RefLockCollection<'_, L> {
@@ -230,13 +227,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
#[must_use]
pub fn try_new(data: &'a L) -> Option<Self> {
let locks = get_locks(data);
- if contains_duplicates(&locks) {
+ if ordered_contains_duplicates(&locks) {
return None;
}
Some(Self { data, locks })
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -257,10 +294,8 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ #[must_use]
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
let guard = unsafe {
// safety: we have the thread key
self.raw_lock();
@@ -269,11 +304,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
self.data.guard()
};
- LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- }
+ LockGuard { guard, key }
}
/// Attempts to lock the without blocking.
@@ -306,10 +337,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'a, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
let guard = unsafe {
if !self.raw_try_lock() {
return Err(key);
@@ -319,11 +347,7 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
self.data.guard()
};
- Ok(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -345,13 +369,53 @@ impl<'a, L: Lockable> RefLockCollection<'a, L> {
/// let key = RefLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
#[allow(clippy::missing_const_for_fn)]
- pub fn unlock<'g, 'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'g>, Key>) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
-impl<'a, L: Sharable> RefLockCollection<'a, L> {
+impl<L: Sharable> RefLockCollection<'_, L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -372,10 +436,8 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ #[must_use]
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we have the thread key
self.raw_read();
@@ -384,7 +446,6 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
// safety: we've already acquired the lock
guard: self.data.read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -412,33 +473,26 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
/// let lock = RefLockCollection::new(&data);
///
/// match lock.try_read(key) {
- /// Some(mut guard) => {
+ /// Ok(mut guard) => {
/// assert_eq!(*guard.0, 5);
/// assert_eq!(*guard.1, "6");
/// },
- /// None => unreachable!(),
+ /// Err(_) => unreachable!(),
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
let guard = unsafe {
// safety: we have the thread key
if !self.raw_try_read() {
- return None;
+ return Err(key);
}
// safety: we've acquired the locks
self.data.read_guard()
};
- Some(LockGuard {
- guard,
- key,
- _phantom: PhantomData,
- })
+ Ok(LockGuard { guard, key })
}
/// Unlocks the underlying lockable data type, returning the key that's
@@ -458,9 +512,7 @@ impl<'a, L: Sharable> RefLockCollection<'a, L> {
/// let key = RefLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
#[allow(clippy::missing_const_for_fn)]
- pub fn unlock_read<'key: 'a, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'a>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -497,7 +549,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
- use crate::{Mutex, ThreadKey};
+ use crate::{Mutex, RwLock, ThreadKey};
#[test]
fn non_duplicates_allowed() {
@@ -513,6 +565,85 @@ mod tests {
}
#[test]
+ fn try_lock_succeeds_for_unlocked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(24), Mutex::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ let guard = collection.try_lock(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn try_lock_fails_for_locked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [Mutex::new(24), Mutex::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = mutexes[1].lock(key);
+ assert_eq!(*guard, 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_lock(key);
+ assert!(guard.is_err());
+ }
+
+ #[test]
+ fn try_read_succeeds_for_unlocked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
+ fn try_read_fails_for_locked_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = mutexes[1].write(key);
+ assert_eq!(*guard, 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key);
+ assert!(guard.is_err());
+ }
+
+ #[test]
+ fn can_read_twice_on_different_threads() {
+ let key = ThreadKey::get().unwrap();
+ let mutexes = [RwLock::new(24), RwLock::new(42)];
+ let collection = RefLockCollection::new(&mutexes);
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let guard = collection.read(key);
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ std::mem::forget(guard);
+ });
+ });
+
+ let guard = collection.try_read(key).unwrap();
+ assert_eq!(*guard[0], 24);
+ assert_eq!(*guard[1], 42);
+ }
+
+ #[test]
fn works_in_collection() {
let key = ThreadKey::get().unwrap();
let mutex1 = Mutex::new(0);
diff --git a/src/collection/retry.rs b/src/collection/retry.rs
index 331b669..775ea29 100644
--- a/src/collection/retry.rs
+++ b/src/collection/retry.rs
@@ -1,24 +1,18 @@
use std::cell::Cell;
use std::collections::HashSet;
-use std::marker::PhantomData;
use crate::collection::utils;
use crate::handle_unwind::handle_unwind;
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
-use super::utils::{attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic};
+use super::utils::{
+ attempt_to_recover_locks_from_panic, attempt_to_recover_reads_from_panic, get_locks_unsorted,
+};
use super::{LockGuard, RetryingLockCollection};
-/// Get all raw locks in the collection
-fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
- let mut locks = Vec::new();
- data.get_ptrs(&mut locks);
- locks
-}
-
/// Checks that a collection contains no duplicate references to a lock.
fn contains_duplicates<L: Lockable>(data: L) -> bool {
let mut locks = Vec::new();
@@ -40,14 +34,14 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
#[mutants::skip] // this should never run
#[cfg(not(tarpaulin_include))]
fn poison(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
for lock in locks {
lock.poison();
}
}
unsafe fn raw_lock(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this probably prevents a panic later
@@ -109,7 +103,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_try_lock(&self) -> bool {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this is an interesting case, but it doesn't give us access to
@@ -139,7 +133,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_unlock(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
for lock in locks {
lock.raw_unlock();
@@ -147,7 +141,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_read(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this probably prevents a panic later
@@ -200,7 +194,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_try_read(&self) -> bool {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
if locks.is_empty() {
// this is an interesting case, but it doesn't give us access to
@@ -229,7 +223,7 @@ unsafe impl<L: Lockable> RawLock for RetryingLockCollection<L> {
}
unsafe fn raw_unlock_read(&self) {
- let locks = get_locks(&self.data);
+ let locks = get_locks_unsorted(&self.data);
for lock in locks {
lock.raw_unlock_read();
@@ -243,6 +237,11 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
where
Self: 'g;
+ type DataMut<'a>
+ = L::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
self.data.get_ptrs(ptrs)
}
@@ -250,6 +249,10 @@ unsafe impl<L: Lockable> Lockable for RetryingLockCollection<L> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.data.guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.data_mut()
+ }
}
unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> {
@@ -258,9 +261,18 @@ unsafe impl<L: Sharable> Sharable for RetryingLockCollection<L> {
where
Self: 'g;
+ type DataRef<'a>
+ = L::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.data.read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.data_ref()
+ }
}
unsafe impl<L: OwnedLockable> OwnedLockable for RetryingLockCollection<L> {}
@@ -516,6 +528,46 @@ impl<L: Lockable> RetryingLockCollection<L> {
(!contains_duplicates(&data)).then_some(Self { data })
}
+ pub fn scoped_lock<R>(&self, key: impl Keyable, f: impl Fn(L::DataMut<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataMut<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection
///
/// This function returns a guard that can be used to access the underlying
@@ -536,10 +588,7 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// *guard.0 += 1;
/// *guard.1 = "1";
/// ```
- pub fn lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::Guard<'g>, Key> {
+ pub fn lock(&self, key: ThreadKey) -> LockGuard<L::Guard<'_>> {
unsafe {
// safety: we're taking the thread key
self.raw_lock();
@@ -548,7 +597,6 @@ impl<L: Lockable> RetryingLockCollection<L> {
// safety: we just locked the collection
guard: self.guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -583,10 +631,7 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// };
///
/// ```
- pub fn try_lock<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Result<LockGuard<'key, L::Guard<'g>, Key>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<LockGuard<L::Guard<'_>>, ThreadKey> {
unsafe {
// safety: we're taking the thread key
if self.raw_try_lock() {
@@ -594,7 +639,6 @@ impl<L: Lockable> RetryingLockCollection<L> {
// safety: we just succeeded in locking everything
guard: self.guard(),
key,
- _phantom: PhantomData,
})
} else {
Err(key)
@@ -620,13 +664,53 @@ impl<L: Lockable> RetryingLockCollection<L> {
/// *guard.1 = "1";
/// let key = RetryingLockCollection::<(Mutex<i32>, Mutex<&str>)>::unlock(guard);
/// ```
- pub fn unlock<'key, Key: Keyable + 'key>(guard: LockGuard<'key, L::Guard<'_>, Key>) -> Key {
+ pub fn unlock(guard: LockGuard<L::Guard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
impl<L: Sharable> RetryingLockCollection<L> {
+ pub fn scoped_read<R>(&self, key: impl Keyable, f: impl Fn(L::DataRef<'_>) -> R) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, R>(
+ &self,
+ key: Key,
+ f: impl Fn(L::DataRef<'_>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks the collection, so that other threads can still read from it
///
/// This function returns a guard that can be used to access the underlying
@@ -647,10 +731,7 @@ impl<L: Sharable> RetryingLockCollection<L> {
/// assert_eq!(*guard.0, 0);
/// assert_eq!(*guard.1, "");
/// ```
- pub fn read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> LockGuard<'key, L::ReadGuard<'g>, Key> {
+ pub fn read(&self, key: ThreadKey) -> LockGuard<L::ReadGuard<'_>> {
unsafe {
// safety: we're taking the thread key
self.raw_read();
@@ -659,7 +740,6 @@ impl<L: Sharable> RetryingLockCollection<L> {
// safety: we just locked the collection
guard: self.read_guard(),
key,
- _phantom: PhantomData,
}
}
}
@@ -687,25 +767,25 @@ impl<L: Sharable> RetryingLockCollection<L> {
/// let lock = RetryingLockCollection::new(data);
///
/// match lock.try_read(key) {
- /// Some(mut guard) => {
+ /// Ok(mut guard) => {
/// assert_eq!(*guard.0, 5);
/// assert_eq!(*guard.1, "6");
/// },
- /// None => unreachable!(),
+ /// Err(_) => unreachable!(),
/// };
///
/// ```
- pub fn try_read<'g, 'key: 'g, Key: Keyable + 'key>(
- &'g self,
- key: Key,
- ) -> Option<LockGuard<'key, L::ReadGuard<'g>, Key>> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<LockGuard<L::ReadGuard<'_>>, ThreadKey> {
unsafe {
// safety: we're taking the thread key
- self.raw_try_lock().then(|| LockGuard {
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ Ok(LockGuard {
// safety: we just succeeded in locking everything
guard: self.read_guard(),
key,
- _phantom: PhantomData,
})
}
}
@@ -726,9 +806,7 @@ impl<L: Sharable> RetryingLockCollection<L> {
/// let mut guard = lock.read(key);
/// let key = RetryingLockCollection::<(RwLock<i32>, RwLock<&str>)>::unlock_read(guard);
/// ```
- pub fn unlock_read<'key, Key: Keyable + 'key>(
- guard: LockGuard<'key, L::ReadGuard<'_>, Key>,
- ) -> Key {
+ pub fn unlock_read(guard: LockGuard<L::ReadGuard<'_>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
@@ -833,7 +911,7 @@ where
mod tests {
use super::*;
use crate::collection::BoxedLockCollection;
- use crate::{Mutex, RwLock, ThreadKey};
+ use crate::{LockCollection, Mutex, RwLock, ThreadKey};
#[test]
fn nonduplicate_lock_references_are_allowed() {
@@ -869,7 +947,6 @@ mod tests {
let rwlock1 = RwLock::new(0);
let rwlock2 = RwLock::new(0);
let collection = RetryingLockCollection::try_new([&rwlock1, &rwlock2]).unwrap();
- // TODO Poisonable::read
let guard = collection.read(key);
@@ -909,13 +986,14 @@ mod tests {
#[test]
fn lock_empty_lock_collection() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let collection: RetryingLockCollection<[RwLock<i32>; 0]> = RetryingLockCollection::new([]);
- let guard = collection.lock(&mut key);
+ let guard = collection.lock(key);
assert!(guard.len() == 0);
+ let key = LockCollection::<[RwLock<_>; 0]>::unlock(guard);
- let guard = collection.read(&mut key);
+ let guard = collection.read(key);
assert!(guard.len() == 0);
}
}
diff --git a/src/collection/utils.rs b/src/collection/utils.rs
index d6d50f4..1d96e5c 100644
--- a/src/collection/utils.rs
+++ b/src/collection/utils.rs
@@ -1,7 +1,35 @@
use std::cell::Cell;
use crate::handle_unwind::handle_unwind;
-use crate::lockable::RawLock;
+use crate::lockable::{Lockable, RawLock};
+
+#[must_use]
+pub fn get_locks<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
+ let mut locks = Vec::new();
+ data.get_ptrs(&mut locks);
+ locks.sort_by_key(|lock| &raw const **lock);
+ locks
+}
+
+#[must_use]
+pub fn get_locks_unsorted<L: Lockable>(data: &L) -> Vec<&dyn RawLock> {
+ let mut locks = Vec::new();
+ data.get_ptrs(&mut locks);
+ locks
+}
+
+/// returns `true` if the sorted list contains a duplicate
+#[must_use]
+pub fn ordered_contains_duplicates(l: &[&dyn RawLock]) -> bool {
+ if l.is_empty() {
+ // Return early to prevent panic in the below call to `windows`
+ return false;
+ }
+
+ l.windows(2)
+ // NOTE: addr_eq is necessary because eq would also compare the v-table pointers
+ .any(|window| std::ptr::addr_eq(window[0], window[1]))
+}
/// Lock a set of locks in the given order. It's UB to call this without a `ThreadKey`
pub unsafe fn ordered_lock(locks: &[&dyn RawLock]) {
@@ -115,3 +143,13 @@ pub unsafe fn attempt_to_recover_reads_from_panic(locked: &[&dyn RawLock]) {
|| locked.iter().for_each(|l| l.poison()),
)
}
+
+#[cfg(test)]
+mod tests {
+ use crate::collection::utils::ordered_contains_duplicates;
+
+ #[test]
+ fn empty_array_does_not_contain_duplicates() {
+ assert!(!ordered_contains_duplicates(&[]))
+ }
+}
diff --git a/src/lockable.rs b/src/lockable.rs
index 4f7cfe5..f125c02 100644
--- a/src/lockable.rs
+++ b/src/lockable.rs
@@ -111,6 +111,10 @@ pub unsafe trait Lockable {
where
Self: 'g;
+ type DataMut<'a>
+ where
+ Self: 'a;
+
/// Yields a list of references to the [`RawLock`]s contained within this
/// value.
///
@@ -129,6 +133,9 @@ pub unsafe trait Lockable {
/// unlocked until this guard is dropped.
#[must_use]
unsafe fn guard(&self) -> Self::Guard<'_>;
+
+ #[must_use]
+ unsafe fn data_mut(&self) -> Self::DataMut<'_>;
}
/// Allows a lock to be accessed by multiple readers.
@@ -145,6 +152,10 @@ pub unsafe trait Sharable: Lockable {
where
Self: 'g;
+ type DataRef<'a>
+ where
+ Self: 'a;
+
/// Returns a guard that can be used to immutably access the underlying
/// data.
///
@@ -155,6 +166,9 @@ pub unsafe trait Sharable: Lockable {
/// unlocked until this guard is dropped.
#[must_use]
unsafe fn read_guard(&self) -> Self::ReadGuard<'_>;
+
+ #[must_use]
+ unsafe fn data_ref(&self) -> Self::DataRef<'_>;
}
/// A type that may be locked and unlocked, and is known to be the only valid
@@ -210,6 +224,11 @@ unsafe impl<T: Lockable> Lockable for &T {
where
Self: 'g;
+ type DataMut<'a>
+ = T::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
(*self).get_ptrs(ptrs);
}
@@ -217,6 +236,10 @@ unsafe impl<T: Lockable> Lockable for &T {
unsafe fn guard(&self) -> Self::Guard<'_> {
(*self).guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ (*self).data_mut()
+ }
}
unsafe impl<T: Sharable> Sharable for &T {
@@ -225,9 +248,18 @@ unsafe impl<T: Sharable> Sharable for &T {
where
Self: 'g;
+ type DataRef<'a>
+ = T::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
(*self).read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ (*self).data_ref()
+ }
}
unsafe impl<T: Lockable> Lockable for &mut T {
@@ -236,6 +268,11 @@ unsafe impl<T: Lockable> Lockable for &mut T {
where
Self: 'g;
+ type DataMut<'a>
+ = T::DataMut<'a>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
(**self).get_ptrs(ptrs)
}
@@ -243,6 +280,10 @@ unsafe impl<T: Lockable> Lockable for &mut T {
unsafe fn guard(&self) -> Self::Guard<'_> {
(**self).guard()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ (**self).data_mut()
+ }
}
impl<T: LockableGetMut> LockableGetMut for &mut T {
@@ -262,9 +303,18 @@ unsafe impl<T: Sharable> Sharable for &mut T {
where
Self: 'g;
+ type DataRef<'a>
+ = T::DataRef<'a>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
(**self).read_guard()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ (**self).data_ref()
+ }
}
unsafe impl<T: OwnedLockable> OwnedLockable for &mut T {}
@@ -276,6 +326,8 @@ macro_rules! tuple_impls {
unsafe impl<$($generic: Lockable,)*> Lockable for ($($generic,)*) {
type Guard<'g> = ($($generic::Guard<'g>,)*) where Self: 'g;
+ type DataMut<'a> = ($($generic::DataMut<'a>,)*) where Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
$(self.$value.get_ptrs(ptrs));*
}
@@ -285,6 +337,10 @@ macro_rules! tuple_impls {
// I don't think any other way of doing it compiles
($(self.$value.guard(),)*)
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ ($(self.$value.data_mut(),)*)
+ }
}
impl<$($generic: LockableGetMut,)*> LockableGetMut for ($($generic,)*) {
@@ -306,9 +362,15 @@ macro_rules! tuple_impls {
unsafe impl<$($generic: Sharable,)*> Sharable for ($($generic,)*) {
type ReadGuard<'g> = ($($generic::ReadGuard<'g>,)*) where Self: 'g;
+ type DataRef<'a> = ($($generic::DataRef<'a>,)*) where Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
($(self.$value.read_guard(),)*)
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ ($(self.$value.data_ref(),)*)
+ }
}
unsafe impl<$($generic: OwnedLockable,)*> OwnedLockable for ($($generic,)*) {}
@@ -329,6 +391,11 @@ unsafe impl<T: Lockable, const N: usize> Lockable for [T; N] {
where
Self: 'g;
+ type DataMut<'a>
+ = [T::DataMut<'a>; N]
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
for lock in self {
lock.get_ptrs(ptrs);
@@ -345,6 +412,15 @@ unsafe impl<T: Lockable, const N: usize> Lockable for [T; N] {
guards.map(|g| g.assume_init())
}
+
+ unsafe fn data_mut<'a>(&'a self) -> Self::DataMut<'a> {
+ let mut guards = MaybeUninit::<[MaybeUninit<T::DataMut<'a>>; N]>::uninit().assume_init();
+ for i in 0..N {
+ guards[i].write(self[i].data_mut());
+ }
+
+ guards.map(|g| g.assume_init())
+ }
}
impl<T: LockableGetMut, const N: usize> LockableGetMut for [T; N] {
@@ -386,6 +462,11 @@ unsafe impl<T: Sharable, const N: usize> Sharable for [T; N] {
where
Self: 'g;
+ type DataRef<'a>
+ = [T::DataRef<'a>; N]
+ where
+ Self: 'a;
+
unsafe fn read_guard<'g>(&'g self) -> Self::ReadGuard<'g> {
let mut guards = MaybeUninit::<[MaybeUninit<T::ReadGuard<'g>>; N]>::uninit().assume_init();
for i in 0..N {
@@ -394,6 +475,15 @@ unsafe impl<T: Sharable, const N: usize> Sharable for [T; N] {
guards.map(|g| g.assume_init())
}
+
+ unsafe fn data_ref<'a>(&'a self) -> Self::DataRef<'a> {
+ let mut guards = MaybeUninit::<[MaybeUninit<T::DataRef<'a>>; N]>::uninit().assume_init();
+ for i in 0..N {
+ guards[i].write(self[i].data_ref());
+ }
+
+ guards.map(|g| g.assume_init())
+ }
}
unsafe impl<T: OwnedLockable, const N: usize> OwnedLockable for [T; N] {}
@@ -404,6 +494,11 @@ unsafe impl<T: Lockable> Lockable for Box<[T]> {
where
Self: 'g;
+ type DataMut<'a>
+ = Box<[T::DataMut<'a>]>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
for lock in self {
lock.get_ptrs(ptrs);
@@ -413,6 +508,10 @@ unsafe impl<T: Lockable> Lockable for Box<[T]> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.iter().map(|lock| lock.guard()).collect()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.iter().map(|lock| lock.data_mut()).collect()
+ }
}
impl<T: LockableGetMut + 'static> LockableGetMut for Box<[T]> {
@@ -432,9 +531,18 @@ unsafe impl<T: Sharable> Sharable for Box<[T]> {
where
Self: 'g;
+ type DataRef<'a>
+ = Box<[T::DataRef<'a>]>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.iter().map(|lock| lock.read_guard()).collect()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.iter().map(|lock| lock.data_ref()).collect()
+ }
}
unsafe impl<T: Sharable> Sharable for Vec<T> {
@@ -443,9 +551,18 @@ unsafe impl<T: Sharable> Sharable for Vec<T> {
where
Self: 'g;
+ type DataRef<'a>
+ = Box<[T::DataRef<'a>]>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
self.iter().map(|lock| lock.read_guard()).collect()
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.iter().map(|lock| lock.data_ref()).collect()
+ }
}
unsafe impl<T: OwnedLockable> OwnedLockable for Box<[T]> {}
@@ -457,6 +574,11 @@ unsafe impl<T: Lockable> Lockable for Vec<T> {
where
Self: 'g;
+ type DataMut<'a>
+ = Box<[T::DataMut<'a>]>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
for lock in self {
lock.get_ptrs(ptrs);
@@ -466,6 +588,10 @@ unsafe impl<T: Lockable> Lockable for Vec<T> {
unsafe fn guard(&self) -> Self::Guard<'_> {
self.iter().map(|lock| lock.guard()).collect()
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.iter().map(|lock| lock.data_mut()).collect()
+ }
}
// I'd make a generic impl<T: Lockable, I: IntoIterator<Item=T>> Lockable for I
diff --git a/src/mutex.rs b/src/mutex.rs
index d6cba7d..2022501 100644
--- a/src/mutex.rs
+++ b/src/mutex.rs
@@ -3,8 +3,8 @@ use std::marker::PhantomData;
use lock_api::RawMutex;
-use crate::key::Keyable;
use crate::poisonable::PoisonFlag;
+use crate::ThreadKey;
mod guard;
mod mutex;
@@ -87,20 +87,19 @@ pub type ParkingMutex<T> = Mutex<T, parking_lot::RawMutex>;
/// let mut key = ThreadKey::get().unwrap();
///
/// // Here we use a block to limit the lifetime of the lock guard.
-/// let result = {
-/// let mut data = data_mutex_clone.lock(&mut key);
+/// let result = data_mutex_clone.scoped_lock(&mut key, |data| {
/// let result = data.iter().fold(0, |acc, x| acc + x * 2);
/// data.push(result);
/// result
/// // The mutex guard gets dropped here, so the lock is released
-/// };
+/// });
/// // The thread key is available again
/// *res_mutex_clone.lock(key) += result;
/// }));
/// });
///
-/// let mut key = ThreadKey::get().unwrap();
-/// let mut data = data_mutex.lock(&mut key);
+/// let key = ThreadKey::get().unwrap();
+/// let mut data = data_mutex.lock(key);
/// let result = data.iter().fold(0, |acc, x| acc + x * 2);
/// data.push(result);
///
@@ -108,12 +107,12 @@ pub type ParkingMutex<T> = Mutex<T, parking_lot::RawMutex>;
/// // allows other threads to start working on the data immediately. Dropping
/// // the data also gives us access to the thread key, so we can lock
/// // another mutex.
-/// drop(data);
+/// let key = Mutex::unlock(data);
///
/// // Here the mutex guard is not assigned to a variable and so, even if the
/// // scope does not end after this line, the mutex is still released: there is
/// // no deadlock.
-/// *res_mutex.lock(&mut key) += result;
+/// *res_mutex.lock(key) += result;
///
/// threads.into_iter().for_each(|thread| {
/// thread
@@ -121,6 +120,7 @@ pub type ParkingMutex<T> = Mutex<T, parking_lot::RawMutex>;
/// .expect("The thread creating or execution failed !")
/// });
///
+/// let key = ThreadKey::get().unwrap();
/// assert_eq!(*res_mutex.lock(key), 800);
/// ```
///
@@ -153,10 +153,9 @@ pub struct MutexRef<'a, T: ?Sized + 'a, R: RawMutex>(
//
// This is the most lifetime-intensive thing I've ever written. Can I graduate
// from borrow checker university now?
-pub struct MutexGuard<'a, 'key: 'a, T: ?Sized + 'a, Key: Keyable + 'key, R: RawMutex> {
+pub struct MutexGuard<'a, T: ?Sized + 'a, R: RawMutex> {
mutex: MutexRef<'a, T, R>, // this way we don't need to re-implement Drop
- thread_key: Key,
- _phantom: PhantomData<&'key ()>,
+ thread_key: ThreadKey,
}
#[cfg(test)]
@@ -194,61 +193,46 @@ mod tests {
#[test]
fn display_works_for_ref() {
let mutex: crate::Mutex<_> = Mutex::new("Hello, world!");
- let guard = unsafe { mutex.try_lock_no_key().unwrap() }; // TODO lock_no_key
+ let guard = unsafe { mutex.try_lock_no_key().unwrap() };
assert_eq!(guard.to_string(), "Hello, world!".to_string());
}
#[test]
- fn ord_works() {
- let key = ThreadKey::get().unwrap();
- let mutex1: crate::Mutex<_> = Mutex::new(1);
- let mutex2: crate::Mutex<_> = Mutex::new(2);
- let mutex3: crate::Mutex<_> = Mutex::new(2);
- let collection = LockCollection::try_new((&mutex1, &mutex2, &mutex3)).unwrap();
-
- let guard = collection.lock(key);
- assert!(guard.0 < guard.1);
- assert!(guard.1 > guard.0);
- assert!(guard.1 == guard.2);
- assert!(guard.0 != guard.2)
- }
-
- #[test]
fn ref_as_mut() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let collection = LockCollection::new(crate::Mutex::new(0));
- let mut guard = collection.lock(&mut key);
+ let mut guard = collection.lock(key);
let guard_mut = guard.as_mut().as_mut();
*guard_mut = 3;
- drop(guard);
+ let key = LockCollection::<crate::Mutex<_>>::unlock(guard);
- let guard = collection.lock(&mut key);
+ let guard = collection.lock(key);
assert_eq!(guard.as_ref().as_ref(), &3);
}
#[test]
fn guard_as_mut() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let mutex = crate::Mutex::new(0);
- let mut guard = mutex.lock(&mut key);
+ let mut guard = mutex.lock(key);
let guard_mut = guard.as_mut();
*guard_mut = 3;
- drop(guard);
+ let key = Mutex::unlock(guard);
- let guard = mutex.lock(&mut key);
+ let guard = mutex.lock(key);
assert_eq!(guard.as_ref(), &3);
}
#[test]
fn dropping_guard_releases_mutex() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let mutex: crate::Mutex<_> = Mutex::new("Hello, world!");
- let guard = mutex.lock(&mut key);
+ let guard = mutex.lock(key);
drop(guard);
assert!(!mutex.is_locked());
diff --git a/src/mutex/guard.rs b/src/mutex/guard.rs
index 4e4d5f1..22e59c1 100644
--- a/src/mutex/guard.rs
+++ b/src/mutex/guard.rs
@@ -5,34 +5,14 @@ use std::ops::{Deref, DerefMut};
use lock_api::RawMutex;
-use crate::key::Keyable;
use crate::lockable::RawLock;
+use crate::ThreadKey;
use super::{Mutex, MutexGuard, MutexRef};
// These impls make things slightly easier because now you can use
// `println!("{guard}")` instead of `println!("{}", *guard)`
-impl<T: PartialEq + ?Sized, R: RawMutex> PartialEq for MutexRef<'_, T, R> {
- fn eq(&self, other: &Self) -> bool {
- self.deref().eq(&**other)
- }
-}
-
-impl<T: Eq + ?Sized, R: RawMutex> Eq for MutexRef<'_, T, R> {}
-
-impl<T: PartialOrd + ?Sized, R: RawMutex> PartialOrd for MutexRef<'_, T, R> {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.deref().partial_cmp(&**other)
- }
-}
-
-impl<T: Ord + ?Sized, R: RawMutex> Ord for MutexRef<'_, T, R> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.deref().cmp(&**other)
- }
-}
-
#[mutants::skip] // hashing involves RNG and is hard to test
#[cfg(not(tarpaulin_include))]
impl<T: Hash + ?Sized, R: RawMutex> Hash for MutexRef<'_, T, R> {
@@ -107,39 +87,9 @@ impl<'a, T: ?Sized, R: RawMutex> MutexRef<'a, T, R> {
// it's kinda annoying to re-implement some of this stuff on guards
// there's nothing i can do about that
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: PartialEq + ?Sized, R: RawMutex, Key: Keyable> PartialEq for MutexGuard<'_, '_, T, Key, R> {
- fn eq(&self, other: &Self) -> bool {
- self.deref().eq(&**other)
- }
-}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: Eq + ?Sized, R: RawMutex, Key: Keyable> Eq for MutexGuard<'_, '_, T, Key, R> {}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: PartialOrd + ?Sized, R: RawMutex, Key: Keyable> PartialOrd
- for MutexGuard<'_, '_, T, Key, R>
-{
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.deref().partial_cmp(&**other)
- }
-}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: Ord + ?Sized, R: RawMutex, Key: Keyable> Ord for MutexGuard<'_, '_, T, Key, R> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.deref().cmp(&**other)
- }
-}
-
#[mutants::skip] // hashing involves RNG and is hard to test
#[cfg(not(tarpaulin_include))]
-impl<T: Hash + ?Sized, R: RawMutex, Key: Keyable> Hash for MutexGuard<'_, '_, T, Key, R> {
+impl<T: Hash + ?Sized, R: RawMutex> Hash for MutexGuard<'_, T, R> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.deref().hash(state)
}
@@ -147,19 +97,19 @@ impl<T: Hash + ?Sized, R: RawMutex, Key: Keyable> Hash for MutexGuard<'_, '_, T,
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: Debug + ?Sized, Key: Keyable, R: RawMutex> Debug for MutexGuard<'_, '_, T, Key, R> {
+impl<T: Debug + ?Sized, R: RawMutex> Debug for MutexGuard<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
-impl<T: Display + ?Sized, Key: Keyable, R: RawMutex> Display for MutexGuard<'_, '_, T, Key, R> {
+impl<T: Display + ?Sized, R: RawMutex> Display for MutexGuard<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&**self, f)
}
}
-impl<T: ?Sized, Key: Keyable, R: RawMutex> Deref for MutexGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawMutex> Deref for MutexGuard<'_, T, R> {
type Target = T;
fn deref(&self) -> &Self::Target {
@@ -167,33 +117,32 @@ impl<T: ?Sized, Key: Keyable, R: RawMutex> Deref for MutexGuard<'_, '_, T, Key,
}
}
-impl<T: ?Sized, Key: Keyable, R: RawMutex> DerefMut for MutexGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawMutex> DerefMut for MutexGuard<'_, T, R> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.mutex
}
}
-impl<T: ?Sized, Key: Keyable, R: RawMutex> AsRef<T> for MutexGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawMutex> AsRef<T> for MutexGuard<'_, T, R> {
fn as_ref(&self) -> &T {
self
}
}
-impl<T: ?Sized, Key: Keyable, R: RawMutex> AsMut<T> for MutexGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawMutex> AsMut<T> for MutexGuard<'_, T, R> {
fn as_mut(&mut self) -> &mut T {
self
}
}
-impl<'a, T: ?Sized, Key: Keyable, R: RawMutex> MutexGuard<'a, '_, T, Key, R> {
+impl<'a, T: ?Sized, R: RawMutex> MutexGuard<'a, T, R> {
/// Create a guard to the given mutex. Undefined if multiple guards to the
/// same mutex exist at once.
#[must_use]
- pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: Key) -> Self {
+ pub(super) unsafe fn new(mutex: &'a Mutex<T, R>, thread_key: ThreadKey) -> Self {
Self {
mutex: MutexRef(mutex, PhantomData),
thread_key,
- _phantom: PhantomData,
}
}
}
diff --git a/src/mutex/mutex.rs b/src/mutex/mutex.rs
index 0bd5286..1d8ce8b 100644
--- a/src/mutex/mutex.rs
+++ b/src/mutex/mutex.rs
@@ -6,9 +6,9 @@ use std::panic::AssertUnwindSafe;
use lock_api::RawMutex;
use crate::handle_unwind::handle_unwind;
-use crate::key::Keyable;
use crate::lockable::{Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock};
use crate::poisonable::PoisonFlag;
+use crate::{Keyable, ThreadKey};
use super::{Mutex, MutexGuard, MutexRef};
@@ -62,6 +62,11 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> {
where
Self: 'g;
+ type DataMut<'a>
+ = &'a mut T
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.push(self);
}
@@ -69,6 +74,10 @@ unsafe impl<T: Send, R: RawMutex + Send + Sync> Lockable for Mutex<T, R> {
unsafe fn guard(&self) -> Self::Guard<'_> {
MutexRef::new(self)
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.get().as_mut().unwrap_unchecked()
+ }
}
impl<T: Send, R: RawMutex + Send + Sync> LockableIntoInner for Mutex<T, R> {
@@ -214,6 +223,46 @@ impl<T: ?Sized, R> Mutex<T, R> {
}
impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
+ pub fn scoped_lock<Ret>(&self, key: impl Keyable, f: impl FnOnce(&mut T) -> Ret) -> Ret {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the mutex was just locked
+ let r = f(self.data.get().as_mut().unwrap_unchecked());
+
+ // safety: we locked the mutex already
+ self.raw_unlock();
+
+ drop(key); // ensures we drop the key in the correct place
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<Key: Keyable, Ret>(
+ &self,
+ key: Key,
+ f: impl FnOnce(&mut T) -> Ret,
+ ) -> Result<Ret, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: the mutex was just locked
+ let r = f(self.data.get().as_mut().unwrap_unchecked());
+
+ // safety: we locked the mutex already
+ self.raw_unlock();
+
+ drop(key); // ensures we drop the key in the correct place
+
+ Ok(r)
+ }
+ }
+
/// Block the thread until this mutex can be locked, and lock it.
///
/// Upon returning, the thread is the only thread with a lock on the
@@ -237,7 +286,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
/// let key = ThreadKey::get().unwrap();
/// assert_eq!(*mutex.lock(key), 10);
/// ```
- pub fn lock<'s, 'k: 's, Key: Keyable>(&'s self, key: Key) -> MutexGuard<'s, 'k, T, Key, R> {
+ pub fn lock(&self, key: ThreadKey) -> MutexGuard<'_, T, R> {
unsafe {
// safety: we have the thread key
self.raw_lock();
@@ -280,10 +329,7 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
/// let key = ThreadKey::get().unwrap();
/// assert_eq!(*mutex.lock(key), 10);
/// ```
- pub fn try_lock<'s, 'k: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> Result<MutexGuard<'s, 'k, T, Key, R>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<MutexGuard<'_, T, R>, ThreadKey> {
unsafe {
// safety: we have the key to the mutex
if self.raw_try_lock() {
@@ -322,7 +368,8 @@ impl<T: ?Sized, R: RawMutex> Mutex<T, R> {
///
/// let key = Mutex::unlock(guard);
/// ```
- pub fn unlock<'a, 'k: 'a, Key: Keyable + 'k>(guard: MutexGuard<'a, 'k, T, Key, R>) -> Key {
+ #[must_use]
+ pub fn unlock(guard: MutexGuard<'_, T, R>) -> ThreadKey {
unsafe {
guard.mutex.0.raw_unlock();
}
diff --git a/src/poisonable.rs b/src/poisonable.rs
index bb6ad28..8d5a810 100644
--- a/src/poisonable.rs
+++ b/src/poisonable.rs
@@ -1,6 +1,8 @@
use std::marker::PhantomData;
use std::sync::atomic::AtomicBool;
+use crate::ThreadKey;
+
mod error;
mod flag;
mod guard;
@@ -55,10 +57,9 @@ pub struct PoisonRef<'a, G> {
/// An RAII guard for a [`Poisonable`].
///
/// This is created by calling methods like [`Poisonable::lock`].
-pub struct PoisonGuard<'a, 'key: 'a, G, Key: 'key> {
+pub struct PoisonGuard<'a, G> {
guard: PoisonRef<'a, G>,
- key: Key,
- _phantom: PhantomData<&'key ()>,
+ key: ThreadKey,
}
/// A type of error which can be returned when acquiring a [`Poisonable`] lock.
@@ -69,9 +70,9 @@ pub struct PoisonError<Guard> {
/// An enumeration of possible errors associated with
/// [`TryLockPoisonableResult`] which can occur while trying to acquire a lock
/// (i.e.: [`Poisonable::try_lock`]).
-pub enum TryLockPoisonableError<'flag, 'key: 'flag, G, Key: 'key> {
- Poisoned(PoisonError<PoisonGuard<'flag, 'key, G, Key>>),
- WouldBlock(Key),
+pub enum TryLockPoisonableError<'flag, G> {
+ Poisoned(PoisonError<PoisonGuard<'flag, G>>),
+ WouldBlock(ThreadKey),
}
/// A type alias for the result of a lock method which can poisoned.
@@ -89,8 +90,8 @@ pub type PoisonResult<Guard> = Result<Guard, PoisonError<Guard>>;
/// For more information, see [`PoisonResult`]. A `TryLockPoisonableResult`
/// doesn't necessarily hold the associated guard in the [`Err`] type as the
/// lock might not have been acquired for other reasons.
-pub type TryLockPoisonableResult<'flag, 'key, G, Key> =
- Result<PoisonGuard<'flag, 'key, G, Key>, TryLockPoisonableError<'flag, 'key, G, Key>>;
+pub type TryLockPoisonableResult<'flag, G> =
+ Result<PoisonGuard<'flag, G>, TryLockPoisonableError<'flag, G>>;
#[cfg(test)]
mod tests {
@@ -101,6 +102,105 @@ mod tests {
use crate::{LockCollection, Mutex, ThreadKey};
#[test]
+ fn locking_poisoned_mutex_returns_error_in_collection() {
+ let key = ThreadKey::get().unwrap();
+ let mutex = LockCollection::new(Poisonable::new(Mutex::new(42)));
+
+ std::thread::scope(|s| {
+ s.spawn(|| {
+ let key = ThreadKey::get().unwrap();
+ let mut guard1 = mutex.lock(key);
+ let guard = guard1.as_deref_mut().unwrap();
+ assert_eq!(**guard, 42);
+ panic!();
+
+ #[allow(unreachable_code)]
+ drop(guard1);
+ })
+ .join()
+ .unwrap_err();
+ });
+
+ let error = mutex.lock(key);
+ let error = error.as_deref().unwrap_err();
+ assert_eq!(***error.get_ref(), 42);
+ }
+
+ #[test]
+ fn non_poisoned_get_mut_is_ok() {
+ let mut mutex = Poisonable::new(Mutex::new(42));
+ let guard = mutex.get_mut();
+ assert!(guard.is_ok());
+ assert_eq!(*guard.unwrap(), 42);
+ }
+
+ #[test]
+ fn non_poisoned_get_mut_is_err() {
+ let mut mutex = Poisonable::new(Mutex::new(42));
+
+ let _ = std::panic::catch_unwind(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused_variables)]
+ let guard = mutex.lock(key);
+ panic!();
+ #[allow(unreachable_code)]
+ drop(guard);
+ });
+
+ let guard = mutex.get_mut();
+ assert!(guard.is_err());
+ assert_eq!(**guard.unwrap_err().get_ref(), 42);
+ }
+
+ #[test]
+ fn unpoisoned_into_inner() {
+ let mutex = Poisonable::new(Mutex::new("foo"));
+ assert_eq!(mutex.into_inner().unwrap(), "foo");
+ }
+
+ #[test]
+ fn poisoned_into_inner() {
+ let mutex = Poisonable::from(Mutex::new("foo"));
+
+ std::panic::catch_unwind(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused_variables)]
+ let guard = mutex.lock(key);
+ panic!();
+ #[allow(unreachable_code)]
+ drop(guard);
+ })
+ .unwrap_err();
+
+ let error = mutex.into_inner().unwrap_err();
+ assert_eq!(error.into_inner(), "foo");
+ }
+
+ #[test]
+ fn unpoisoned_into_child() {
+ let mutex = Poisonable::new(Mutex::new("foo"));
+ assert_eq!(mutex.into_child().unwrap().into_inner(), "foo");
+ }
+
+ #[test]
+ fn poisoned_into_child() {
+ let mutex = Poisonable::from(Mutex::new("foo"));
+
+ std::panic::catch_unwind(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused_variables)]
+ let guard = mutex.lock(key);
+ panic!();
+ #[allow(unreachable_code)]
+ drop(guard);
+ })
+ .unwrap_err();
+
+ let error = mutex.into_child().unwrap_err();
+ assert_eq!(error.into_inner().into_inner(), "foo");
+ }
+
+ #[test]
fn display_works() {
let key = ThreadKey::get().unwrap();
let mutex = Poisonable::new(Mutex::new("Hello, world!"));
@@ -111,21 +211,75 @@ mod tests {
}
#[test]
- fn ord_works() {
+ fn ref_as_ref() {
+ let key = ThreadKey::get().unwrap();
+ let collection = LockCollection::new(Poisonable::new(Mutex::new("foo")));
+ let guard = collection.lock(key);
+ let Ok(ref guard) = guard.as_ref() else {
+ panic!()
+ };
+ assert_eq!(**guard.as_ref(), "foo");
+ }
+
+ #[test]
+ fn ref_as_mut() {
let key = ThreadKey::get().unwrap();
- let lock1 = Poisonable::new(Mutex::new(1));
- let lock2 = Poisonable::new(Mutex::new(3));
- let lock3 = Poisonable::new(Mutex::new(3));
- let collection = LockCollection::try_new((&lock1, &lock2, &lock3)).unwrap();
+ let collection = LockCollection::new(Poisonable::new(Mutex::new("foo")));
+ let mut guard1 = collection.lock(key);
+ let Ok(ref mut guard) = guard1.as_mut() else {
+ panic!()
+ };
+ let guard = guard.as_mut();
+ **guard = "bar";
+
+ let key = LockCollection::<Poisonable<Mutex<_>>>::unlock(guard1);
+ let guard = collection.lock(key);
+ let guard = guard.as_deref().unwrap();
+ assert_eq!(*guard.as_ref(), "bar");
+ }
+ #[test]
+ fn guard_as_ref() {
+ let key = ThreadKey::get().unwrap();
+ let collection = Poisonable::new(Mutex::new("foo"));
let guard = collection.lock(key);
- let guard1 = guard.0.as_ref().unwrap();
- let guard2 = guard.1.as_ref().unwrap();
- let guard3 = guard.2.as_ref().unwrap();
- assert_eq!(guard1.cmp(guard2), std::cmp::Ordering::Less);
- assert_eq!(guard2.cmp(guard1), std::cmp::Ordering::Greater);
- assert!(guard2 == guard3);
- assert!(guard1 != guard3);
+ let Ok(ref guard) = guard.as_ref() else {
+ panic!()
+ };
+ assert_eq!(**guard.as_ref(), "foo");
+ }
+
+ #[test]
+ fn guard_as_mut() {
+ let key = ThreadKey::get().unwrap();
+ let mutex = Poisonable::new(Mutex::new("foo"));
+ let mut guard1 = mutex.lock(key);
+ let Ok(ref mut guard) = guard1.as_mut() else {
+ panic!()
+ };
+ let guard = guard.as_mut();
+ **guard = "bar";
+
+ let key = Poisonable::<Mutex<_>>::unlock(guard1.unwrap());
+ let guard = mutex.lock(key);
+ let guard = guard.as_deref().unwrap();
+ assert_eq!(*guard, "bar");
+ }
+
+ #[test]
+ fn deref_mut_in_collection() {
+ let key = ThreadKey::get().unwrap();
+ let collection = LockCollection::new(Poisonable::new(Mutex::new(42)));
+ let mut guard1 = collection.lock(key);
+ let Ok(ref mut guard) = guard1.as_mut() else {
+ panic!()
+ };
+ // TODO make this more convenient
+ assert_eq!(***guard, 42);
+ ***guard = 24;
+
+ let key = LockCollection::<Poisonable<Mutex<_>>>::unlock(guard1);
+ _ = collection.lock(key);
}
#[test]
@@ -202,18 +356,45 @@ mod tests {
assert!(mutex.is_poisoned());
- let mut key = ThreadKey::get().unwrap();
- let mut error = mutex.lock(&mut key).unwrap_err();
+ let key: ThreadKey = ThreadKey::get().unwrap();
+ let mut error = mutex.lock(key).unwrap_err();
let error1 = error.as_mut();
**error1 = "bar";
- drop(error);
+ let key = Poisonable::<Mutex<_>>::unlock(error.into_inner());
mutex.clear_poison();
- let guard = mutex.lock(&mut key).unwrap();
+ let guard = mutex.lock(key).unwrap();
assert_eq!(&**guard, "bar");
}
#[test]
+ fn try_error_from_lock_error() {
+ let mutex = Poisonable::new(Mutex::new("foo"));
+
+ let _ = std::panic::catch_unwind(|| {
+ let key = ThreadKey::get().unwrap();
+ #[allow(unused_variables)]
+ let guard = mutex.lock(key);
+ panic!();
+
+ #[allow(unknown_lints)]
+ #[allow(unreachable_code)]
+ drop(guard);
+ });
+
+ assert!(mutex.is_poisoned());
+
+ let key = ThreadKey::get().unwrap();
+ let error = mutex.lock(key).unwrap_err();
+ let error = TryLockPoisonableError::from(error);
+
+ let TryLockPoisonableError::Poisoned(error) = error else {
+ panic!()
+ };
+ assert_eq!(&**error.into_inner(), "foo");
+ }
+
+ #[test]
fn new_poisonable_is_not_poisoned() {
let mutex = Poisonable::new(Mutex::new(42));
assert!(!mutex.is_poisoned());
diff --git a/src/poisonable/error.rs b/src/poisonable/error.rs
index bff011d..b69df5d 100644
--- a/src/poisonable/error.rs
+++ b/src/poisonable/error.rs
@@ -109,7 +109,7 @@ impl<Guard> PoisonError<Guard> {
///
/// let key = ThreadKey::get().unwrap();
/// let p_err = mutex.lock(key).unwrap_err();
- /// let data: &PoisonGuard<_, _> = p_err.get_ref();
+ /// let data: &PoisonGuard<_> = p_err.get_ref();
/// println!("recovered {} items", data.len());
/// ```
#[must_use]
@@ -154,7 +154,7 @@ impl<Guard> PoisonError<Guard> {
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<G, Key> fmt::Debug for TryLockPoisonableError<'_, '_, G, Key> {
+impl<G> fmt::Debug for TryLockPoisonableError<'_, G> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match *self {
Self::Poisoned(..) => "Poisoned(..)".fmt(f),
@@ -163,7 +163,7 @@ impl<G, Key> fmt::Debug for TryLockPoisonableError<'_, '_, G, Key> {
}
}
-impl<G, Key> fmt::Display for TryLockPoisonableError<'_, '_, G, Key> {
+impl<G> fmt::Display for TryLockPoisonableError<'_, G> {
#[cfg_attr(test, mutants::skip)]
#[cfg(not(tarpaulin_include))]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
@@ -175,12 +175,10 @@ impl<G, Key> fmt::Display for TryLockPoisonableError<'_, '_, G, Key> {
}
}
-impl<G, Key> Error for TryLockPoisonableError<'_, '_, G, Key> {}
+impl<G> Error for TryLockPoisonableError<'_, G> {}
-impl<'flag, 'key, G, Key> From<PoisonError<PoisonGuard<'flag, 'key, G, Key>>>
- for TryLockPoisonableError<'flag, 'key, G, Key>
-{
- fn from(value: PoisonError<PoisonGuard<'flag, 'key, G, Key>>) -> Self {
+impl<'flag, G> From<PoisonError<PoisonGuard<'flag, G>>> for TryLockPoisonableError<'flag, G> {
+ fn from(value: PoisonError<PoisonGuard<'flag, G>>) -> Self {
Self::Poisoned(value)
}
}
diff --git a/src/poisonable/guard.rs b/src/poisonable/guard.rs
index 3f85d25..b887e2d 100644
--- a/src/poisonable/guard.rs
+++ b/src/poisonable/guard.rs
@@ -3,8 +3,6 @@ use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
-use crate::Keyable;
-
use super::{PoisonFlag, PoisonGuard, PoisonRef};
impl<'a, Guard> PoisonRef<'a, Guard> {
@@ -28,26 +26,6 @@ impl<Guard> Drop for PoisonRef<'_, Guard> {
}
}
-impl<Guard: PartialEq> PartialEq for PoisonRef<'_, Guard> {
- fn eq(&self, other: &Self) -> bool {
- self.guard.eq(&other.guard)
- }
-}
-
-impl<Guard: PartialOrd> PartialOrd for PoisonRef<'_, Guard> {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.guard.partial_cmp(&other.guard)
- }
-}
-
-impl<Guard: Eq> Eq for PoisonRef<'_, Guard> {}
-
-impl<Guard: Ord> Ord for PoisonRef<'_, Guard> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.guard.cmp(&other.guard)
- }
-}
-
#[mutants::skip] // hashing involves RNG and is hard to test
#[cfg(not(tarpaulin_include))]
impl<Guard: Hash> Hash for PoisonRef<'_, Guard> {
@@ -96,37 +74,9 @@ impl<Guard> AsMut<Guard> for PoisonRef<'_, Guard> {
}
}
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: PartialEq, Key: Keyable> PartialEq for PoisonGuard<'_, '_, Guard, Key> {
- fn eq(&self, other: &Self) -> bool {
- self.guard.eq(&other.guard)
- }
-}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: PartialOrd, Key: Keyable> PartialOrd for PoisonGuard<'_, '_, Guard, Key> {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.guard.partial_cmp(&other.guard)
- }
-}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: Eq, Key: Keyable> Eq for PoisonGuard<'_, '_, Guard, Key> {}
-
-#[mutants::skip] // it's hard to get two guards safely
-#[cfg(not(tarpaulin_include))]
-impl<Guard: Ord, Key: Keyable> Ord for PoisonGuard<'_, '_, Guard, Key> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.guard.cmp(&other.guard)
- }
-}
-
#[mutants::skip] // hashing involves RNG and is hard to test
#[cfg(not(tarpaulin_include))]
-impl<Guard: Hash, Key: Keyable> Hash for PoisonGuard<'_, '_, Guard, Key> {
+impl<Guard: Hash> Hash for PoisonGuard<'_, Guard> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.guard.hash(state)
}
@@ -134,19 +84,19 @@ impl<Guard: Hash, Key: Keyable> Hash for PoisonGuard<'_, '_, Guard, Key> {
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<Guard: Debug, Key: Keyable> Debug for PoisonGuard<'_, '_, Guard, Key> {
+impl<Guard: Debug> Debug for PoisonGuard<'_, Guard> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.guard, f)
}
}
-impl<Guard: Display, Key: Keyable> Display for PoisonGuard<'_, '_, Guard, Key> {
+impl<Guard: Display> Display for PoisonGuard<'_, Guard> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&self.guard, f)
}
}
-impl<T, Guard: Deref<Target = T>, Key: Keyable> Deref for PoisonGuard<'_, '_, Guard, Key> {
+impl<T, Guard: Deref<Target = T>> Deref for PoisonGuard<'_, Guard> {
type Target = T;
fn deref(&self) -> &Self::Target {
@@ -155,20 +105,20 @@ impl<T, Guard: Deref<Target = T>, Key: Keyable> Deref for PoisonGuard<'_, '_, Gu
}
}
-impl<T, Guard: DerefMut<Target = T>, Key: Keyable> DerefMut for PoisonGuard<'_, '_, Guard, Key> {
+impl<T, Guard: DerefMut<Target = T>> DerefMut for PoisonGuard<'_, Guard> {
fn deref_mut(&mut self) -> &mut Self::Target {
#[allow(clippy::explicit_auto_deref)] // fixing this results in a compiler error
&mut *self.guard.guard
}
}
-impl<Guard, Key: Keyable> AsRef<Guard> for PoisonGuard<'_, '_, Guard, Key> {
+impl<Guard> AsRef<Guard> for PoisonGuard<'_, Guard> {
fn as_ref(&self) -> &Guard {
&self.guard.guard
}
}
-impl<Guard, Key: Keyable> AsMut<Guard> for PoisonGuard<'_, '_, Guard, Key> {
+impl<Guard> AsMut<Guard> for PoisonGuard<'_, Guard> {
fn as_mut(&mut self) -> &mut Guard {
&mut self.guard.guard
}
diff --git a/src/poisonable/poisonable.rs b/src/poisonable/poisonable.rs
index 2dac4bb..efe4ed0 100644
--- a/src/poisonable/poisonable.rs
+++ b/src/poisonable/poisonable.rs
@@ -1,10 +1,9 @@
-use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
-use crate::Keyable;
+use crate::{Keyable, ThreadKey};
use super::{
PoisonError, PoisonFlag, PoisonGuard, PoisonRef, PoisonResult, Poisonable,
@@ -49,6 +48,11 @@ unsafe impl<L: Lockable> Lockable for Poisonable<L> {
where
Self: 'g;
+ type DataMut<'a>
+ = PoisonResult<L::DataMut<'a>>
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
self.inner.get_ptrs(ptrs)
}
@@ -62,6 +66,14 @@ unsafe impl<L: Lockable> Lockable for Poisonable<L> {
Ok(ref_guard)
}
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ if self.is_poisoned() {
+ Err(PoisonError::new(self.inner.data_mut()))
+ } else {
+ Ok(self.inner.data_mut())
+ }
+ }
}
unsafe impl<L: Sharable> Sharable for Poisonable<L> {
@@ -70,6 +82,11 @@ unsafe impl<L: Sharable> Sharable for Poisonable<L> {
where
Self: 'g;
+ type DataRef<'a>
+ = PoisonResult<L::DataRef<'a>>
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
let ref_guard = PoisonRef::new(&self.poisoned, self.inner.read_guard());
@@ -79,6 +96,14 @@ unsafe impl<L: Sharable> Sharable for Poisonable<L> {
Ok(ref_guard)
}
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ if self.is_poisoned() {
+ Err(PoisonError::new(self.inner.data_ref()))
+ } else {
+ Ok(self.inner.data_ref())
+ }
+ }
}
unsafe impl<L: OwnedLockable> OwnedLockable for Poisonable<L> {}
@@ -266,14 +291,10 @@ impl<L> Poisonable<L> {
impl<L: Lockable> Poisonable<L> {
/// Creates a guard for the poisonable, without locking it
- unsafe fn guard<'flag, 'key, Key: Keyable + 'key>(
- &'flag self,
- key: Key,
- ) -> PoisonResult<PoisonGuard<'flag, 'key, L::Guard<'flag>, Key>> {
+ unsafe fn guard(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::Guard<'_>>> {
let guard = PoisonGuard {
guard: PoisonRef::new(&self.poisoned, self.inner.guard()),
key,
- _phantom: PhantomData,
};
if self.is_poisoned() {
@@ -285,6 +306,50 @@ impl<L: Lockable> Poisonable<L> {
}
impl<L: Lockable + RawLock> Poisonable<L> {
+ pub fn scoped_lock<'a, R>(
+ &'a self,
+ key: impl Keyable,
+ f: impl Fn(<Self as Lockable>::DataMut<'a>) -> R,
+ ) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: the data was just locked
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_lock<'a, Key: Keyable, R>(
+ &'a self,
+ key: Key,
+ f: impl Fn(<Self as Lockable>::DataMut<'a>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_mut());
+
+ // safety: the collection is still locked
+ self.raw_unlock();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Acquires the lock, blocking the current thread until it is ok to do so.
///
/// This function will block the current thread until it is available to
@@ -316,10 +381,7 @@ impl<L: Lockable + RawLock> Poisonable<L> {
/// let key = ThreadKey::get().unwrap();
/// assert_eq!(*mutex.lock(key).unwrap(), 10);
/// ```
- pub fn lock<'flag, 'key, Key: Keyable + 'key>(
- &'flag self,
- key: Key,
- ) -> PoisonResult<PoisonGuard<'flag, 'key, L::Guard<'flag>, Key>> {
+ pub fn lock(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::Guard<'_>>> {
unsafe {
self.inner.raw_lock();
self.guard(key)
@@ -370,10 +432,7 @@ impl<L: Lockable + RawLock> Poisonable<L> {
///
/// [`Poisoned`]: `TryLockPoisonableError::Poisoned`
/// [`WouldBlock`]: `TryLockPoisonableError::WouldBlock`
- pub fn try_lock<'flag, 'key, Key: Keyable + 'key>(
- &'flag self,
- key: Key,
- ) -> TryLockPoisonableResult<'flag, 'key, L::Guard<'flag>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> TryLockPoisonableResult<'_, L::Guard<'_>> {
unsafe {
if self.inner.raw_try_lock() {
Ok(self.guard(key)?)
@@ -398,23 +457,17 @@ impl<L: Lockable + RawLock> Poisonable<L> {
///
/// let key = Poisonable::<Mutex<_>>::unlock(guard);
/// ```
- pub fn unlock<'flag, 'key, Key: Keyable + 'key>(
- guard: PoisonGuard<'flag, 'key, L::Guard<'flag>, Key>,
- ) -> Key {
+ pub fn unlock<'flag>(guard: PoisonGuard<'flag, L::Guard<'flag>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
}
impl<L: Sharable + RawLock> Poisonable<L> {
- unsafe fn read_guard<'flag, 'key, Key: Keyable + 'key>(
- &'flag self,
- key: Key,
- ) -> PoisonResult<PoisonGuard<'flag, 'key, L::ReadGuard<'flag>, Key>> {
+ unsafe fn read_guard(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::ReadGuard<'_>>> {
let guard = PoisonGuard {
guard: PoisonRef::new(&self.poisoned, self.inner.read_guard()),
key,
- _phantom: PhantomData,
};
if self.is_poisoned() {
@@ -424,6 +477,50 @@ impl<L: Sharable + RawLock> Poisonable<L> {
Ok(guard)
}
+ pub fn scoped_read<'a, R>(
+ &'a self,
+ key: impl Keyable,
+ f: impl Fn(<Self as Sharable>::DataRef<'a>) -> R,
+ ) -> R {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the data was just locked
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays alive long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<'a, Key: Keyable, R>(
+ &'a self,
+ key: Key,
+ f: impl Fn(<Self as Sharable>::DataRef<'a>) -> R,
+ ) -> Result<R, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: we just locked the collection
+ let r = f(self.data_ref());
+
+ // safety: the collection is still locked
+ self.raw_unlock_read();
+
+ drop(key); // ensures the key stays valid long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks with shared read access, blocking the current thread until it can
/// be acquired.
///
@@ -457,10 +554,7 @@ impl<L: Sharable + RawLock> Poisonable<L> {
/// assert!(c_lock.read(key).is_ok());
/// }).join().expect("thread::spawn failed");
/// ```
- pub fn read<'flag, 'key, Key: Keyable + 'key>(
- &'flag self,
- key: Key,
- ) -> PoisonResult<PoisonGuard<'flag, 'key, L::ReadGuard<'flag>, Key>> {
+ pub fn read(&self, key: ThreadKey) -> PoisonResult<PoisonGuard<'_, L::ReadGuard<'_>>> {
unsafe {
self.inner.raw_read();
self.read_guard(key)
@@ -504,10 +598,7 @@ impl<L: Sharable + RawLock> Poisonable<L> {
///
/// [`Poisoned`]: `TryLockPoisonableError::Poisoned`
/// [`WouldBlock`]: `TryLockPoisonableError::WouldBlock`
- pub fn try_read<'flag, 'key, Key: Keyable + 'key>(
- &'flag self,
- key: Key,
- ) -> TryLockPoisonableResult<'flag, 'key, L::ReadGuard<'flag>, Key> {
+ pub fn try_read(&self, key: ThreadKey) -> TryLockPoisonableResult<'_, L::ReadGuard<'_>> {
unsafe {
if self.inner.raw_try_read() {
Ok(self.read_guard(key)?)
@@ -530,9 +621,7 @@ impl<L: Sharable + RawLock> Poisonable<L> {
/// let mut guard = lock.read(key).unwrap();
/// let key = Poisonable::<RwLock<_>>::unlock_read(guard);
/// ```
- pub fn unlock_read<'flag, 'key, Key: Keyable + 'key>(
- guard: PoisonGuard<'flag, 'key, L::ReadGuard<'flag>, Key>,
- ) -> Key {
+ pub fn unlock_read<'flag>(guard: PoisonGuard<'flag, L::ReadGuard<'flag>>) -> ThreadKey {
drop(guard.guard);
guard.key
}
diff --git a/src/rwlock.rs b/src/rwlock.rs
index f78e648..b604370 100644
--- a/src/rwlock.rs
+++ b/src/rwlock.rs
@@ -3,8 +3,8 @@ use std::marker::PhantomData;
use lock_api::RawRwLock;
-use crate::key::Keyable;
use crate::poisonable::PoisonFlag;
+use crate::ThreadKey;
mod rwlock;
@@ -95,10 +95,9 @@ pub struct RwLockWriteRef<'a, T: ?Sized, R: RawRwLock>(
///
/// [`read`]: `RwLock::read`
/// [`try_read`]: `RwLock::try_read`
-pub struct RwLockReadGuard<'a, 'key: 'a, T: ?Sized, Key: Keyable + 'key, R: RawRwLock> {
+pub struct RwLockReadGuard<'a, T: ?Sized, R: RawRwLock> {
rwlock: RwLockReadRef<'a, T, R>,
- thread_key: Key,
- _phantom: PhantomData<&'key ()>,
+ thread_key: ThreadKey,
}
/// RAII structure used to release the exclusive write access of a lock when
@@ -108,16 +107,14 @@ pub struct RwLockReadGuard<'a, 'key: 'a, T: ?Sized, Key: Keyable + 'key, R: RawR
/// [`RwLock`]
///
/// [`try_write`]: `RwLock::try_write`
-pub struct RwLockWriteGuard<'a, 'key: 'a, T: ?Sized, Key: Keyable + 'key, R: RawRwLock> {
+pub struct RwLockWriteGuard<'a, T: ?Sized, R: RawRwLock> {
rwlock: RwLockWriteRef<'a, T, R>,
- thread_key: Key,
- _phantom: PhantomData<&'key ()>,
+ thread_key: ThreadKey,
}
#[cfg(test)]
mod tests {
use crate::lockable::Lockable;
- use crate::LockCollection;
use crate::RwLock;
use crate::ThreadKey;
@@ -142,6 +139,16 @@ mod tests {
}
#[test]
+ fn read_lock_from_works() {
+ let key = ThreadKey::get().unwrap();
+ let lock: crate::RwLock<_> = RwLock::from("Hello, world!");
+ let reader = ReadLock::from(&lock);
+
+ let guard = reader.lock(key);
+ assert_eq!(*guard, "Hello, world!");
+ }
+
+ #[test]
fn write_lock_unlocked_when_initialized() {
let key = ThreadKey::get().unwrap();
let lock: crate::RwLock<_> = RwLock::new("Hello, world!");
@@ -235,21 +242,6 @@ mod tests {
}
#[test]
- fn write_ord() {
- let key = ThreadKey::get().unwrap();
- let lock1: crate::RwLock<_> = RwLock::new(1);
- let lock2: crate::RwLock<_> = RwLock::new(5);
- let lock3: crate::RwLock<_> = RwLock::new(5);
- let collection = LockCollection::try_new((&lock1, &lock2, &lock3)).unwrap();
- let guard = collection.lock(key);
-
- assert!(guard.0 < guard.1);
- assert!(guard.1 > guard.0);
- assert!(guard.1 == guard.2);
- assert!(guard.0 != guard.2);
- }
-
- #[test]
fn read_ref_display_works() {
let lock: crate::RwLock<_> = RwLock::new("Hello, world!");
let guard = unsafe { lock.try_read_no_key().unwrap() };
@@ -264,21 +256,6 @@ mod tests {
}
#[test]
- fn read_ord() {
- let key = ThreadKey::get().unwrap();
- let lock1: crate::RwLock<_> = RwLock::new(1);
- let lock2: crate::RwLock<_> = RwLock::new(5);
- let lock3: crate::RwLock<_> = RwLock::new(5);
- let collection = LockCollection::try_new((&lock1, &lock2, &lock3)).unwrap();
- let guard = collection.read(key);
-
- assert!(guard.0 < guard.1);
- assert!(guard.1 > guard.0);
- assert!(guard.1 == guard.2);
- assert!(guard.0 != guard.2);
- }
-
- #[test]
fn dropping_read_ref_releases_rwlock() {
let lock: crate::RwLock<_> = RwLock::new("Hello, world!");
@@ -290,10 +267,10 @@ mod tests {
#[test]
fn dropping_write_guard_releases_rwlock() {
- let mut key = ThreadKey::get().unwrap();
+ let key = ThreadKey::get().unwrap();
let lock: crate::RwLock<_> = RwLock::new("Hello, world!");
- let guard = lock.write(&mut key);
+ let guard = lock.write(key);
drop(guard);
assert!(!lock.is_locked());
diff --git a/src/rwlock/read_guard.rs b/src/rwlock/read_guard.rs
index bd22837..0d68c75 100644
--- a/src/rwlock/read_guard.rs
+++ b/src/rwlock/read_guard.rs
@@ -5,34 +5,14 @@ use std::ops::Deref;
use lock_api::RawRwLock;
-use crate::key::Keyable;
use crate::lockable::RawLock;
+use crate::ThreadKey;
use super::{RwLock, RwLockReadGuard, RwLockReadRef};
// These impls make things slightly easier because now you can use
// `println!("{guard}")` instead of `println!("{}", *guard)`
-impl<T: PartialEq + ?Sized, R: RawRwLock> PartialEq for RwLockReadRef<'_, T, R> {
- fn eq(&self, other: &Self) -> bool {
- self.deref().eq(&**other)
- }
-}
-
-impl<T: Eq + ?Sized, R: RawRwLock> Eq for RwLockReadRef<'_, T, R> {}
-
-impl<T: PartialOrd + ?Sized, R: RawRwLock> PartialOrd for RwLockReadRef<'_, T, R> {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.deref().partial_cmp(&**other)
- }
-}
-
-impl<T: Ord + ?Sized, R: RawRwLock> Ord for RwLockReadRef<'_, T, R> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.deref().cmp(&**other)
- }
-}
-
#[mutants::skip] // hashing involves PRNG and is hard to test
#[cfg(not(tarpaulin_include))]
impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockReadRef<'_, T, R> {
@@ -89,41 +69,9 @@ impl<'a, T: ?Sized, R: RawRwLock> RwLockReadRef<'a, T, R> {
}
}
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: PartialEq + ?Sized, R: RawRwLock, Key: Keyable> PartialEq
- for RwLockReadGuard<'_, '_, T, Key, R>
-{
- fn eq(&self, other: &Self) -> bool {
- self.deref().eq(&**other)
- }
-}
-
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: Eq + ?Sized, R: RawRwLock, Key: Keyable> Eq for RwLockReadGuard<'_, '_, T, Key, R> {}
-
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: PartialOrd + ?Sized, R: RawRwLock, Key: Keyable> PartialOrd
- for RwLockReadGuard<'_, '_, T, Key, R>
-{
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.deref().partial_cmp(&**other)
- }
-}
-
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: Ord + ?Sized, R: RawRwLock, Key: Keyable> Ord for RwLockReadGuard<'_, '_, T, Key, R> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.deref().cmp(&**other)
- }
-}
-
#[mutants::skip] // hashing involves PRNG and is hard to test
#[cfg(not(tarpaulin_include))]
-impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockReadGuard<'_, '_, T, Key, R> {
+impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockReadGuard<'_, T, R> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.deref().hash(state)
}
@@ -131,21 +79,19 @@ impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockReadGuard<'_,
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: Debug + ?Sized, Key: Keyable, R: RawRwLock> Debug for RwLockReadGuard<'_, '_, T, Key, R> {
+impl<T: Debug + ?Sized, R: RawRwLock> Debug for RwLockReadGuard<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
-impl<T: Display + ?Sized, Key: Keyable, R: RawRwLock> Display
- for RwLockReadGuard<'_, '_, T, Key, R>
-{
+impl<T: Display + ?Sized, R: RawRwLock> Display for RwLockReadGuard<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&**self, f)
}
}
-impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockReadGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawRwLock> Deref for RwLockReadGuard<'_, T, R> {
type Target = T;
fn deref(&self) -> &Self::Target {
@@ -153,21 +99,20 @@ impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockReadGuard<'_, '_, T,
}
}
-impl<T: ?Sized, Key: Keyable, R: RawRwLock> AsRef<T> for RwLockReadGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawRwLock> AsRef<T> for RwLockReadGuard<'_, T, R> {
fn as_ref(&self) -> &T {
self
}
}
-impl<'a, T: ?Sized, Key: Keyable, R: RawRwLock> RwLockReadGuard<'a, '_, T, Key, R> {
+impl<'a, T: ?Sized, R: RawRwLock> RwLockReadGuard<'a, T, R> {
/// Create a guard to the given mutex. Undefined if multiple guards to the
/// same mutex exist at once.
#[must_use]
- pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: Key) -> Self {
+ pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self {
Self {
rwlock: RwLockReadRef(rwlock, PhantomData),
thread_key,
- _phantom: PhantomData,
}
}
}
diff --git a/src/rwlock/read_lock.rs b/src/rwlock/read_lock.rs
index 5dd83a7..05b184a 100644
--- a/src/rwlock/read_lock.rs
+++ b/src/rwlock/read_lock.rs
@@ -2,8 +2,8 @@ use std::fmt::Debug;
use lock_api::RawRwLock;
-use crate::key::Keyable;
use crate::lockable::{Lockable, RawLock, Sharable};
+use crate::ThreadKey;
use super::{ReadLock, RwLock, RwLockReadGuard, RwLockReadRef};
@@ -13,6 +13,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R>
where
Self: 'g;
+ type DataMut<'a>
+ = &'a T
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.push(self.as_ref());
}
@@ -20,6 +25,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for ReadLock<'_, T, R>
unsafe fn guard(&self) -> Self::Guard<'_> {
RwLockReadRef::new(self.as_ref())
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.0.data_ref()
+ }
}
unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for ReadLock<'_, T, R> {
@@ -28,9 +37,18 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for ReadLock<'_, T, R>
where
Self: 'g;
+ type DataRef<'a>
+ = &'a T
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::Guard<'_> {
RwLockReadRef::new(self.as_ref())
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.0.data_ref()
+ }
}
#[mutants::skip]
@@ -117,10 +135,8 @@ impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> {
/// ```
///
/// [`ThreadKey`]: `crate::ThreadKey`
- pub fn lock<'s, 'key: 's, Key: Keyable + 'key>(
- &'s self,
- key: Key,
- ) -> RwLockReadGuard<'s, 'key, T, Key, R> {
+ #[must_use]
+ pub fn lock(&self, key: ThreadKey) -> RwLockReadGuard<'_, T, R> {
self.0.read(key)
}
@@ -155,10 +171,7 @@ impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> {
/// Err(_) => unreachable!(),
/// };
/// ```
- pub fn try_lock<'s, 'key: 's, Key: Keyable + 'key>(
- &'s self,
- key: Key,
- ) -> Result<RwLockReadGuard<'s, 'key, T, Key, R>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<RwLockReadGuard<'_, T, R>, ThreadKey> {
self.0.try_read(key)
}
@@ -189,7 +202,8 @@ impl<T: ?Sized, R: RawRwLock> ReadLock<'_, T, R> {
/// assert_eq!(*guard, 0);
/// let key = ReadLock::unlock(guard);
/// ```
- pub fn unlock<'key, Key: Keyable + 'key>(guard: RwLockReadGuard<'_, 'key, T, Key, R>) -> Key {
+ #[must_use]
+ pub fn unlock(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey {
RwLock::unlock_read(guard)
}
}
diff --git a/src/rwlock/rwlock.rs b/src/rwlock/rwlock.rs
index 038e6c7..905ecf8 100644
--- a/src/rwlock/rwlock.rs
+++ b/src/rwlock/rwlock.rs
@@ -6,10 +6,10 @@ use std::panic::AssertUnwindSafe;
use lock_api::RawRwLock;
use crate::handle_unwind::handle_unwind;
-use crate::key::Keyable;
use crate::lockable::{
Lockable, LockableGetMut, LockableIntoInner, OwnedLockable, RawLock, Sharable,
};
+use crate::{Keyable, ThreadKey};
use super::{PoisonFlag, RwLock, RwLockReadGuard, RwLockReadRef, RwLockWriteGuard, RwLockWriteRef};
@@ -79,6 +79,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> {
where
Self: 'g;
+ type DataMut<'a>
+ = &'a mut T
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.push(self);
}
@@ -86,6 +91,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for RwLock<T, R> {
unsafe fn guard(&self) -> Self::Guard<'_> {
RwLockWriteRef::new(self)
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.data.get().as_mut().unwrap_unchecked()
+ }
}
unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> {
@@ -94,9 +103,18 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Sharable for RwLock<T, R> {
where
Self: 'g;
+ type DataRef<'a>
+ = &'a T
+ where
+ Self: 'a;
+
unsafe fn read_guard(&self) -> Self::ReadGuard<'_> {
RwLockReadRef::new(self)
}
+
+ unsafe fn data_ref(&self) -> Self::DataRef<'_> {
+ self.data.get().as_ref().unwrap_unchecked()
+ }
}
unsafe impl<T: Send, R: RawRwLock + Send + Sync> OwnedLockable for RwLock<T, R> {}
@@ -230,6 +248,86 @@ impl<T: ?Sized, R> RwLock<T, R> {
}
impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
+ pub fn scoped_read<Ret>(&self, key: impl Keyable, f: impl Fn(&T) -> Ret) -> Ret {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_read();
+
+ // safety: the rwlock was just locked
+ let r = f(self.data.get().as_ref().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_read<Key: Keyable, Ret>(
+ &self,
+ key: Key,
+ f: impl Fn(&T) -> Ret,
+ ) -> Result<Ret, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_read() {
+ return Err(key);
+ }
+
+ // safety: the rwlock was just locked
+ let r = f(self.data.get().as_ref().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock_read();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ Ok(r)
+ }
+ }
+
+ pub fn scoped_write<Ret>(&self, key: impl Keyable, f: impl Fn(&mut T) -> Ret) -> Ret {
+ unsafe {
+ // safety: we have the thread key
+ self.raw_lock();
+
+ // safety: we just locked the rwlock
+ let r = f(self.data.get().as_mut().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ r
+ }
+ }
+
+ pub fn scoped_try_write<Key: Keyable, Ret>(
+ &self,
+ key: Key,
+ f: impl Fn(&mut T) -> Ret,
+ ) -> Result<Ret, Key> {
+ unsafe {
+ // safety: we have the thread key
+ if !self.raw_try_lock() {
+ return Err(key);
+ }
+
+ // safety: the rwlock was just locked
+ let r = f(self.data.get().as_mut().unwrap_unchecked());
+
+ // safety: the rwlock is already locked
+ self.raw_unlock();
+
+ drop(key); // ensure the key stays valid for long enough
+
+ Ok(r)
+ }
+ }
+
/// Locks this `RwLock` with shared read access, blocking the current
/// thread until it can be acquired.
///
@@ -264,10 +362,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
///
/// [`ThreadKey`]: `crate::ThreadKey`
- pub fn read<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> RwLockReadGuard<'s, 'key, T, Key, R> {
+ pub fn read(&self, key: ThreadKey) -> RwLockReadGuard<'_, T, R> {
unsafe {
self.raw_read();
@@ -305,10 +400,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// Err(_) => unreachable!(),
/// };
/// ```
- pub fn try_read<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> Result<RwLockReadGuard<'s, 'key, T, Key, R>, Key> {
+ pub fn try_read(&self, key: ThreadKey) -> Result<RwLockReadGuard<'_, T, R>, ThreadKey> {
unsafe {
if self.raw_try_read() {
// safety: the lock is locked first
@@ -369,10 +461,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// ```
///
/// [`ThreadKey`]: `crate::ThreadKey`
- pub fn write<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> RwLockWriteGuard<'s, 'key, T, Key, R> {
+ pub fn write(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> {
unsafe {
self.raw_lock();
@@ -407,10 +496,7 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// let n = lock.read(key);
/// assert_eq!(*n, 1);
/// ```
- pub fn try_write<'s, 'key: 's, Key: Keyable>(
- &'s self,
- key: Key,
- ) -> Result<RwLockWriteGuard<'s, 'key, T, Key, R>, Key> {
+ pub fn try_write(&self, key: ThreadKey) -> Result<RwLockWriteGuard<'_, T, R>, ThreadKey> {
unsafe {
if self.raw_try_lock() {
// safety: the lock is locked first
@@ -445,9 +531,8 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// assert_eq!(*guard, 0);
/// let key = RwLock::unlock_read(guard);
/// ```
- pub fn unlock_read<'key, Key: Keyable + 'key>(
- guard: RwLockReadGuard<'_, 'key, T, Key, R>,
- ) -> Key {
+ #[must_use]
+ pub fn unlock_read(guard: RwLockReadGuard<'_, T, R>) -> ThreadKey {
unsafe {
guard.rwlock.0.raw_unlock_read();
}
@@ -473,9 +558,8 @@ impl<T: ?Sized, R: RawRwLock> RwLock<T, R> {
/// *guard += 20;
/// let key = RwLock::unlock_write(guard);
/// ```
- pub fn unlock_write<'key, Key: Keyable + 'key>(
- guard: RwLockWriteGuard<'_, 'key, T, Key, R>,
- ) -> Key {
+ #[must_use]
+ pub fn unlock_write(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey {
unsafe {
guard.rwlock.0.raw_unlock();
}
diff --git a/src/rwlock/write_guard.rs b/src/rwlock/write_guard.rs
index c971260..3fabf8e 100644
--- a/src/rwlock/write_guard.rs
+++ b/src/rwlock/write_guard.rs
@@ -5,34 +5,14 @@ use std::ops::{Deref, DerefMut};
use lock_api::RawRwLock;
-use crate::key::Keyable;
use crate::lockable::RawLock;
+use crate::ThreadKey;
use super::{RwLock, RwLockWriteGuard, RwLockWriteRef};
// These impls make things slightly easier because now you can use
// `println!("{guard}")` instead of `println!("{}", *guard)`
-impl<T: PartialEq + ?Sized, R: RawRwLock> PartialEq for RwLockWriteRef<'_, T, R> {
- fn eq(&self, other: &Self) -> bool {
- self.deref().eq(&**other)
- }
-}
-
-impl<T: Eq + ?Sized, R: RawRwLock> Eq for RwLockWriteRef<'_, T, R> {}
-
-impl<T: PartialOrd + ?Sized, R: RawRwLock> PartialOrd for RwLockWriteRef<'_, T, R> {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.deref().partial_cmp(&**other)
- }
-}
-
-impl<T: Ord + ?Sized, R: RawRwLock> Ord for RwLockWriteRef<'_, T, R> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.deref().cmp(&**other)
- }
-}
-
#[mutants::skip] // hashing involves PRNG and is difficult to test
#[cfg(not(tarpaulin_include))]
impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockWriteRef<'_, T, R> {
@@ -104,41 +84,9 @@ impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteRef<'a, T, R> {
}
}
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: PartialEq + ?Sized, R: RawRwLock, Key: Keyable> PartialEq
- for RwLockWriteGuard<'_, '_, T, Key, R>
-{
- fn eq(&self, other: &Self) -> bool {
- self.deref().eq(&**other)
- }
-}
-
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: Eq + ?Sized, R: RawRwLock, Key: Keyable> Eq for RwLockWriteGuard<'_, '_, T, Key, R> {}
-
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: PartialOrd + ?Sized, R: RawRwLock, Key: Keyable> PartialOrd
- for RwLockWriteGuard<'_, '_, T, Key, R>
-{
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- self.deref().partial_cmp(&**other)
- }
-}
-
-#[mutants::skip] // it's hard to get two read guards safely
-#[cfg(not(tarpaulin_include))]
-impl<T: Ord + ?Sized, R: RawRwLock, Key: Keyable> Ord for RwLockWriteGuard<'_, '_, T, Key, R> {
- fn cmp(&self, other: &Self) -> std::cmp::Ordering {
- self.deref().cmp(&**other)
- }
-}
-
#[mutants::skip] // hashing involves PRNG and is difficult to test
#[cfg(not(tarpaulin_include))]
-impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockWriteGuard<'_, '_, T, Key, R> {
+impl<T: Hash + ?Sized, R: RawRwLock> Hash for RwLockWriteGuard<'_, T, R> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.deref().hash(state)
}
@@ -146,21 +94,19 @@ impl<T: Hash + ?Sized, R: RawRwLock, Key: Keyable> Hash for RwLockWriteGuard<'_,
#[mutants::skip]
#[cfg(not(tarpaulin_include))]
-impl<T: Debug + ?Sized, Key: Keyable, R: RawRwLock> Debug for RwLockWriteGuard<'_, '_, T, Key, R> {
+impl<T: Debug + ?Sized, R: RawRwLock> Debug for RwLockWriteGuard<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&**self, f)
}
}
-impl<T: Display + ?Sized, Key: Keyable, R: RawRwLock> Display
- for RwLockWriteGuard<'_, '_, T, Key, R>
-{
+impl<T: Display + ?Sized, R: RawRwLock> Display for RwLockWriteGuard<'_, T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Display::fmt(&**self, f)
}
}
-impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockWriteGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawRwLock> Deref for RwLockWriteGuard<'_, T, R> {
type Target = T;
fn deref(&self) -> &Self::Target {
@@ -168,33 +114,32 @@ impl<T: ?Sized, Key: Keyable, R: RawRwLock> Deref for RwLockWriteGuard<'_, '_, T
}
}
-impl<T: ?Sized, Key: Keyable, R: RawRwLock> DerefMut for RwLockWriteGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawRwLock> DerefMut for RwLockWriteGuard<'_, T, R> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.rwlock
}
}
-impl<T: ?Sized, Key: Keyable, R: RawRwLock> AsRef<T> for RwLockWriteGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawRwLock> AsRef<T> for RwLockWriteGuard<'_, T, R> {
fn as_ref(&self) -> &T {
self
}
}
-impl<T: ?Sized, Key: Keyable, R: RawRwLock> AsMut<T> for RwLockWriteGuard<'_, '_, T, Key, R> {
+impl<T: ?Sized, R: RawRwLock> AsMut<T> for RwLockWriteGuard<'_, T, R> {
fn as_mut(&mut self) -> &mut T {
self
}
}
-impl<'a, T: ?Sized + 'a, Key: Keyable, R: RawRwLock> RwLockWriteGuard<'a, '_, T, Key, R> {
+impl<'a, T: ?Sized + 'a, R: RawRwLock> RwLockWriteGuard<'a, T, R> {
/// Create a guard to the given mutex. Undefined if multiple guards to the
/// same mutex exist at once.
#[must_use]
- pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: Key) -> Self {
+ pub(super) unsafe fn new(rwlock: &'a RwLock<T, R>, thread_key: ThreadKey) -> Self {
Self {
rwlock: RwLockWriteRef(rwlock, PhantomData),
thread_key,
- _phantom: PhantomData,
}
}
}
diff --git a/src/rwlock/write_lock.rs b/src/rwlock/write_lock.rs
index cc96953..8a44a2d 100644
--- a/src/rwlock/write_lock.rs
+++ b/src/rwlock/write_lock.rs
@@ -2,8 +2,8 @@ use std::fmt::Debug;
use lock_api::RawRwLock;
-use crate::key::Keyable;
use crate::lockable::{Lockable, RawLock};
+use crate::ThreadKey;
use super::{RwLock, RwLockWriteGuard, RwLockWriteRef, WriteLock};
@@ -13,6 +13,11 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R
where
Self: 'g;
+ type DataMut<'a>
+ = &'a mut T
+ where
+ Self: 'a;
+
fn get_ptrs<'a>(&'a self, ptrs: &mut Vec<&'a dyn RawLock>) {
ptrs.push(self.as_ref());
}
@@ -20,6 +25,10 @@ unsafe impl<T: Send, R: RawRwLock + Send + Sync> Lockable for WriteLock<'_, T, R
unsafe fn guard(&self) -> Self::Guard<'_> {
RwLockWriteRef::new(self.as_ref())
}
+
+ unsafe fn data_mut(&self) -> Self::DataMut<'_> {
+ self.0.data_mut()
+ }
}
// Technically, the exclusive locks can also be shared, but there's currently
@@ -108,10 +117,8 @@ impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> {
/// ```
///
/// [`ThreadKey`]: `crate::ThreadKey`
- pub fn lock<'s, 'key: 's, Key: Keyable + 'key>(
- &'s self,
- key: Key,
- ) -> RwLockWriteGuard<'s, 'key, T, Key, R> {
+ #[must_use]
+ pub fn lock(&self, key: ThreadKey) -> RwLockWriteGuard<'_, T, R> {
self.0.write(key)
}
@@ -145,10 +152,7 @@ impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> {
/// Err(_) => unreachable!(),
/// };
/// ```
- pub fn try_lock<'s, 'key: 's, Key: Keyable + 'key>(
- &'s self,
- key: Key,
- ) -> Result<RwLockWriteGuard<'s, 'key, T, Key, R>, Key> {
+ pub fn try_lock(&self, key: ThreadKey) -> Result<RwLockWriteGuard<'_, T, R>, ThreadKey> {
self.0.try_write(key)
}
@@ -176,7 +180,8 @@ impl<T: ?Sized, R: RawRwLock> WriteLock<'_, T, R> {
/// *guard += 20;
/// let key = WriteLock::unlock(guard);
/// ```
- pub fn unlock<'key, Key: Keyable + 'key>(guard: RwLockWriteGuard<'_, 'key, T, Key, R>) -> Key {
+ #[must_use]
+ pub fn unlock(guard: RwLockWriteGuard<'_, T, R>) -> ThreadKey {
RwLock::unlock_write(guard)
}
}