Skip to content

Commit

Permalink
Refactor cache decorator
Browse files Browse the repository at this point in the history
These changes make the `cache` decorator operate more like `diskcache`'s
existing `memoize` method.  They also remove the use of hash value as store
keys.
  • Loading branch information
brandonwillard committed May 23, 2024
1 parent ffab2ac commit d7c9707
Showing 1 changed file with 84 additions and 45 deletions.
129 changes: 84 additions & 45 deletions outlines/caching.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
import asyncio
import functools
import hashlib
import os
from typing import Callable, Optional

import cloudpickle
from diskcache import Cache
from diskcache import Cache, Disk
from diskcache.core import ENOVAL, UNKNOWN, args_to_key, full_name

_caching_enabled = True


class CloudpickleDisk(Disk):
def __init__(self, directory, compress_level=1, **kwargs):
self.compress_level = compress_level
super().__init__(directory, **kwargs)

def put(self, key):
data = cloudpickle.dumps(key)
return super().put(data)

def get(self, key, raw):
data = super().get(key, raw)
return cloudpickle.loads(data)

def store(self, value, read, key=UNKNOWN):
if not read:
value = cloudpickle.dumps(value)
return super().store(value, read, key=key)

def fetch(self, mode, filename, value, read):
data = super().fetch(mode, filename, value, read)
if not read:
data = cloudpickle.loads(data)
return data


@functools.lru_cache(1)
def get_cache():
"""Get the context object that contains previously-computed return values.
Expand All @@ -26,7 +51,12 @@ def get_cache():

home_dir = os.path.expanduser("~")
cache_dir = os.environ.get("OUTLINES_CACHE_DIR", f"{home_dir}/.cache/outlines")
memory = Cache(cache_dir, eviction_policy="none", cull_limit=0)
memory = Cache(
cache_dir,
eviction_policy="none",
cull_limit=0,
disk=CloudpickleDisk,
)

# ensure if version upgrade occurs, old cache is pruned
if outlines_version != memory.get("__version__"):
Expand All @@ -36,63 +66,72 @@ def get_cache():
return memory


def hash_arguments(*args, **kwargs) -> str:
"""Create a hash out of the args and kwargs provided"""
result = hashlib.md5()
for item in list(args) + sorted(kwargs.items()):
result.update(cloudpickle.dumps(item))
return result.hexdigest()


def cache(key_function: Optional[Callable] = None):
def cache(expire: Optional[float] = None, typed=False, ignore=()):
"""Caching decorator for memoizing function calls.
The cache key is created based on the values returned by the key_function callable
if provided or based on the arguments of the decorated function directly otherwise
This is based on `diskcache`'s `memoize`.
Parameters
----------
key_function
A callable function used to generate a unique key for each function call. It's
called with the arguments of the decorated function as arguments
expire
Seconds until arguments expire.
typed
Cache different types separately.
ignore
Positional or keyword arguments to ignore.
Returns
-------
A decorator function that can be applied to other functions.
A decorator function that can be applied to other functions.
"""

def decorator(cached_function: Callable):
memory = get_cache()

def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = cached_function(*args, **kwargs)
memory[cache_key] = result
return result

async def async_wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)
if key_function:
key_args = key_function(*args, **kwargs)
cache_key = hash_arguments(*key_args)
else:
cache_key = hash_arguments(*args, **kwargs)
if cache_key in memory:
return memory[cache_key]
result = await cached_function(*args, **kwargs)
memory[cache_key] = result
return result
base = (full_name(cached_function),)

if asyncio.iscoroutinefunction(cached_function):
return async_wrapper

async def wrapper(*args, **kwargs):
if not _caching_enabled:
return await cached_function(*args, **kwargs)

cache_key = wrapper.__cache_key__(*args, **kwargs)
result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)

if result is ENOVAL:
result = await cached_function(*args, **kwargs)
wrapper.__memory__.set(cache_key, result, expire, retry=True)

return result

else:
return wrapper

def wrapper(*args, **kwargs):
if not _caching_enabled:
return cached_function(*args, **kwargs)

cache_key = wrapper.__cache_key__(*args, **kwargs)
result = wrapper.__memory__.get(cache_key, default=ENOVAL, retry=True)

if result is ENOVAL:
result = cached_function(*args, **kwargs)
wrapper.__memory__.set(cache_key, result, expire, retry=True)

return result

def __cache_key__(*args, **kwargs):
"""Make key for cache given function arguments."""
return args_to_key(base, args, kwargs, typed, ignore)

wrapper.__cache_key__ = __cache_key__ # type: ignore
wrapper.__memory__ = memory # type: ignore
wrapper.__wrapped__ = cached_function # type: ignore

return wrapper

return decorator

Expand Down

0 comments on commit d7c9707

Please sign in to comment.