diff --git a/tests/test_wimprates.py b/tests/test_wimprates.py index ac1fd9d..727d5f8 100644 --- a/tests/test_wimprates.py +++ b/tests/test_wimprates.py @@ -14,6 +14,8 @@ class TestBenchmarks(unittest.TestCase): opts = dict(mw=50, sigma_nucleon=1e-45, + save_cache=False, + load_cache=False, ) def test_elastic(self): ref = 30.39515403337126 diff --git a/wimprates/summary.py b/wimprates/summary.py index 8ac5c5c..33a92d9 100644 --- a/wimprates/summary.py +++ b/wimprates/summary.py @@ -2,12 +2,14 @@ Summary functions """ import numericalunits as nu +nu.reset_units(42) # Comment this line this when debugging dimensional analysis errors import wimprates as wr export, __all__ = wr.exporter() @export +@wr.save_result def rate_wimp(es, mw, sigma_nucleon, interaction='SI', detection_mechanism='elastic_nr', m_med=float('inf'), t=None, halo_model=None, diff --git a/wimprates/utils.py b/wimprates/utils.py index 890fb38..5be1768 100644 --- a/wimprates/utils.py +++ b/wimprates/utils.py @@ -1,13 +1,17 @@ import functools +import hashlib import inspect import os import pickle +from typing import Any, Callable import warnings from boltons.funcutils import wraps import numpy as np from tqdm.autonotebook import tqdm +import wimprates as wr + def exporter(): """Export utility modified from https://stackoverflow.com/a/41895194 @@ -108,6 +112,7 @@ def pairwise_log_transform(a, b): arr = np.concatenate((a, b), axis=1) return np.log(arr) + @export def deprecated(reason): """ @@ -124,4 +129,58 @@ def new_func(*args, **kwargs): ) return func(*args, **kwargs) return new_func - return decorator \ No newline at end of file + return decorator + + +def _generate_hash(*args, **kwargs): + # Create a string with the arguments and module version + + args_str = wr.__version__ + str(args) + + # Add keyword arguments to the string + args_str += "".join( + [f"{key}{kwargs[key]}" for key in sorted(kwargs) if key != "progress_bar"] + ) + + # Generate a SHA-256 hash + return hashlib.sha256(args_str.encode()).hexdigest() + + +@export +def save_result(func: Callable) -> Callable[..., Any]: + @wraps(func) + def wrapper( + *args, cache_dir: str="wimprates_cache", save_cache: bool=True, load_cache: bool=True, **kwargs + ): + # Define the cache directory + CACHE_DIR = cache_dir + + # Generate the hash based on function arguments and module version + func_name = func.__name__ + cache_key = _generate_hash(*args, **kwargs) + + # Define the path to the cache file + cache_file = os.path.join(CACHE_DIR, f"{func_name}_{cache_key}.pkl") + + # Check if the result is already cached + if load_cache and os.path.exists(cache_file): + with open(cache_file, "rb") as f: + print("Loading from cache: ", cache_file) + return pickle.load(f) + + # Compute the result + result = func(*args, **kwargs) + + if save_cache: + # Ensure cache directory exists + if not os.path.exists(CACHE_DIR): + os.makedirs(CACHE_DIR) + + # Save the result to the cache + with open(cache_file, "wb") as f: + pickle.dump(result, f) + print("Result saved to cache: ", cache_file) + + return result + + return wrapper