diff options
Diffstat (limited to 'ai/src/transposition_table.rs')
| -rw-r--r-- | ai/src/transposition_table.rs | 129 |
1 files changed, 96 insertions, 33 deletions
diff --git a/ai/src/transposition_table.rs b/ai/src/transposition_table.rs index b801789..2b56a66 100644 --- a/ai/src/transposition_table.rs +++ b/ai/src/transposition_table.rs @@ -1,49 +1,44 @@ use crate::CheckersBitBoard; - -#[cfg(debug_assertions)] -const TABLE_SIZE: usize = 1_000_000 / std::mem::size_of::<TranspositionTableEntry>(); - -#[cfg(not(debug_assertions))] -const TABLE_SIZE: usize = 10_000_000 / std::mem::size_of::<TranspositionTableEntry>(); - -const EMPTY_ENTRY: Option<TranspositionTableEntry> = None; -static mut REPLACE_TABLE: [Option<TranspositionTableEntry>; TABLE_SIZE] = [EMPTY_ENTRY; TABLE_SIZE]; -static mut DEPTH_TABLE: [Option<TranspositionTableEntry>; TABLE_SIZE] = [EMPTY_ENTRY; TABLE_SIZE]; +use parking_lot::RwLock; +use std::num::NonZeroU8; #[derive(Copy, Clone, Debug)] struct TranspositionTableEntry { board: CheckersBitBoard, eval: f32, - depth: u8, -} - -pub struct TranspositionTableReference { - replace_table: &'static mut [Option<TranspositionTableEntry>; TABLE_SIZE], - depth_table: &'static mut [Option<TranspositionTableEntry>; TABLE_SIZE], + depth: NonZeroU8, } impl TranspositionTableEntry { - const fn new(board: CheckersBitBoard, eval: f32, depth: u8) -> Self { + const fn new(board: CheckersBitBoard, eval: f32, depth: NonZeroU8) -> Self { Self { board, eval, depth } } } -impl TranspositionTableReference { - pub fn new() -> Self { - Self { - replace_table: unsafe { &mut REPLACE_TABLE }, - depth_table: unsafe { &mut DEPTH_TABLE }, - } - } +pub struct TranspositionTable { + replace_table: Box<[RwLock<Option<TranspositionTableEntry>>]>, + depth_table: Box<[RwLock<Option<TranspositionTableEntry>>]>, +} +#[derive(Copy, Clone, Debug)] +pub struct TranspositionTableRef<'a> { + replace_table: &'a [RwLock<Option<TranspositionTableEntry>>], + depth_table: &'a [RwLock<Option<TranspositionTableEntry>>], +} + +impl<'a> TranspositionTableRef<'a> { pub fn get(self, board: CheckersBitBoard, depth: u8) -> Option<f32> { + let table_len = self.replace_table.as_ref().len(); + // try the replace table let entry = unsafe { self.replace_table - .get_unchecked(board.hash_code() as usize % TABLE_SIZE) + .as_ref() + .get_unchecked(board.hash_code() as usize % table_len) + .read() }; if let Some(entry) = *entry { - if entry.board == board && entry.depth >= depth { + if entry.board == board && entry.depth.get() >= depth { return Some(entry.eval); } } @@ -51,12 +46,14 @@ impl TranspositionTableReference { // try the depth table let entry = unsafe { self.depth_table - .get_unchecked(board.hash_code() as usize % TABLE_SIZE) + .as_ref() + .get_unchecked(board.hash_code() as usize % table_len) + .read() }; match *entry { Some(entry) => { if entry.board == board { - if entry.depth >= depth { + if entry.depth.get() >= depth { Some(entry.eval) } else { None @@ -69,18 +66,57 @@ impl TranspositionTableReference { } } - pub fn insert(self, board: CheckersBitBoard, eval: f32, depth: u8) { - // insert to the replace table + pub fn get_any_depth(self, board: CheckersBitBoard) -> Option<f32> { + let table_len = self.replace_table.as_ref().len(); + + // try the depth table let entry = unsafe { + self.depth_table + .as_ref() + .get_unchecked(board.hash_code() as usize % table_len) + .read() + }; + if let Some(entry) = *entry { + if entry.board == board { + return Some(entry.eval); + } + } + + // try the replace table + let entry = unsafe { + self.replace_table + .as_ref() + .get_unchecked(board.hash_code() as usize % table_len) + .read() + }; + match *entry { + Some(entry) => { + if entry.board == board { + Some(entry.eval) + } else { + None + } + } + None => None, + } + } + + pub fn insert(&self, board: CheckersBitBoard, eval: f32, depth: NonZeroU8) { + let table_len = self.replace_table.as_ref().len(); + + // insert to the replace table + let mut entry = unsafe { self.replace_table - .get_unchecked_mut(board.hash_code() as usize % TABLE_SIZE) + .get_unchecked(board.hash_code() as usize % table_len) + .write() }; *entry = Some(TranspositionTableEntry::new(board, eval, depth)); // insert to the depth table, only if the new depth is higher - let entry = unsafe { + let mut entry = unsafe { self.depth_table - .get_unchecked_mut(board.hash_code() as usize % TABLE_SIZE) + .get_unchecked(board.hash_code() as usize % table_len) + .write() }; match *entry { Some(entry_val) => { @@ -92,3 +128,30 @@ impl TranspositionTableReference { } } } + +impl TranspositionTable { + pub fn new(table_size: usize) -> Self { + let mut replace_table = Box::new_uninit_slice(table_size / 2); + let mut depth_table = Box::new_uninit_slice(table_size / 2); + + for entry in replace_table.iter_mut() { + entry.write(RwLock::new(None)); + } + + for entry in depth_table.iter_mut() { + entry.write(RwLock::new(None)); + } + + Self { + replace_table: unsafe { replace_table.assume_init() }, + depth_table: unsafe { depth_table.assume_init() }, + } + } + + pub fn mut_ref(&mut self) -> TranspositionTableRef { + TranspositionTableRef { + replace_table: &self.replace_table, + depth_table: &self.depth_table, + } + } +} |
