From d7c970766e0f670ba2f55b5a126b2327d2aabde7 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Wed, 22 May 2024 14:54:33 -0500 Subject: [PATCH] Refactor cache decorator These changes make the `cache` decorator operate more like `diskcache`'s existing `memoize` method. They also remove the use of hash value as store keys. --- outlines/caching.py | 129 ++++++++++++++++++++++++++++---------------- 1 file changed, 84 insertions(+), 45 deletions(-) diff --git a/outlines/caching.py b/outlines/caching.py index 68207a0e4..52d66af74 100644 --- a/outlines/caching.py +++ b/outlines/caching.py @@ -1,15 +1,40 @@ import asyncio import functools -import hashlib import os from typing import Callable, Optional import cloudpickle -from diskcache import Cache +from diskcache import Cache, Disk +from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name _caching_enabled = True +class CloudpickleDisk(Disk): + def __init__(self, directory, compress_level=1, **kwargs): + self.compress_level = compress_level + super().__init__(directory, **kwargs) + + def put(self, key): + data = cloudpickle.dumps(key) + return super().put(data) + + def get(self, key, raw): + data = super().get(key, raw) + return cloudpickle.loads(data) + + def store(self, value, read, key=UNKNOWN): + if not read: + value = cloudpickle.dumps(value) + return super().store(value, read, key=key) + + def fetch(self, mode, filename, value, read): + data = super().fetch(mode, filename, value, read) + if not read: + data = cloudpickle.loads(data) + return data + + @functools.lru_cache(1) def get_cache(): """Get the context object that contains previously-computed return values. @@ -26,7 +51,12 @@ def get_cache(): home_dir = os.path.expanduser("~") cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines") - memory = Cache(cache_dir, eviction_policy="none", cull_limit=0) + memory = Cache( + cache_dir, + eviction_policy="none", + cull_limit=0, + disk=CloudpickleDisk, + ) # ensure if version upgrade occurs, old cache is pruned if outlines_version != memory.get("__version__"): @@ -36,63 +66,72 @@ def get_cache(): return memory -def hash_arguments(*args, **kwargs) -> str: - """Create a hash out of the args and kwargs provided""" - result = hashlib.md5() - for item in list(args) + sorted(kwargs.items()): - result.update(cloudpickle.dumps(item)) - return result.hexdigest() - - -def cache(key_function: Optional[Callable] = None): +def cache(expire: Optional[float] = None, typed=False, ignore=()): """Caching decorator for memoizing function calls. + The cache key is created based on the values returned by the key_function callable if provided or based on the arguments of the decorated function directly otherwise + + This is based on `diskcache`'s `memoize`. + Parameters ---------- - key_function - A callable function used to generate a unique key for each function call. It's - called with the arguments of the decorated function as arguments + expire + Seconds until arguments expire. + typed + Cache different types separately. + ignore + Positional or keyword arguments to ignore. + Returns ------- - A decorator function that can be applied to other functions. + A decorator function that can be applied to other functions. """ def decorator(cached_function: Callable): memory = get_cache() - def wrapper(*args, **kwargs): - if not _caching_enabled: - return cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = cached_function(*args, **kwargs) - memory[cache_key] = result - return result - - async def async_wrapper(*args, **kwargs): - if not _caching_enabled: - return await cached_function(*args, **kwargs) - if key_function: - key_args = key_function(*args, **kwargs) - cache_key = hash_arguments(*key_args) - else: - cache_key = hash_arguments(*args, **kwargs) - if cache_key in memory: - return memory[cache_key] - result = await cached_function(*args, **kwargs) - memory[cache_key] = result - return result + base = (full_name(cached_function),) if asyncio.iscoroutinefunction(cached_function): - return async_wrapper + + async def wrapper(*args, **kwargs): + if not _caching_enabled: + return await cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = await cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + else: - return wrapper + + def wrapper(*args, **kwargs): + if not _caching_enabled: + return cached_function(*args, **kwargs) + + cache_key = wrapper.__cache_key__(*args, **kwargs) + result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True) + + if result is ENOVAL: + result = cached_function(*args, **kwargs) + wrapper.__memory__.set(cache_key, result, expire, retry=True) + + return result + + def __cache_key__(*args, **kwargs): + """Make key for cache given function arguments.""" + return args_to_key(base, args, kwargs, typed, ignore) + + wrapper.__cache_key__ = __cache_key__ # type: ignore + wrapper.__memory__ = memory # type: ignore + wrapper.__wrapped__ = cached_function # type: ignore + + return wrapper return decorator