From 22e70fa4aa49eea4c18b9447786a8d4aafcfd69b Mon Sep 17 00:00:00 2001 From: Lorenzo Principe <28869147+lorenzomag@users.noreply.github.com> Date: Thu, 15 Aug 2024 21:09:32 +0200 Subject: [PATCH 1/5] Remove overhead on import Refactored a bit the code here to remove expensive and unutilised computation on import of wimprates --- wimprates/electron.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/wimprates/electron.py b/wimprates/electron.py index b86a4a1..90c3e0a 100644 --- a/wimprates/electron.py +++ b/wimprates/electron.py @@ -1,5 +1,6 @@ """Dark matter - electron scattering """ +from functools import lru_cache import numericalunits as nu import numpy as np from scipy.interpolate import RegularGridInterpolator, interp1d @@ -9,13 +10,17 @@ export, __all__ = wr.exporter() __all__ += ['dme_shells', 'l_to_letter', 'l_to_number'] -# Load form factor and construct interpolators -shell_data = wr.load_pickle('dme/dme_ionization_ff.pkl') -for _shell, _sd in shell_data.items(): - _sd['log10ffsquared_itp'] = RegularGridInterpolator( - (_sd['lnks'], _sd['lnqs']), - np.log10(_sd['ffsquared']), - bounds_error=False, fill_value=-float('inf'),) + +@lru_cache() +def get_shell_data(): + """Load form factor and construct interpolators""" + shell_data = wr.load_pickle('dme/dme_ionization_ff.pkl') + for _shell, _sd in shell_data.items(): + _sd['log10ffsquared_itp'] = RegularGridInterpolator( + (_sd['lnks'], _sd['lnqs']), + np.log10(_sd['ffsquared']), + bounds_error=False, fill_value=-float('inf'),) + return shell_data dme_shells = [(5, 1), (5, 0), (4, 2), (4, 1), (4, 0)] @@ -54,6 +59,9 @@ def dme_ionization_ff(shell, e_er, q): # Ry = rydberg = 13.6 eV ry = nu.me * nu.e ** 4 / (8 * nu.eps0 ** 2 * nu.hPlanck ** 2) lnk = np.log(e_er / ry) / 2 + + shell_data = get_shell_data() + return 10**(shell_data[shell]['log10ffsquared_itp']( np.vstack([lnk, lnq]).T)) @@ -84,8 +92,8 @@ def v_min_dme(eb, erec, q, mw): return (erec + eb) / q + q / (2 * mw) -# Precompute velocity integrals for t=None @export +@lru_cache() def velocity_integral_without_time(halo_model=None): halo_model = wr.StandardHaloModel() if halo_model is None else halo_model _v_mins = np.linspace(0, 1, 1000) * wr.v_max(None, halo_model.v_esc) @@ -105,7 +113,6 @@ def velocity_integral_without_time(halo_model=None): fill_value=0, bounds_error=False) return inverse_mean_speed_kms -inverse_mean_speed_kms = velocity_integral_without_time() @export @@ -141,9 +148,12 @@ def rate_dme(erec, n, l, mw, sigma_dme, # No bounds are given for the q integral # but the form factors are only specified in a limited range of q + shell_data = get_shell_data() qmax = (np.exp(shell_data[shell]['lnqs'].max()) * (nu.me * nu.c0 * nu.alphaFS)) + # Precompute velocity integrals for t=None + inverse_mean_speed_kms = velocity_integral_without_time() if t is None: # Use precomputed inverse mean speed, # so we only have to do a single integral From 3ab6231dd1fcfd2da24cdad1645ecd08135ddbc4 Mon Sep 17 00:00:00 2001 From: Lorenzo Principe <28869147+lorenzomag@users.noreply.github.com> Date: Thu, 15 Aug 2024 23:05:50 +0200 Subject: [PATCH 2/5] Implement optional multiprocessing for Migdal --- wimprates/migdal.py | 241 +++++++++++++++++++++++++++----------------- 1 file changed, 150 insertions(+), 91 deletions(-) diff --git a/wimprates/migdal.py b/wimprates/migdal.py index 6c23e4b..a4219f4 100644 --- a/wimprates/migdal.py +++ b/wimprates/migdal.py @@ -10,6 +10,7 @@ """ from collections.abc import Callable +from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass from typing import Any, Optional, Union @@ -17,6 +18,7 @@ from functools import lru_cache import numericalunits as nu import numpy as np +from tqdm.autonotebook import tqdm import pandas as pd from scipy.integrate import dblquad from scipy.interpolate import interp1d @@ -68,7 +70,7 @@ def l(self) -> str: return self.name[1:] -def _default_shells(material: str) -> list[str]: +def _default_shells(material: str) -> tuple[str]: """ Returns the default shells to consider for a given material. Args: @@ -88,45 +90,25 @@ def _default_shells(material: str) -> list[str]: Ge=["3*"], Si=["2*"], ) - return consider_shells[material] + return tuple(consider_shells[material]) -def create_cox_probability_function( - element, - state: str, - material: str, +def _create_cox_probability_function( + element: str, dipole: bool = False, ) -> Callable[..., np.ndarray[Any, Any]]: fn_name = "dpI1dipole" if dipole else "dpI1" fn = getattr(element, fn_name) - def get_probability( - e: Union[float, np.ndarray], # energy of released electron - erec: Optional[Union[float, np.ndarray]] = None, # recoil energy - v: Optional[Union[float, np.ndarray]] = None, # recoil speed - ) -> np.ndarray: - if erec is None: - if v is None: - raise ValueError("Either v or erec have to be provided") - elif v is None: - v = (2 * erec / wr.mn(material)) ** 0.5 / nu.c0 - else: - raise ValueError("Either v or erec have to be provided") - - e /= nu.keV - - input_points = wr.pairwise_log_transform(e, v) - return fn(input_points, state) / nu.keV # type: ignore - - return get_probability + return fn @export def get_migdal_transitions_probability_iterators( material: str = "Xe", model: str = "Ibe", - considered_shells: Optional[Union[list[str], str]] = None, + considered_shells: Optional[Union[tuple[str], str]] = None, dark_matter: bool = True, e_threshold: Optional[float] = None, dipole: bool = False, @@ -193,6 +175,7 @@ def get_migdal_transitions_probability_iterators( dipole=dipole, dark_matter=dark_matter, e_threshold=e_threshold, + **kwargs ) for state, binding_e in element.orbitals: @@ -205,10 +188,8 @@ def get_migdal_transitions_probability_iterators( material, binding_e * nu.keV, model, - single_ionization_probability=create_cox_probability_function( + single_ionization_probability=_create_cox_probability_function( element, - state, - material, dipole=dipole, ), ) @@ -230,10 +211,92 @@ def vmin_migdal( return np.maximum(0, y) +def get_diff_rate( + w: float, + shells: list[Shell], + mw: float, + sigma_nucleon: float, + halo_model: wr.StandardHaloModel, + interaction: str, + m_med: float, + migdal_model: str, + include_approx_nr: bool, + q_nr: float, + material: str, + t: Optional[float], + **kwargs, +): + result = 0 + for shell in shells: + + def diff_rate(v, erec): + # Observed energy = energy of emitted electron + # + binding energy of state + eelec = w - shell.binding_e - include_approx_nr * erec * q_nr + if eelec < 0: + return 0 + + if migdal_model == "Ibe": + return ( + # Usual elastic differential rate, + # common constants follow at end + wr.sigma_erec( + erec, + v, + mw, + sigma_nucleon, + interaction, + m_med=m_med, + material=material, + ) + * v + * halo_model.velocity_dist(v, t) + # Migdal effect |Z|^2 + # TODO: ?? what is explicit (eV/c)**2 doing here? + * (nu.me * (2 * erec / wr.mn(material)) ** 0.5 / (nu.eV / nu.c0)) + ** 2 + / (2 * np.pi) + * shell(eelec) + ) + elif migdal_model == "Cox": + vrec = (2 * erec / wr.mn(material)) ** 0.5 / nu.c0 + input_points = wr.pairwise_log_transform(eelec/nu.keV, vrec) + return ( + wr.sigma_erec( + erec, + v, + mw, + sigma_nucleon, + interaction, + m_med=m_med, + material=material, + ) + * v + * halo_model.velocity_dist(v, t) + * shell(input_points, shell.name) / nu.keV + ) + + # Note dblquad expects the function to be f(y, x), not f(x, y)... + result += dblquad( + diff_rate, + 0, + wr.e_max(mw, wr.v_max(t, halo_model.v_esc), wr.mn(material)), + lambda erec: vmin_migdal( + w=w - include_approx_nr * erec * q_nr, + erec=erec, + mw=mw, + material=material, + ), + lambda _: wr.v_max(t, halo_model.v_esc), + **kwargs, + )[0] + + return result + + @export -@wr.vectorize_first def rate_migdal( - w: np.ndarray, + w: Union[np.ndarray, float], mw: float, sigma_nucleon: float, interaction: str = "SI", @@ -243,11 +306,13 @@ def rate_migdal( material: str = "Xe", t: Optional[float] = None, halo_model: Optional[wr.StandardHaloModel] = None, - consider_shells: Optional[list[str]] = None, + consider_shells: Optional[tuple[str]] = None, migdal_model: str = "Ibe", dark_matter: bool = True, dipole: bool = False, e_threshold: Optional[float] = None, + progress_bar: bool = False, + multi_process: Optional[Union[bool, int]] = True, **kwargs, ) -> np.ndarray: """Differential rate per unit detector mass and deposited ER energy of @@ -279,8 +344,21 @@ def rate_migdal( Further kwargs are passed to scipy.integrate.quad numeric integrator (e.g. error tolerance). """ + _is_array = True + if not isinstance(w, np.ndarray): + if isinstance(w, float): + _is_array = False + w = np.array([w]) + else: + raise ValueError("w must be a float or a numpy array") + halo_model = wr.StandardHaloModel() if halo_model is None else halo_model + if progress_bar: + prog_bar = tqdm + else: + prog_bar = lambda x, *args, **kwargs: x + if not consider_shells: consider_shells = _default_shells(material) @@ -293,72 +371,53 @@ def rate_migdal( dark_matter=dark_matter, ) - result = 0 - for shell in shells: - - def diff_rate(v, erec): - # Observed energy = energy of emitted electron - # + binding energy of state - eelec = w - shell.binding_e - include_approx_nr * erec * q_nr - if eelec < 0: - return 0 - - if migdal_model == "Ibe": - return ( - # Usual elastic differential rate, - # common constants follow at end - wr.sigma_erec( - erec, - v, - mw, - sigma_nucleon, - interaction, - m_med=m_med, - material=material, - ) - * v - * halo_model.velocity_dist(v, t) - # Migdal effect |Z|^2 - # TODO: ?? what is explicit (eV/c)**2 doing here? - * (nu.me * (2 * erec / wr.mn(material)) ** 0.5 / (nu.eV / nu.c0)) - ** 2 - / (2 * np.pi) - * shell(eelec) - ) - elif migdal_model == "Cox": - return ( - wr.sigma_erec( - erec, - v, + results = [] + if multi_process and not dipole: + multi_process = None if isinstance(multi_process, bool) else multi_process + with ProcessPoolExecutor(multi_process) as executor: + futures = [] + for val in w: + futures.append( + executor.submit( + get_diff_rate, + val, + shells, mw, sigma_nucleon, + halo_model, interaction, - m_med=m_med, - material=material, + m_med, + migdal_model, + include_approx_nr, + q_nr, + material, + t, ) - * v - * halo_model.velocity_dist(v, t) - * shell(eelec, erec) ) - # Note dblquad expects the function to be f(y, x), not f(x, y)... - r = dblquad( - diff_rate, - 0, - wr.e_max(mw, wr.v_max(t, halo_model.v_esc), wr.mn(material)), - lambda erec: vmin_migdal( - w=w - include_approx_nr * erec * q_nr, - erec=erec, - mw=mw, - material=material, - ), - lambda _: wr.v_max(t, halo_model.v_esc), - **kwargs, - )[0] - - result += r + for future in prog_bar(futures, desc="Computing rates"): + results.append(future.result()) + else: + for val in prog_bar(w, desc="Computing rates"): + results.append( + get_diff_rate( + val, + shells, + mw, + sigma_nucleon, + halo_model, + interaction, + m_med, + migdal_model, + include_approx_nr, + q_nr, + material, + t, + ) + ) - return halo_model.rho_dm / mw * (1 / wr.mn(material)) * np.array(result) + results = np.array(results) if _is_array else float(results[0]) + return halo_model.rho_dm / mw * (1 / wr.mn(material)) * results @wr.deprecated("Use get_migdal_transitions_probability_iterators instead") From 411bfd60479ccaa6f79c34fdd3ce077c7b9b3081 Mon Sep 17 00:00:00 2001 From: Lorenzo Principe <28869147+lorenzomag@users.noreply.github.com> Date: Thu, 15 Aug 2024 23:07:51 +0200 Subject: [PATCH 3/5] Make Multiprocessing loop more Pythonic --- wimprates/migdal.py | 56 +++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/wimprates/migdal.py b/wimprates/migdal.py index a4219f4..0af6b6a 100644 --- a/wimprates/migdal.py +++ b/wimprates/migdal.py @@ -12,10 +12,11 @@ from collections.abc import Callable from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass +import os from typing import Any, Optional, Union from fnmatch import fnmatch -from functools import lru_cache +from functools import lru_cache, partial import numericalunits as nu import numpy as np from tqdm.autonotebook import tqdm @@ -312,7 +313,7 @@ def rate_migdal( dipole: bool = False, e_threshold: Optional[float] = None, progress_bar: bool = False, - multi_process: Optional[Union[bool, int]] = True, + multi_processing: Optional[Union[bool, int]] = True, **kwargs, ) -> np.ndarray: """Differential rate per unit detector mass and deposited ER energy of @@ -371,33 +372,34 @@ def rate_migdal( dark_matter=dark_matter, ) - results = [] - if multi_process and not dipole: - multi_process = None if isinstance(multi_process, bool) else multi_process - with ProcessPoolExecutor(multi_process) as executor: - futures = [] - for val in w: - futures.append( - executor.submit( - get_diff_rate, - val, - shells, - mw, - sigma_nucleon, - halo_model, - interaction, - m_med, - migdal_model, - include_approx_nr, - q_nr, - material, - t, - ) - ) + if multi_processing and not dipole: + multi_processing = None if isinstance(multi_processing, bool) else multi_processing + with ProcessPoolExecutor(multi_processing) as executor: + partial_get_diff_rate = partial( + get_diff_rate, + shells=shells, + mw=mw, + sigma_nucleon=sigma_nucleon, + halo_model=halo_model, + interaction=interaction, + m_med=m_med, + migdal_model=migdal_model, + include_approx_nr=include_approx_nr, + q_nr=q_nr, + material=material, + t=t, + ) - for future in prog_bar(futures, desc="Computing rates"): - results.append(future.result()) + n_workers = os.cpu_count() if multi_processing is None else multi_processing + results = list( + prog_bar( + executor.map(partial_get_diff_rate, w), + desc=f"Computing rates (MP={n_workers} workers)", + total=len(w), + ) + ) else: + results = [] for val in prog_bar(w, desc="Computing rates"): results.append( get_diff_rate( From a0e4d871e6ac3f5b9ab687c66f0016741a9e7cf4 Mon Sep 17 00:00:00 2001 From: Lorenzo Principe <28869147+lorenzomag@users.noreply.github.com> Date: Thu, 15 Aug 2024 23:08:57 +0200 Subject: [PATCH 4/5] Cache Cox Model instance during runtime --- wimprates/data/migdal/Cox/cox_wrapper.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/wimprates/data/migdal/Cox/cox_wrapper.py b/wimprates/data/migdal/Cox/cox_wrapper.py index 67dd5f6..7743eef 100644 --- a/wimprates/data/migdal/Cox/cox_wrapper.py +++ b/wimprates/data/migdal/Cox/cox_wrapper.py @@ -5,17 +5,18 @@ The working directory is then reset. """ +from functools import lru_cache import os import sys -import wimprates as wr - from .cox_submodule.Migdal import Migdal +import wimprates as wr export, __all__ = wr.exporter() @export +@lru_cache def cox_migdal_model(element: str, **kwargs) -> Migdal: """ This function creates a Cox Migdal model for a given element. From 8e6c5959a2e8c416806de0e4c2b341edde548d02 Mon Sep 17 00:00:00 2001 From: Lorenzo Principe <28869147+lorenzomag@users.noreply.github.com> Date: Thu, 15 Aug 2024 23:11:20 +0200 Subject: [PATCH 5/5] Add caching functionality to save and load results --- tests/test_wimprates.py | 2 ++ wimprates/summary.py | 2 ++ wimprates/utils.py | 61 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 64 insertions(+), 1 deletion(-) diff --git a/tests/test_wimprates.py b/tests/test_wimprates.py index ac1fd9d..727d5f8 100644 --- a/tests/test_wimprates.py +++ b/tests/test_wimprates.py @@ -14,6 +14,8 @@ class TestBenchmarks(unittest.TestCase): opts = dict(mw=50, sigma_nucleon=1e-45, + save_cache=False, + load_cache=False, ) def test_elastic(self): ref = 30.39515403337126 diff --git a/wimprates/summary.py b/wimprates/summary.py index 8ac5c5c..33a92d9 100644 --- a/wimprates/summary.py +++ b/wimprates/summary.py @@ -2,12 +2,14 @@ Summary functions """ import numericalunits as nu +nu.reset_units(42) # Comment this line this when debugging dimensional analysis errors import wimprates as wr export, __all__ = wr.exporter() @export +@wr.save_result def rate_wimp(es, mw, sigma_nucleon, interaction='SI', detection_mechanism='elastic_nr', m_med=float('inf'), t=None, halo_model=None, diff --git a/wimprates/utils.py b/wimprates/utils.py index 890fb38..5be1768 100644 --- a/wimprates/utils.py +++ b/wimprates/utils.py @@ -1,13 +1,17 @@ import functools +import hashlib import inspect import os import pickle +from typing import Any, Callable import warnings from boltons.funcutils import wraps import numpy as np from tqdm.autonotebook import tqdm +import wimprates as wr + def exporter(): """Export utility modified from https://stackoverflow.com/a/41895194 @@ -108,6 +112,7 @@ def pairwise_log_transform(a, b): arr = np.concatenate((a, b), axis=1) return np.log(arr) + @export def deprecated(reason): """ @@ -124,4 +129,58 @@ def new_func(*args, **kwargs): ) return func(*args, **kwargs) return new_func - return decorator \ No newline at end of file + return decorator + + +def _generate_hash(*args, **kwargs): + # Create a string with the arguments and module version + + args_str = wr.__version__ + str(args) + + # Add keyword arguments to the string + args_str += "".join( + [f"{key}{kwargs[key]}" for key in sorted(kwargs) if key != "progress_bar"] + ) + + # Generate a SHA-256 hash + return hashlib.sha256(args_str.encode()).hexdigest() + + +@export +def save_result(func: Callable) -> Callable[..., Any]: + @wraps(func) + def wrapper( + *args, cache_dir: str="wimprates_cache", save_cache: bool=True, load_cache: bool=True, **kwargs + ): + # Define the cache directory + CACHE_DIR = cache_dir + + # Generate the hash based on function arguments and module version + func_name = func.__name__ + cache_key = _generate_hash(*args, **kwargs) + + # Define the path to the cache file + cache_file = os.path.join(CACHE_DIR, f"{func_name}_{cache_key}.pkl") + + # Check if the result is already cached + if load_cache and os.path.exists(cache_file): + with open(cache_file, "rb") as f: + print("Loading from cache: ", cache_file) + return pickle.load(f) + + # Compute the result + result = func(*args, **kwargs) + + if save_cache: + # Ensure cache directory exists + if not os.path.exists(CACHE_DIR): + os.makedirs(CACHE_DIR) + + # Save the result to the cache + with open(cache_file, "wb") as f: + pickle.dump(result, f) + print("Result saved to cache: ", cache_file) + + return result + + return wrapper