Skip to content

Commit

Permalink
Merge pull request #39 from MatthewCaseres/use-the-raw-cache
Browse files Browse the repository at this point in the history
Use the raw cache
  • Loading branch information
lewisfogden authored May 6, 2024
2 parents b56c70a + a122deb commit 8f8ad7a
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 104 deletions.
76 changes: 24 additions & 52 deletions src/heavylight/memory_optimized_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,11 @@
class FunctionCall:
func_name: str
args: tuple
kwargs: FrozenSet[Tuple[str, Any]]

def __repr__(self):
if len(self.kwargs) == 0:
if len(self.args) == 1:
if len(self.args) == 1:
return f"{self.func_name}({self.args[0]})"
return f"{self.func_name}{self.args}"
return f"{self.func_name}({', '.join(map(str, self.args))}, {', '.join(f'{k}={v}' for k, v in self.kwargs)})"

ArgsHash = Tuple[Tuple, frozenset]
return f"{self.func_name}{self.args}"

class CacheGraph:
"""
Expand All @@ -33,8 +28,8 @@ def reset(self):
Clear all internal state of the cache graph.
"""
self.stack: list[FunctionCall] = [] # what function is currently being called
self._caches: defaultdict[str, dict[ArgsHash, Any]] = defaultdict(dict) # Results of function calls, ugly keys like ((1, 2), frozenset([('a', 1)]))
self._caches_agg: defaultdict[str, dict[ArgsHash, Any]] = defaultdict(dict)
self.cache: defaultdict[str, dict[Tuple, Any]] = defaultdict(dict) # Results of function calls, ugly keys like ((1, 2), frozenset([('a', 1)]))
self.cache_agg: defaultdict[str, dict[Tuple, Any]] = defaultdict(dict)
self.graph: defaultdict[FunctionCall, set[FunctionCall]] = defaultdict(set) # Call graph, graph[caller] = [callee1, callee2, ...]
# Typically aggregated results for a function at a timestep.
# What is the last function that needs the result of a function? Used to help in clearing the cache
Expand All @@ -45,8 +40,8 @@ def reset(self):
self.cache_misses: defaultdict[FunctionCall, int] = defaultdict(int)

def check_if_cached(self, function_call: FunctionCall):
name_in_cache = function_call.func_name in self._caches
return name_in_cache and (function_call.args, function_call.kwargs) in self._caches[function_call.func_name]
name_in_cache = function_call.func_name in self.cache
return name_in_cache and function_call.args in self.cache[function_call.func_name]

def optimize(self):
self.can_clear = defaultdict(list)
Expand All @@ -66,8 +61,9 @@ def __call__(self, storage_func: Union[Callable[[int], Any], None] = None):
def custom_cache_decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
frozen_kwargs = frozenset(kwargs.items())
function_call = FunctionCall(func.__name__, args, frozen_kwargs)
if len(kwargs) > 0:
raise ValueError("Keyword arguments are not supported in heavylight")
function_call = FunctionCall(func.__name__, args)
if self.stack:
self.graph[self.stack[-1]].add(function_call)
self.last_needed_by[function_call] = self.stack[-1]
Expand All @@ -76,36 +72,27 @@ def wrapper(*args, **kwargs):
self.cache_misses[function_call] += 1
self.stack.append(function_call)
result = func(*args, **kwargs)
self._caches[func.__name__][(args, frozen_kwargs)] = result
self.cache[func.__name__][args] = result
for clearable_call in self.can_clear[function_call]:
del self._caches[clearable_call.func_name][(clearable_call.args, clearable_call.kwargs)]
del self.cache[clearable_call.func_name][(clearable_call.args)]
self.stack.pop()
self._store_result(storage_func, func, (args, frozen_kwargs), result)
self._store_result(storage_func, func, args, result)
return result
return self._caches[func.__name__][(args, frozen_kwargs)]
return self.cache[func.__name__][args]
decorator = CacheMethod(self, wrapper, storage_func)
return decorator
return custom_cache_decorator

def _store_result(self, storage_func: Union[Callable, None], func: Callable, args_hash: ArgsHash, raw_result: Any):
def _store_result(self, storage_func: Union[Callable, None], func: Callable, args_tuple: Tuple, raw_result: Any):
"""We might want to store an intermediate result"""
if storage_func is None:
return
stored_result = storage_func(raw_result)
self._caches_agg[func.__name__][args_hash] = stored_result
self.cache_agg[func.__name__][args_tuple] = stored_result

def size(self):
return sum(len(cache) for cache in self._caches.values())

@property
def cache(self):
caches = defaultdict(dict, {func_name: {get_pretty_key(k): v for k, v in cache.items()} for func_name, cache in self._caches.items()})
return caches
return sum(len(cache) for cache in self.cache.values())

@property
def cache_agg(self):
caches = defaultdict(dict, {func_name: {get_pretty_key(k): v for k, v in cache.items()} for func_name, cache in self._caches_agg.items()})
return caches

class CacheMethod:
def __init__(self, cache_graph: CacheGraph, func: Callable, agg_func: Union[Callable, None] = None):
Expand All @@ -121,41 +108,26 @@ def df(self):
def df_agg(self):
return pd.DataFrame({self.func.__name__: self.cache_agg})

# only run the dictionary comprehension for the particular method we want to access
# simply returning self.cache_graph.caches[self.func.__name__] would run the dictionary comprehension for all methods
@property
def cache(self):
return {get_pretty_key(k): v for k, v in self._cache.items()}

@property
def cache_agg(self):
return {get_pretty_key(k): v for k, v in self._cache_agg.items()}
return self.cache_graph.cache[self.func.__name__]

@property
def _cache(self):
return self.cache_graph._caches[self.func.__name__]

@property
def _cache_agg(self):
return self.cache_graph._caches_agg[self.func.__name__]
def cache_agg(self):
return self.cache_graph.cache_agg[self.func.__name__]

def __setitem__(self, key, value):
if isinstance(key, int):
self._cache[((key,), frozenset())] = value
self.cache[(key,)] = value
elif isinstance(key, tuple):
self.cache[key] = value
else:
self._cache[(key, frozenset())] = value
raise ValueError("Key must be an integer or a tuple")

def __repr__(self):
return f"<Cache Function: {self.func.__name__}, Size: {len(self._cache)}>"
return f"<Cache Function: {self.func.__name__}, Size: {len(self.cache)}>"

def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.func(*args, **kwds)

def get_pretty_key(key: ArgsHash):
if len(key[0]) == 1 and len(key[1]) == 0:
return key[0][0]
elif len(key[1]) == 0:
return key[0]
else:
return key

2 changes: 1 addition & 1 deletion src/heavylight/memory_optimized_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _run_model(self, proj_len: int):
for func in self._single_param_timestep_funcs:
# We avoid recalling any functions that have already been cached, resolves issue #15 lewisfogden/heavylight
if (
not FunctionCall(func.func.__name__, (t,), frozenset())
not FunctionCall(func.func.__name__, (t,))
in self._cache_graph.all_calls
):
func(t)
Expand Down
38 changes: 9 additions & 29 deletions tests/test_lightmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def fib(self, t, mult_factor):
def test_method_call_and_cache_retrievals():
sm = SimpleModel(np.linspace(.1, 1, 10))
sm.forward_rate(0)
assert sm.forward_rate.cache[0] == .04
assert sm._cache_graph._caches['forward_rate'][((0,), frozenset())] == .04
assert sm._cache_graph.cache['forward_rate'][0] == .04
assert sm.forward_rate.cache[(0,)] == .04
assert sm._cache_graph.cache['forward_rate'][(0,)] == .04
assert len(sm._cache_graph.cache) == 1
assert len(sm.forward_rate.cache) == 1
sm.forward_rate(5)
Expand All @@ -76,34 +75,15 @@ def test_caching_speedups():
# but it is still cached
assert len(sm.fib.cache) == 201

def test_no_kwargs(): # we do not support keyword arguments because multiple possible hashes for same function call
sm = SimpleModel(np.linspace(.1, 1, 10))
with pytest.raises(ValueError):
sm.fib(5, mult_factor=1.1)

def test_reset_cache():
sm = SimpleModel(np.linspace(.1, 1, 10))
sm.RunModel(5)
assert round(np.sum(sm.pols_death.cache[0]), 10) == .055
assert round(np.sum(sm.pols_death.cache[(0,)]), 10) == .055
sm.mortality_rate = .02
sm.RunModel(5)
assert round(np.sum(sm.pols_death.cache[0]), 10) == .11

class TestPrettyCacheModel(LightModel):
def __init__(self):
super().__init__()
def t(self, t):
return self.wowee(t, 1) + self.zowee(t, x=1)
def wowee(self, t, x):
return 1
def zowee(self, t, x):
return 2

def test_pretty_cache():
pretty_model = TestPrettyCacheModel()
pretty_model.Clear()
pretty_model.RunModel(0)
assert pretty_model._cache_graph.cache['wowee'][(0,1)] == 1
assert pretty_model.wowee.cache[(0,1)] == 1
assert pretty_model._cache_graph.cache['zowee'][((0,),frozenset({'x': 1}.items()))] == 2
assert pretty_model.zowee.cache[((0,),frozenset({'x': 1}.items()))] == 2
assert pretty_model._cache_graph.cache['t'][0] == 3
assert pretty_model.t.cache[0] == 3
# can inject into the cache
pretty_model.zowee[1] = 'hello cache'
assert pretty_model.zowee.cache[1] == 'hello cache'
assert round(np.sum(sm.pols_death.cache[(0,)]), 10) == .11
18 changes: 9 additions & 9 deletions tests/test_lightmodel_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,23 +30,23 @@ def t(self, t):
def multi_param(self, t, t2):
return np.ones((self.size,))

expected_cache = {'pols_death': {0: np.array([0.01, 0.01]),
1: np.array([0.0099, 0.0099])},
'pols_if': {0: np.array([1., 1.]), 1: np.array([0.99, 0.99])},
'test_agg_none': {0: np.array([1., 1.]), 1: np.array([1., 1.])},
't': {0: np.array([1., 1.]), 1: np.array([1., 1.])},
expected_cache = {'pols_death': {(0,): np.array([0.01, 0.01]),
(1,): np.array([0.0099, 0.0099])},
'pols_if': {(0,): np.array([1., 1.]), (1,): np.array([0.99, 0.99])},
'test_agg_none': {(0,): np.array([1., 1.]), (1,): np.array([1., 1.])},
't': {(0,): np.array([1., 1.]), (1,): np.array([1., 1.])},
'multi_param': {(0, 0): np.array([1., 1.]), (1, 1): np.array([1., 1.])}}

expected_cache_no_multi = { k: v for k, v in expected_cache.items() if k != 'multi_param' }

expected_cache_agg = {'pols_if': {0: 2.0, 1: 1.98},
'pols_death': {0: 0.02, 1: 0.0198},
't': {0: 2, 1: 2},
expected_cache_agg = {'pols_if': {(0,): 2.0, (1,): 1.98},
'pols_death': {(0,): 0.02, (1,): 0.0198},
't': {(0,): 2, (1,): 2},
'multi_param': {(0, 0): 2, (1, 1): 2}}

expected_cache_agg_no_multi = { k: v for k, v in expected_cache_agg.items() if k != 'multi_param' }

expected_cache_agg_none_aggfunc = {'t': {0: 2, 1: 2}}
expected_cache_agg_none_aggfunc = {'t': {(0,): 2, (1,): 2}}

def test_model_df_before_run():
tm = TestModel(default_agg_function)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_memory_savings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def pv_cashflow(self, t):

def calculate_cache_graph_size(model: LightModel):
cg = model._cache_graph
return sum(val.nbytes for cache in cg._caches.values() for val in cache.values())
return sum(val.nbytes for cache in cg.cache.values() for val in cache.values())

def run_model_calculate_max_cache(model: SimpleModel, max_time: int):
max_cache_size = 0
Expand Down Expand Up @@ -91,12 +91,12 @@ def test_run_optimize():
sm.RunOptimized()
assert len(sm._cache_graph.cache_misses.values()) > 0
assert all(x == 1 for x in sm._cache_graph.cache_misses.values())
assert sm.num_pols_if.cache_agg[1] == 9.9
assert sm.num_pols_if.cache_agg[(1,)] == 9.9
assert len(sm.num_pols_if.cache) == 0
# do it again
sm.initial_pols_if = np.ones((100,))
sm.RunOptimized()
assert len(sm._cache_graph.cache_misses.values()) > 0
assert all(x == 1 for x in sm._cache_graph.cache_misses.values())
assert round(sm.num_pols_if.cache_agg[1], 10) == 99
assert round(sm.num_pols_if.cache_agg[(1,)], 10) == 99
assert len(sm.num_pols_if.cache) == 0
18 changes: 8 additions & 10 deletions tests/test_optimized_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
import pytest

def test_function_call():
fc_args_kwargs = FunctionCall("func", (1, 2), frozenset([('a', 1)]))
fc_args_kwargs = FunctionCall("func", (1, 2))
assert fc_args_kwargs.func_name == "func"
assert fc_args_kwargs.args == (1, 2)
assert fc_args_kwargs.kwargs == frozenset([('a', 1)])
assert repr(fc_args_kwargs) == "func(1, 2, a=1)"
fc_single_arg_no_kwargs = FunctionCall("func", (1,), frozenset())
assert repr(fc_args_kwargs) == "func(1, 2)"
fc_single_arg_no_kwargs = FunctionCall("func", (1,))
assert repr(fc_single_arg_no_kwargs) == "func(1)"
fc_multiple_args_no_kwargs = FunctionCall("func", (1, "hello"), frozenset())
fc_multiple_args_no_kwargs = FunctionCall("func", (1, "hello"))
assert repr(fc_multiple_args_no_kwargs) == "func(1, 'hello')"

def test_cache_graph_storage_function():
Expand All @@ -21,7 +20,7 @@ def fib(n):
return fib(n - 1) + fib(n - 2)
fib(5)
assert len(cg.cache_agg["fib"]) == 6
for k, v in cg.cache_agg["fib"].items():
for (k, *rest), v in cg.cache_agg["fib"].items():
assert v == fib(k)**2

def test_cache_dunders():
Expand All @@ -34,12 +33,11 @@ def fib(n):
fib(5)
assert repr(fib) == "<Cache Function: fib, Size: 6>"
assert len(fib.cache) == 6
test_key = 5
test_key = (5,)
assert fib.cache[test_key] == cg.cache['fib'][test_key] == 5
assert fib.cache[5] == 5 # prettified keys
fib[5] = 10
fib[test_key[0]] = 10
assert fib.cache[test_key] == cg.cache['fib'][test_key] == 10
assert fib.cache[5] == 10
assert fib.cache[test_key] == 10
fib[(5,)] = 100
assert fib.cache[test_key] == cg.cache['fib'][test_key] == 100

Expand Down

0 comments on commit 8f8ad7a

Please sign in to comment.