Skip to content

Commit

Permalink
add file_hash to config in TemplateSource
Browse files Browse the repository at this point in the history
  • Loading branch information
hammannr committed Oct 4, 2024
1 parent c60ba7f commit 48fba6d
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions alea/template_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from multihist import Hist1d
from inference_interface import template_to_multihist

from alea.utils import load_json
from alea.utils import load_json, compute_file_hash

logging.basicConfig(level=logging.INFO)
can_check_binning = True
Expand Down Expand Up @@ -50,6 +50,14 @@ def __init__(self, config: Dict, *args, **kwargs):
# override the default interpolation method
if "pdf_interpolation_method" not in config:
config["pdf_interpolation_method"] = "piecewise"

# add file hash to the config
format_named_parameters = self._get_format_named_parameters(config)
path = config["templatename"].format(**format_named_parameters)
file_hash = compute_file_hash(path)
print(f"{config['name']} File hash: {file_hash}")
config["file_hash"] = file_hash

super().__init__(config, *args, **kwargs)

def _check_binning(self, h, histogram_info: str):
Expand Down Expand Up @@ -96,9 +104,12 @@ def _check_binning(self, h, histogram_info: str):
@property
def format_named_parameters(self):
"""Get the named parameters in the config to dictionary format."""
format_named_parameters = {
k: self.config[k] for k in self.config.get("named_parameters", [])
}
format_named_parameters = self._get_format_named_parameters(self.config)
return format_named_parameters

@staticmethod
def _get_format_named_parameters(config: Dict) -> Dict:
format_named_parameters = {k: config[k] for k in config.get("named_parameters", [])}
return format_named_parameters

def build_histogram(self):
Expand Down

0 comments on commit 48fba6d

Please sign in to comment.