Skip to content

Commit

Permalink
Add caching functionality to save and load results
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzomag committed Aug 15, 2024
1 parent bfa4394 commit 623ed89
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 1 deletion.
2 changes: 2 additions & 0 deletions tests/test_wimprates.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
class TestBenchmarks(unittest.TestCase):
opts = dict(mw=50,
sigma_nucleon=1e-45,
save_cache=False,
load_cache=False,
)
def test_elastic(self):
ref = 30.39515403337126
Expand Down
2 changes: 2 additions & 0 deletions wimprates/summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
Summary functions
"""
import numericalunits as nu
nu.reset_units(42) # Comment this line this when debugging dimensional analysis errors

import wimprates as wr
export, __all__ = wr.exporter()


@export
@wr.save_result
def rate_wimp(es, mw, sigma_nucleon, interaction='SI',
detection_mechanism='elastic_nr', m_med=float('inf'),
t=None, halo_model=None,
Expand Down
61 changes: 60 additions & 1 deletion wimprates/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import functools
import hashlib
import inspect
import os
import pickle
from typing import Any, Callable
import warnings

from boltons.funcutils import wraps
import numpy as np
from tqdm.autonotebook import tqdm

import wimprates as wr


def exporter():
"""Export utility modified from https://stackoverflow.com/a/41895194
Expand Down Expand Up @@ -108,6 +112,7 @@ def pairwise_log_transform(a, b):
arr = np.concatenate((a, b), axis=1)
return np.log(arr)


@export
def deprecated(reason):
"""
Expand All @@ -124,4 +129,58 @@ def new_func(*args, **kwargs):
)
return func(*args, **kwargs)
return new_func
return decorator
return decorator


def _generate_hash(*args, **kwargs):
# Create a string with the arguments and module version

args_str = wr.__version__ + str(args)

# Add keyword arguments to the string
args_str += "".join(
[f"{key}{kwargs[key]}" for key in sorted(kwargs) if key != "progress_bar"]
)

# Generate a SHA-256 hash
return hashlib.sha256(args_str.encode()).hexdigest()


@export
def save_result(func: Callable) -> Callable[..., Any]:
@wraps(func)
def wrapper(
*args, cache_dir: str="wimprates_cache", save_cache: bool=True, load_cache: bool=True, **kwargs
):
# Define the cache directory
CACHE_DIR = cache_dir

# Generate the hash based on function arguments and module version
func_name = func.__name__
cache_key = _generate_hash(*args, **kwargs)

# Define the path to the cache file
cache_file = os.path.join(CACHE_DIR, f"{func_name}_{cache_key}.pkl")

# Check if the result is already cached
if load_cache and os.path.exists(cache_file):
with open(cache_file, "rb") as f:
print("Loading from cache: ", cache_file)
return pickle.load(f)

# Compute the result
result = func(*args, **kwargs)

if save_cache:
# Ensure cache directory exists
if not os.path.exists(CACHE_DIR):
os.makedirs(CACHE_DIR)

# Save the result to the cache
with open(cache_file, "wb") as f:
pickle.dump(result, f)
print("Result saved to cache: ", cache_file)

return result

return wrapper

0 comments on commit 623ed89

Please sign in to comment.