diff --git a/src/lib.rs b/src/lib.rs index e642c67..be91a84 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -192,7 +192,7 @@ pub struct LruCache { map: HashMap, Box>, S>, cap: NonZeroUsize, - // head and tail are sigil nodes to faciliate inserting entries + // head and tail are sigil nodes to facilitate inserting entries head: *mut LruEntry, tail: *mut LruEntry, } @@ -366,7 +366,7 @@ impl LruCache { } // Used internally to swap out a node if the cache is full or to create a new node if space - // is available. Shared between `put`, `push`, and `get_or_insert`. + // is available. Shared between `put`, `push`, `get_or_insert`, and `get_or_insert_mut`. #[allow(clippy::type_complexity)] fn replace_or_create_node(&mut self, k: K, v: V) -> (Option<(K, V)>, Box>) { if self.len() == self.cap().get() { @@ -510,6 +510,52 @@ impl LruCache { } } + /// Returns a mutable reference to the value of the key in the cache if it is + /// present in the cache and moves the key to the head of the LRU list. + /// If the key does not exist the provided `FnOnce` is used to populate + /// the list and a mutable reference is returned. + /// + /// # Example + /// + /// ``` + /// use lru::LruCache; + /// use std::num::NonZeroUsize; + /// let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap()); + /// + /// cache.put(1, "a"); + /// cache.put(2, "b"); + /// + /// let v = cache.get_or_insert_mut(2, ||"c"); + /// assert_eq!(v, &"b"); + /// *v = "d"; + /// assert_eq!(cache.get_or_insert_mut(2, ||"e"), &mut "d"); + /// assert_eq!(cache.get_or_insert_mut(3, ||"f"), &mut "f"); + /// assert_eq!(cache.get_or_insert_mut(3, ||"e"), &mut "f"); + /// ``` + pub fn get_or_insert_mut<'a, F>(&mut self, k: K, f: F) -> &'a mut V + where + F: FnOnce() -> V, + { + if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) { + let node_ptr: *mut LruEntry = &mut **node; + + self.detach(node_ptr); + self.attach(node_ptr); + + unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V } + } else { + let v = f(); + let (_, mut node) = self.replace_or_create_node(k, v); + + let node_ptr: *mut LruEntry = &mut *node; + self.attach(node_ptr); + + let keyref = unsafe { (*node_ptr).key.as_ptr() }; + self.map.insert(KeyRef { k: keyref }, node); + unsafe { &mut (*(*node_ptr).val.as_mut_ptr()) as &mut V } + } + } + /// Returns a reference to the value corresponding to the key in the cache or `None` if it is /// not present in the cache. Unlike `get`, `peek` does not update the LRU list so the key's /// position will be unchanged. @@ -1252,10 +1298,31 @@ mod tests { assert_eq!(cache.cap().get(), 2); assert_eq!(cache.len(), 2); assert!(!cache.is_empty()); - assert_eq!(cache.get_or_insert(&"apple", || "orange"), &"red"); - assert_eq!(cache.get_or_insert(&"banana", || "orange"), &"yellow"); - assert_eq!(cache.get_or_insert(&"lemon", || "orange"), &"orange"); - assert_eq!(cache.get_or_insert(&"lemon", || "red"), &"orange"); + assert_eq!(cache.get_or_insert("apple", || "orange"), &"red"); + assert_eq!(cache.get_or_insert("banana", || "orange"), &"yellow"); + assert_eq!(cache.get_or_insert("lemon", || "orange"), &"orange"); + assert_eq!(cache.get_or_insert("lemon", || "red"), &"orange"); + } + + #[test] + fn test_put_and_get_or_insert_mut() { + let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap()); + assert!(cache.is_empty()); + + assert_eq!(cache.put("apple", "red"), None); + assert_eq!(cache.put("banana", "yellow"), None); + + assert_eq!(cache.cap().get(), 2); + assert_eq!(cache.len(), 2); + + let v = cache.get_or_insert_mut("apple", || "orange"); + assert_eq!(v, &"red"); + *v = "blue"; + + assert_eq!(cache.get_or_insert_mut("apple", || "orange"), &"blue"); + assert_eq!(cache.get_or_insert_mut("banana", || "orange"), &"yellow"); + assert_eq!(cache.get_or_insert_mut("lemon", || "orange"), &"orange"); + assert_eq!(cache.get_or_insert_mut("lemon", || "red"), &"orange"); } #[test]