Use MaybeUninit to reduce memory usage by half

This commit is contained in:
Gwilym Kuiper 2022-03-19 22:42:09 +00:00
parent cd9798d01f
commit ee983ef7ec

View file

@ -1,7 +1,9 @@
use alloc::vec::Vec; use alloc::vec::Vec;
use core::{ use core::{
hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}, hash::{BuildHasher, BuildHasherDefault, Hash, Hasher},
iter, mem, iter,
mem::{self, MaybeUninit},
ptr,
}; };
use rustc_hash::FxHasher; use rustc_hash::FxHasher;
@ -10,27 +12,118 @@ type HashType = u32;
struct Node<K, V> { struct Node<K, V> {
hash: HashType, hash: HashType,
// distance_to_initial_bucket = -1 => key and value are uninit.
// distance_to_initial_bucket >= 0 => key and value are init
distance_to_initial_bucket: i32, distance_to_initial_bucket: i32,
key: K, key: MaybeUninit<K>,
value: V, value: MaybeUninit<V>,
} }
impl<K, V> Node<K, V> { impl<K, V> Node<K, V> {
fn with_new_key_value(self, new_key: K, new_value: V) -> (Self, V) { fn with_new_key_value(self, new_key: K, new_value: V) -> (Self, Option<V>) {
( (
Self { Self {
hash: self.hash, hash: self.hash,
distance_to_initial_bucket: self.distance_to_initial_bucket, distance_to_initial_bucket: self.distance_to_initial_bucket,
key: new_key, key: MaybeUninit::new(new_key),
value: new_value, value: MaybeUninit::new(new_value),
}, },
self.value, self.get_owned_value(),
) )
} }
fn new() -> Self {
Self {
hash: 0,
distance_to_initial_bucket: -1,
key: MaybeUninit::uninit(),
value: MaybeUninit::uninit(),
}
}
fn new_with(key: K, value: V, hash: HashType) -> Self {
Self {
hash,
distance_to_initial_bucket: 0,
key: MaybeUninit::new(key),
value: MaybeUninit::new(value),
}
}
fn get_value_ref(&self) -> Option<&V> {
if self.has_value() {
Some(unsafe { self.value.assume_init_ref() })
} else {
None
}
}
fn get_value_mut(&mut self) -> Option<&mut V> {
if self.has_value() {
Some(unsafe { self.value.assume_init_mut() })
} else {
None
}
}
fn get_owned_value(mut self) -> Option<V> {
if self.has_value() {
let value = mem::replace(&mut self.value, MaybeUninit::uninit());
self.distance_to_initial_bucket = -1;
Some(unsafe { value.assume_init() })
} else {
None
}
}
fn key_ref(&self) -> Option<&K> {
if self.distance_to_initial_bucket >= 0 {
Some(unsafe { self.key.assume_init_ref() })
} else {
None
}
}
fn has_value(&self) -> bool {
self.distance_to_initial_bucket >= 0
}
fn take(&mut self) -> Self {
mem::take(self)
}
fn take_key_value(&mut self) -> Option<(K, V, HashType)> {
if self.has_value() {
let key = mem::replace(&mut self.key, MaybeUninit::uninit());
let value = mem::replace(&mut self.value, MaybeUninit::uninit());
self.distance_to_initial_bucket = -1;
Some(unsafe { (key.assume_init(), value.assume_init(), self.hash) })
} else {
None
}
}
}
impl<K, V> Drop for Node<K, V> {
fn drop(&mut self) {
if self.distance_to_initial_bucket >= 0 {
unsafe { ptr::drop_in_place(self.key.as_mut_ptr()) };
unsafe { ptr::drop_in_place(self.value.as_mut_ptr()) };
}
}
}
impl<K, V> Default for Node<K, V> {
fn default() -> Self {
Self::new()
}
} }
struct NodeStorage<K, V> { struct NodeStorage<K, V> {
nodes: Vec<Option<Node<K, V>>>, nodes: Vec<Node<K, V>>,
max_distance_to_initial_bucket: i32, max_distance_to_initial_bucket: i32,
number_of_items: usize, number_of_items: usize,
@ -41,7 +134,7 @@ impl<K, V> NodeStorage<K, V> {
assert!(capacity.is_power_of_two(), "Capacity must be a power of 2"); assert!(capacity.is_power_of_two(), "Capacity must be a power of 2");
Self { Self {
nodes: iter::repeat_with(|| None).take(capacity).collect(), nodes: iter::repeat_with(Default::default).take(capacity).collect(),
max_distance_to_initial_bucket: 0, max_distance_to_initial_bucket: 0,
number_of_items: 0, number_of_items: 0,
} }
@ -63,26 +156,21 @@ impl<K, V> NodeStorage<K, V> {
self.len() self.len()
); );
let mut new_node = Node { let mut new_node = Node::new_with(key, value, hash);
hash,
distance_to_initial_bucket: 0,
key,
value,
};
loop { loop {
let location = fast_mod( let location = fast_mod(
self.capacity(), self.capacity(),
new_node.hash + new_node.distance_to_initial_bucket as HashType, new_node.hash + new_node.distance_to_initial_bucket as HashType,
); );
let current_node = self.nodes[location].as_mut(); let current_node = &mut self.nodes[location];
if let Some(current_node) = current_node { if current_node.has_value() {
if current_node.distance_to_initial_bucket <= new_node.distance_to_initial_bucket { if current_node.distance_to_initial_bucket <= new_node.distance_to_initial_bucket {
mem::swap(&mut new_node, current_node); mem::swap(&mut new_node, current_node);
} }
} else { } else {
self.nodes[location] = Some(new_node); self.nodes[location] = new_node;
break; break;
} }
@ -97,34 +185,23 @@ impl<K, V> NodeStorage<K, V> {
fn remove_from_location(&mut self, location: usize) -> V { fn remove_from_location(&mut self, location: usize) -> V {
let mut current_location = location; let mut current_location = location;
self.number_of_items -= 1;
let result = loop { loop {
let next_location = fast_mod(self.capacity(), (current_location + 1) as HashType); let next_location = fast_mod(self.capacity(), (current_location + 1) as HashType);
// if the next node is empty, or the next location has 0 distance to initial bucket then // if the next node is empty, or the next location has 0 distance to initial bucket then
// we can clear the current node // we can clear the current node
if self.nodes[next_location].is_none() if !self.nodes[next_location].has_value()
|| self.nodes[next_location] || self.nodes[next_location].distance_to_initial_bucket == 0
.as_ref()
.unwrap()
.distance_to_initial_bucket
== 0
{ {
break self.nodes[current_location].take().unwrap(); return self.nodes[current_location].take_key_value().unwrap().1;
} }
if self.nodes[next_location].is_none() {}
self.nodes.swap(current_location, next_location); self.nodes.swap(current_location, next_location);
self.nodes[current_location] self.nodes[current_location].distance_to_initial_bucket -= 1;
.as_mut()
.unwrap()
.distance_to_initial_bucket -= 1;
current_location = next_location; current_location = next_location;
}; }
self.number_of_items -= 1;
result.value
} }
fn get_location(&self, key: &K, hash: HashType) -> Option<usize> fn get_location(&self, key: &K, hash: HashType) -> Option<usize>
@ -138,8 +215,8 @@ impl<K, V> NodeStorage<K, V> {
); );
let node = &self.nodes[location]; let node = &self.nodes[location];
if let Some(node) = node { if let Some(node_key_ref) = node.key_ref() {
if &node.key == key { if node_key_ref == key {
return Some(location); return Some(location);
} }
} else { } else {
@ -188,8 +265,10 @@ impl<K, V> HashMap<K, V> {
let mut new_node_storage = NodeStorage::with_size(new_size); let mut new_node_storage = NodeStorage::with_size(new_size);
for node in self.nodes.nodes.drain(..).flatten() { for mut node in self.nodes.nodes.drain(..) {
new_node_storage.insert_new(node.key, node.value, node.hash); if let Some((key, value, hash)) = node.take_key_value() {
new_node_storage.insert_new(key, value, hash);
}
} }
self.nodes = new_node_storage; self.nodes = new_node_storage;
@ -215,11 +294,11 @@ where
let hash = self.hash(&key); let hash = self.hash(&key);
if let Some(location) = self.nodes.get_location(&key, hash) { if let Some(location) = self.nodes.get_location(&key, hash) {
let old_node = self.nodes.nodes[location].take().unwrap(); let old_node = self.nodes.nodes[location].take();
let (new_node, old_value) = old_node.with_new_key_value(key, value); let (new_node, old_value) = old_node.with_new_key_value(key, value);
self.nodes.nodes[location] = Some(new_node); self.nodes.nodes[location] = new_node;
return Some(old_value); return old_value;
} }
if self.nodes.capacity() * 85 / 100 <= self.len() { if self.nodes.capacity() * 85 / 100 <= self.len() {
@ -235,14 +314,15 @@ where
self.nodes self.nodes
.get_location(key, hash) .get_location(key, hash)
.map(|location| &self.nodes.nodes[location].as_ref().unwrap().value) .map(|location| self.nodes.nodes[location].get_value_ref())
.flatten()
} }
pub fn get_mut(&mut self, key: &K) -> Option<&mut V> { pub fn get_mut(&mut self, key: &K) -> Option<&mut V> {
let hash = self.hash(key); let hash = self.hash(key);
if let Some(location) = self.nodes.get_location(key, hash) { if let Some(location) = self.nodes.get_location(key, hash) {
Some(&mut self.nodes.nodes[location].as_mut().unwrap().value) self.nodes.nodes[location].get_value_mut()
} else { } else {
None None
} }
@ -282,12 +362,12 @@ impl<'a, K, V> Iterator for Iter<'a, K, V> {
return None; return None;
} }
if let Some(node) = &self.map.nodes.nodes[self.at] { let node = &self.map.nodes.nodes[self.at];
self.at += 1; self.at += 1;
return Some((&node.key, &node.value));
}
self.at += 1; if node.has_value() {
return Some((node.key_ref().unwrap(), node.get_value_ref().unwrap()));
}
} }
} }
} }
@ -341,7 +421,7 @@ where
{ {
match self { match self {
Entry::Occupied(e) => { Entry::Occupied(e) => {
f(&mut e.entry.value); f(e.entry.get_value_mut().unwrap());
Entry::Occupied(e) Entry::Occupied(e)
} }
Entry::Vacant(e) => Entry::Vacant(e), Entry::Vacant(e) => Entry::Vacant(e),
@ -359,7 +439,7 @@ where
if let Some(location) = location { if let Some(location) = location {
Entry::Occupied(OccupiedEntry { Entry::Occupied(OccupiedEntry {
entry: self.nodes.nodes[location].as_mut().unwrap(), entry: &mut self.nodes.nodes[location],
}) })
} else { } else {
Entry::Vacant(VacantEntry { key, map: self }) Entry::Vacant(VacantEntry { key, map: self })