Skip to content

Commit

Permalink
Implement iter and clear
Browse files Browse the repository at this point in the history
  • Loading branch information
limemloh committed Mar 1, 2024
1 parent 9e808ff commit 2c7d8e2
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 75 deletions.
257 changes: 182 additions & 75 deletions concordium-std/src/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::{
mem::{self, MaybeUninit},
num,
num::NonZeroU32,
prims, state_btree_internals,
prims,
traits::*,
types::*,
vec::Vec,
Expand Down Expand Up @@ -3113,7 +3113,7 @@ impl<const M: usize, K, V, S> StateBTreeMap<M, K, V, S> {
pub fn insert(&mut self, key: K, value: V) -> Option<V>
where
S: HasStateApi,
K: Serialize + Ord + fmt::Debug,
K: Serialize + Ord,
V: Serial + DeserialWithState<S>, {
let old_value_option = self.map.insert_borrowed(&key, value);
if old_value_option.is_none() {
Expand All @@ -3125,10 +3125,18 @@ impl<const M: usize, K, V, S> StateBTreeMap<M, K, V, S> {
return old_value_option;
}

/// Remove a key from the map, returning the value at the key if the key was
/// previously in the map.
///
/// *Caution*: If `V` is a [StateBox], [StateMap], then it is
/// important to call [`Deletable::delete`] on the value returned when
/// you're finished with it. Otherwise, it will remain in the contract
/// state.
#[must_use]
pub fn remove_and_get(&mut self, key: &K) -> Option<V>
where
S: HasStateApi,
K: Serialize + Ord + fmt::Debug,
K: Serialize + Ord,
V: Serial + DeserialWithState<S> + Deletable, {
let v = self.map.remove_and_get(key);
if v.is_some() && !self.ordered_set.remove(key) {
Expand All @@ -3138,10 +3146,12 @@ impl<const M: usize, K, V, S> StateBTreeMap<M, K, V, S> {
v
}

/// Remove a key from the map.
/// This also deletes the value in the state.
pub fn remove(&mut self, key: &K)
where
S: HasStateApi,
K: Serialize + Ord + fmt::Debug,
K: Serialize + Ord,
V: Serial + DeserialWithState<S> + Deletable, {
if self.ordered_set.remove(key) {
self.map.remove(key);
Expand Down Expand Up @@ -3197,6 +3207,47 @@ impl<const M: usize, K, V, S> StateBTreeMap<M, K, V, S> {

/// Returns `true` is the map contains no elements.
pub fn is_empty(&self) -> bool { self.ordered_set.is_empty() }

/// Create an iterator over the entries of [`StateBTreeMap`].
/// Ordered by `K`.
pub fn iter(&self) -> StateBTreeMapIter<M, K, V, S>
where
S: HasStateApi, {
StateBTreeMapIter {
key_iter: self.ordered_set.iter(),
map: &self.map,
}
}

/// Clears the map, removing all key-value pairs.
/// This also includes values pointed at, if `V`, for example, is a
/// [StateBox]. **If applicable use [`clear_flat`](Self::clear_flat)
/// instead.**
pub fn clear(&mut self)
where
S: HasStateApi,
K: Serialize,
V: Serial + DeserialWithState<S> + Deletable, {
self.map.clear();
self.ordered_set.clear();
}

/// Clears the map, removing all key-value pairs.
/// **This should be used over [`clear`](Self::clear) if it is
/// applicable.** It avoids recursive deletion of values since the
/// values are required to be _flat_.
///
/// Unfortunately it is not possible to automatically choose between these
/// implementations. Once Rust gets trait specialization then this might
/// be possible.
pub fn clear_flat(&mut self)
where
S: HasStateApi,
K: Serialize,
V: Serialize, {
self.map.clear_flat();
self.ordered_set.clear();
}
}

impl<const M: usize, K, S> StateBTreeSet<M, K, S> {
Expand All @@ -3220,7 +3271,7 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {
pub fn insert(&mut self, key: K) -> bool
where
S: HasStateApi,
K: Serialize + Ord + fmt::Debug, {
K: Serialize + Ord, {
let Some(root_id) = self.root else {
let node_id = {
let (node_id, _node) = self.create_node(vec![key], Vec::new());
Expand Down Expand Up @@ -3288,6 +3339,35 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {
/// Returns `true` is the map contains no elements.
pub fn is_empty(&self) -> bool { self.root.is_none() }

/// Get an iterator over the elements in the `StateBTreeSet`. The iterator
/// returns elements in increasing order.
pub fn iter(&self) -> StateBTreeSetIter<M, K, S>
where
S: HasStateApi, {
StateBTreeSetIter {
length: self.len.try_into().unwrap_abort(),
next_node: self.root,
depth_first_stack: Vec::new(),
tree: &self,
_marker_lifetime: Default::default(),
}
}

/// Clears the map, removing all elements.
pub fn clear(&mut self)
where
S: HasStateApi, {
// Reset the information.
self.root = None;
self.next_node_id = state_btree_internals::NodeId {
id: 0,
};
self.len = 0;
// Then delete every node store in the state.
// Unwrapping is safe when only using the high-level API.
self.state_api.delete_prefix(&self.prefix).unwrap_abort();
}

/// Returns the smallest key in the map, which is strictly larger than the
/// provided key. `None` meaning no such key is present in the map.
pub fn higher(&self, key: &K) -> Option<K>
Expand Down Expand Up @@ -3361,7 +3441,7 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {

pub fn remove(&mut self, key: &K) -> bool
where
K: Ord + Serialize + fmt::Debug,
K: Ord + Serialize,
S: HasStateApi, {
let Some(root_node_id) = self.root else {
return false;
Expand Down Expand Up @@ -3602,7 +3682,7 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {
keys,
children,
};
let entry = self.state_api.create_entry(&self.node_key(node_id)).unwrap_abort();
let entry = self.state_api.create_entry(&node_id.as_key(&self.prefix)).unwrap_abort();
let mut ref_mut: StateRefMut<'_, state_btree_internals::Node<M, K>, S> =
StateRefMut::new(entry, self.state_api.clone());
ref_mut.set(node);
Expand All @@ -3626,7 +3706,7 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {
key: K,
) -> bool
where
K: Serialize + Ord + fmt::Debug,
K: Serialize + Ord,
S: HasStateApi, {
let mut node = initial_node;
loop {
Expand Down Expand Up @@ -3689,14 +3769,14 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {

/// Internal function for looking up a node in the tree.
/// This assumes the node is present and traps if this is not the case.
fn get_node<'a, 'b>(
fn get_node<'a, 'b, Key>(
&'a self,
node_id: state_btree_internals::NodeId,
) -> state_btree_internals::Node<M, K>
) -> state_btree_internals::Node<M, Key>
where
K: Deserial,
Key: Deserial,
S: HasStateApi, {
let key = self.node_key(node_id);
let key = node_id.as_key(&self.prefix);
let mut entry = self.state_api.lookup_entry(&key).unwrap_abort();
entry.get().unwrap_abort()
}
Expand All @@ -3710,51 +3790,10 @@ impl<const M: usize, K, S> StateBTreeSet<M, K, S> {
where
K: Serial,
S: HasStateApi, {
let key = self.node_key(node_id);
let key = node_id.as_key(&self.prefix);
let entry = self.state_api.lookup_entry(&key).unwrap_abort();
StateRefMut::new(entry, self.state_api.clone())
}

/// Construct the key for the node in the key-value store from the node ID.
fn node_key(&self, node_id: state_btree_internals::NodeId) -> [u8; BTREE_NODE_KEY_SIZE] {
// Create an uninitialized array of `MaybeUninit`. The `assume_init` is
// safe because the type we are claiming to have initialized here is a
// bunch of `MaybeUninit`s, which do not require initialization.
let mut prefixed: [MaybeUninit<u8>; BTREE_NODE_KEY_SIZE] =
unsafe { MaybeUninit::uninit().assume_init() };

for i in 0..STATE_ITEM_PREFIX_SIZE {
prefixed[i].write(self.prefix[i]);
}
let id_bytes = node_id.id.to_le_bytes();
for i in 0..id_bytes.len() {
prefixed[STATE_ITEM_PREFIX_SIZE + i].write(id_bytes[i]);
}
// Transmuting away the maybeuninit is safe since we have initialized all of
// them.
unsafe { mem::transmute(prefixed) }
}

pub(crate) fn debug(&self) -> String
where
S: HasStateApi,
K: Deserial + fmt::Debug, {
let Some(root_id) = self.root else {
return "Empty".to_owned();
};

let mut out = String::new();
let mut stack = vec![root_id];
while let Some(node_id) = stack.pop() {
let node = self.get_node(node_id);
out.push_str(format!("{} [\n", node_id.id).as_str());
out.push_str(node.debug().as_str());
out.push_str("]\n");

stack.extend(&node.children)
}
out
}
}

/// Byte size of the key used to store a BTree internal node in the smart
Expand Down Expand Up @@ -3783,44 +3822,112 @@ impl<const M: usize, K> state_btree_internals::Node<M, K> {

/// Check if the node holds the minimum number of keys.
fn is_at_min(&self) -> bool { self.len() == Self::MINIMUM_KEY_LEN }

fn debug(&self) -> String
where
K: fmt::Debug, {
let mut out = String::new();
if self.is_leaf() {
out.push_str(format!("Leaf with {:?}\n", self.keys).as_str());
} else {
for i in 0..self.len() {
out.push_str(format!("Child Id: {}\n", self.children[i].id).as_str());
out.push_str(format!("Key: {:?}\n", self.keys[i]).as_str());
}
out.push_str(format!("Child Id: {}\n", self.children[self.len()].id).as_str());
}
out
}
}

impl state_btree_internals::NodeId {
const SERIALIZED_BYTE_SIZE: usize = 4;
/// Byte size of `NodeId` when serialized.
pub(crate) const SERIALIZED_BYTE_SIZE: usize = 4;

/// Return a copy of the NodeId, then increments itself.
pub(crate) fn copy_then_increment(&mut self) -> Self {
let current = self.clone();
self.id += 1;
current
}

/// Construct the key for the node in the key-value store from the node ID.
fn as_key(&self, prefix: &StateItemPrefix) -> [u8; BTREE_NODE_KEY_SIZE] {
// Create an uninitialized array of `MaybeUninit`. The `assume_init` is
// safe because the type we are claiming to have initialized here is a
// bunch of `MaybeUninit`s, which do not require initialization.
let mut prefixed: [MaybeUninit<u8>; BTREE_NODE_KEY_SIZE] =
unsafe { MaybeUninit::uninit().assume_init() };
for i in 0..STATE_ITEM_PREFIX_SIZE {
prefixed[i].write(prefix[i]);
}
let id_bytes = self.id.to_le_bytes();
for i in 0..id_bytes.len() {
prefixed[STATE_ITEM_PREFIX_SIZE + i].write(id_bytes[i]);
}
// Transmuting away the maybeuninit is safe since we have initialized all of
// them.
unsafe { mem::transmute(prefixed) }
}
}

impl<const M: usize, K, V, S> Deletable for StateBTreeMap<M, K, V, S>
impl<K: Deserial> Deserial for state_btree_internals::KeyWrapper<K> {
fn deserial<R: Read>(source: &mut R) -> ParseResult<Self> {
let key = K::deserial(source)?;
Ok(Self {
key: Some(key),
})
}
}

impl<'a, 'b, const M: usize, K, S> Iterator for StateBTreeSetIter<'a, 'b, M, K, S>
where
'a: 'b,
K: Deserial,
S: HasStateApi,
{
fn delete(self) {
todo!("clear the map");
type Item = StateRef<'b, K>;

fn next(&mut self) -> Option<Self::Item> {
while let Some(id) = self.next_node.take() {
let node = self.tree.get_node(id);
if !node.is_leaf() {
self.next_node = Some(node.children[0]);
}
self.depth_first_stack.push((node, 0));
}

//self.clear();
if let Some((node, index)) = self.depth_first_stack.last_mut() {
let key = node.keys[*index].key.take().unwrap_abort();
*index += 1;
let no_more_keys = index == &node.keys.len();
if !node.is_leaf() {
let child_id = node.children[*index];
self.next_node = Some(child_id);
}
if no_more_keys {
self.depth_first_stack.pop();
}
self.length -= 1;
Some(StateRef::new(key))
} else {
None
}
}

fn size_hint(&self) -> (usize, Option<usize>) { (self.length, Some(self.length)) }
}

impl<'a, 'b, const M: usize, K, V, S> Iterator for StateBTreeMapIter<'a, 'b, M, K, V, S>
where
'a: 'b,
K: Serialize,
V: Serial + DeserialWithState<S> + 'b,
S: HasStateApi,
{
type Item = (StateRef<'b, K>, StateRef<'b, V>);

fn next(&mut self) -> Option<Self::Item> {
let next_key = self.key_iter.next()?;
let value = self.map.get(&next_key).unwrap_abort();
// Unwrap is safe, otherwise the map and the set have inconsistencies.
Some((next_key, value))
}

fn size_hint(&self) -> (usize, Option<usize>) { self.key_iter.size_hint() }
}

impl<const M: usize, K, V, S> Deletable for StateBTreeMap<M, K, V, S>
where
S: HasStateApi,
K: Serialize,
V: Serial + DeserialWithState<S> + Deletable,
{
fn delete(mut self) { self.clear(); }
}

#[cfg(test)]
Expand Down
12 changes: 12 additions & 0 deletions concordium-std/src/test_infrastructure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2681,4 +2681,16 @@ mod test {
assert!(tree.remove(&2));
assert!(!tree.contains(&2));
}

#[test]
fn test_btree_iter() {
let mut state_builder = TestStateBuilder::new();
let mut tree: StateBTreeSet<2, u32, _> = state_builder.new_btree_set();
let keys: Vec<u32> = (0..15).into_iter().collect();
for &k in &keys {
tree.insert(k);
}
let iter_keys: Vec<u32> = tree.iter().map(|k| k.clone()).collect();
assert_eq!(keys, iter_keys);
}
}
Loading

0 comments on commit 2c7d8e2

Please sign in to comment.