Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add try_get_or_insert_mut #178

Merged
merged 1 commit into from
Sep 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 84 additions & 16 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,9 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
/// let b = ||->Result<&str, String> {Ok("b")};
/// assert_eq!(cache.try_get_or_insert(2, a), Ok(&"c"));
/// assert_eq!(cache.try_get_or_insert(3, a), Ok(&"d"));
/// assert_eq!(cache.try_get_or_insert(1, f), Err("failed".to_owned()));
/// assert_eq!(cache.try_get_or_insert(1, b), Ok(&"b"));
/// assert_eq!(cache.try_get_or_insert(1, a), Ok(&"b"));
/// assert_eq!(cache.try_get_or_insert(4, f), Err("failed".to_owned()));
/// assert_eq!(cache.try_get_or_insert(5, b), Ok(&"b"));
/// assert_eq!(cache.try_get_or_insert(5, a), Ok(&"b"));
/// ```
pub fn try_get_or_insert<F, E>(&mut self, k: K, f: F) -> Result<&V, E>
where
Expand All @@ -547,19 +547,15 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {

unsafe { Ok(&*(*node_ptr).val.as_ptr()) }
} else {
match f() {
Err(e) => Err(e),
Ok(v) => {
let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
Ok(unsafe { &*(*node_ptr).val.as_ptr() })
}
}
let v = f()?;
let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
Ok(unsafe { &*(*node_ptr).val.as_ptr() })
}
}

Expand Down Expand Up @@ -609,6 +605,58 @@ impl<K: Hash + Eq, V, S: BuildHasher> LruCache<K, V, S> {
}
}

/// Returns a mutable reference to the value of the key in the cache if it is
/// present in the cache and moves the key to the head of the LRU list.
/// If the key does not exist the provided `FnOnce` is used to populate
/// the list and a mutable reference is returned. If `FnOnce` returns `Err`,
/// returns the `Err`.
///
/// # Example
///
/// ```
/// use lru::LruCache;
/// use std::num::NonZeroUsize;
/// let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
///
/// cache.put(1, "a");
/// cache.put(2, "b");
/// cache.put(2, "c");
///
/// let f = ||->Result<&str, String> {Err("failed".to_owned())};
/// let a = ||->Result<&str, String> {Ok("a")};
/// let b = ||->Result<&str, String> {Ok("b")};
/// if let Ok(v) = cache.try_get_or_insert_mut(2, a) {
/// *v = "d";
/// }
/// assert_eq!(cache.try_get_or_insert_mut(2, a), Ok(&mut "d"));
/// assert_eq!(cache.try_get_or_insert_mut(3, f), Err("failed".to_owned()));
/// assert_eq!(cache.try_get_or_insert_mut(4, b), Ok(&mut "b"));
/// assert_eq!(cache.try_get_or_insert_mut(4, a), Ok(&mut "b"));
/// ```
pub fn try_get_or_insert_mut<'a, F, E>(&mut self, k: K, f: F) -> Result<&'a mut V, E>
where
F: FnOnce() -> Result<V, E>,
{
if let Some(node) = self.map.get_mut(&KeyRef { k: &k }) {
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.detach(node_ptr);
self.attach(node_ptr);

unsafe { Ok(&mut *(*node_ptr).val.as_mut_ptr()) }
} else {
let v = f()?;
let (_, node) = self.replace_or_create_node(k, v);
let node_ptr: *mut LruEntry<K, V> = node.as_ptr();

self.attach(node_ptr);

let keyref = unsafe { (*node_ptr).key.as_ptr() };
self.map.insert(KeyRef { k: keyref }, node);
unsafe { Ok(&mut *(*node_ptr).val.as_mut_ptr()) }
}
}

/// Returns a reference to the value corresponding to the key in the cache or `None` if it is
/// not present in the cache. Unlike `get`, `peek` does not update the LRU list so the key's
/// position will be unchanged.
Expand Down Expand Up @@ -1494,6 +1542,26 @@ mod tests {
assert_eq!(cache.get_or_insert_mut("lemon", || "red"), &"orange");
}

#[test]
fn test_try_get_or_insert_mut() {
let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());

cache.put(1, "a");
cache.put(2, "b");
cache.put(2, "c");

let f = || -> Result<&str, &str> { Err("failed") };
let a = || -> Result<&str, &str> { Ok("a") };
let b = || -> Result<&str, &str> { Ok("b") };
if let Ok(v) = cache.try_get_or_insert_mut(2, a) {
*v = "d";
}
assert_eq!(cache.try_get_or_insert_mut(2, a), Ok(&mut "d"));
assert_eq!(cache.try_get_or_insert_mut(3, f), Err("failed"));
assert_eq!(cache.try_get_or_insert_mut(4, b), Ok(&mut "b"));
assert_eq!(cache.try_get_or_insert_mut(4, a), Ok(&mut "b"));
}

#[test]
fn test_put_and_get_mut() {
let mut cache = LruCache::new(NonZeroUsize::new(2).unwrap());
Expand Down
Loading