diff --git a/agb/src/hash_map.rs b/agb/src/hash_map.rs index b8e6a474..138b9c4d 100644 --- a/agb/src/hash_map.rs +++ b/agb/src/hash_map.rs @@ -1,7 +1,9 @@ use alloc::vec::Vec; use core::{ hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}, - iter, mem, + iter, + mem::{self, MaybeUninit}, + ptr, }; use rustc_hash::FxHasher; @@ -10,27 +12,118 @@ type HashType = u32; struct Node { hash: HashType, + + // distance_to_initial_bucket = -1 => key and value are uninit. + // distance_to_initial_bucket >= 0 => key and value are init distance_to_initial_bucket: i32, - key: K, - value: V, + key: MaybeUninit, + value: MaybeUninit, } impl Node { - fn with_new_key_value(self, new_key: K, new_value: V) -> (Self, V) { + fn with_new_key_value(self, new_key: K, new_value: V) -> (Self, Option) { ( Self { hash: self.hash, distance_to_initial_bucket: self.distance_to_initial_bucket, - key: new_key, - value: new_value, + key: MaybeUninit::new(new_key), + value: MaybeUninit::new(new_value), }, - self.value, + self.get_owned_value(), ) } + + fn new() -> Self { + Self { + hash: 0, + distance_to_initial_bucket: -1, + key: MaybeUninit::uninit(), + value: MaybeUninit::uninit(), + } + } + + fn new_with(key: K, value: V, hash: HashType) -> Self { + Self { + hash, + distance_to_initial_bucket: 0, + key: MaybeUninit::new(key), + value: MaybeUninit::new(value), + } + } + + fn get_value_ref(&self) -> Option<&V> { + if self.has_value() { + Some(unsafe { self.value.assume_init_ref() }) + } else { + None + } + } + + fn get_value_mut(&mut self) -> Option<&mut V> { + if self.has_value() { + Some(unsafe { self.value.assume_init_mut() }) + } else { + None + } + } + + fn get_owned_value(mut self) -> Option { + if self.has_value() { + let value = mem::replace(&mut self.value, MaybeUninit::uninit()); + self.distance_to_initial_bucket = -1; + + Some(unsafe { value.assume_init() }) + } else { + None + } + } + + fn key_ref(&self) -> Option<&K> { + if self.distance_to_initial_bucket >= 0 { + Some(unsafe { self.key.assume_init_ref() }) + } else { + None + } + } + + fn has_value(&self) -> bool { + self.distance_to_initial_bucket >= 0 + } + + fn take(&mut self) -> Self { + mem::take(self) + } + + fn take_key_value(&mut self) -> Option<(K, V, HashType)> { + if self.has_value() { + let key = mem::replace(&mut self.key, MaybeUninit::uninit()); + let value = mem::replace(&mut self.value, MaybeUninit::uninit()); + self.distance_to_initial_bucket = -1; + + Some(unsafe { (key.assume_init(), value.assume_init(), self.hash) }) + } else { + None + } + } +} + +impl Drop for Node { + fn drop(&mut self) { + if self.distance_to_initial_bucket >= 0 { + unsafe { ptr::drop_in_place(self.key.as_mut_ptr()) }; + unsafe { ptr::drop_in_place(self.value.as_mut_ptr()) }; + } + } +} + +impl Default for Node { + fn default() -> Self { + Self::new() + } } struct NodeStorage { - nodes: Vec>>, + nodes: Vec>, max_distance_to_initial_bucket: i32, number_of_items: usize, @@ -41,7 +134,7 @@ impl NodeStorage { assert!(capacity.is_power_of_two(), "Capacity must be a power of 2"); Self { - nodes: iter::repeat_with(|| None).take(capacity).collect(), + nodes: iter::repeat_with(Default::default).take(capacity).collect(), max_distance_to_initial_bucket: 0, number_of_items: 0, } @@ -63,26 +156,21 @@ impl NodeStorage { self.len() ); - let mut new_node = Node { - hash, - distance_to_initial_bucket: 0, - key, - value, - }; + let mut new_node = Node::new_with(key, value, hash); loop { let location = fast_mod( self.capacity(), new_node.hash + new_node.distance_to_initial_bucket as HashType, ); - let current_node = self.nodes[location].as_mut(); + let current_node = &mut self.nodes[location]; - if let Some(current_node) = current_node { + if current_node.has_value() { if current_node.distance_to_initial_bucket <= new_node.distance_to_initial_bucket { mem::swap(&mut new_node, current_node); } } else { - self.nodes[location] = Some(new_node); + self.nodes[location] = new_node; break; } @@ -97,34 +185,23 @@ impl NodeStorage { fn remove_from_location(&mut self, location: usize) -> V { let mut current_location = location; + self.number_of_items -= 1; - let result = loop { + loop { let next_location = fast_mod(self.capacity(), (current_location + 1) as HashType); // if the next node is empty, or the next location has 0 distance to initial bucket then // we can clear the current node - if self.nodes[next_location].is_none() - || self.nodes[next_location] - .as_ref() - .unwrap() - .distance_to_initial_bucket - == 0 + if !self.nodes[next_location].has_value() + || self.nodes[next_location].distance_to_initial_bucket == 0 { - break self.nodes[current_location].take().unwrap(); + return self.nodes[current_location].take_key_value().unwrap().1; } - if self.nodes[next_location].is_none() {} self.nodes.swap(current_location, next_location); - self.nodes[current_location] - .as_mut() - .unwrap() - .distance_to_initial_bucket -= 1; + self.nodes[current_location].distance_to_initial_bucket -= 1; current_location = next_location; - }; - - self.number_of_items -= 1; - - result.value + } } fn get_location(&self, key: &K, hash: HashType) -> Option @@ -138,8 +215,8 @@ impl NodeStorage { ); let node = &self.nodes[location]; - if let Some(node) = node { - if &node.key == key { + if let Some(node_key_ref) = node.key_ref() { + if node_key_ref == key { return Some(location); } } else { @@ -188,8 +265,10 @@ impl HashMap { let mut new_node_storage = NodeStorage::with_size(new_size); - for node in self.nodes.nodes.drain(..).flatten() { - new_node_storage.insert_new(node.key, node.value, node.hash); + for mut node in self.nodes.nodes.drain(..) { + if let Some((key, value, hash)) = node.take_key_value() { + new_node_storage.insert_new(key, value, hash); + } } self.nodes = new_node_storage; @@ -215,11 +294,11 @@ where let hash = self.hash(&key); if let Some(location) = self.nodes.get_location(&key, hash) { - let old_node = self.nodes.nodes[location].take().unwrap(); + let old_node = self.nodes.nodes[location].take(); let (new_node, old_value) = old_node.with_new_key_value(key, value); - self.nodes.nodes[location] = Some(new_node); + self.nodes.nodes[location] = new_node; - return Some(old_value); + return old_value; } if self.nodes.capacity() * 85 / 100 <= self.len() { @@ -235,14 +314,15 @@ where self.nodes .get_location(key, hash) - .map(|location| &self.nodes.nodes[location].as_ref().unwrap().value) + .map(|location| self.nodes.nodes[location].get_value_ref()) + .flatten() } pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { let hash = self.hash(key); if let Some(location) = self.nodes.get_location(key, hash) { - Some(&mut self.nodes.nodes[location].as_mut().unwrap().value) + self.nodes.nodes[location].get_value_mut() } else { None } @@ -282,12 +362,12 @@ impl<'a, K, V> Iterator for Iter<'a, K, V> { return None; } - if let Some(node) = &self.map.nodes.nodes[self.at] { - self.at += 1; - return Some((&node.key, &node.value)); - } - + let node = &self.map.nodes.nodes[self.at]; self.at += 1; + + if node.has_value() { + return Some((node.key_ref().unwrap(), node.get_value_ref().unwrap())); + } } } } @@ -341,7 +421,7 @@ where { match self { Entry::Occupied(e) => { - f(&mut e.entry.value); + f(e.entry.get_value_mut().unwrap()); Entry::Occupied(e) } Entry::Vacant(e) => Entry::Vacant(e), @@ -359,7 +439,7 @@ where if let Some(location) = location { Entry::Occupied(OccupiedEntry { - entry: self.nodes.nodes[location].as_mut().unwrap(), + entry: &mut self.nodes.nodes[location], }) } else { Entry::Vacant(VacantEntry { key, map: self })