diff --git a/analysis_templates/cms_minimal/law.cfg b/analysis_templates/cms_minimal/law.cfg index 4bed389aa..8a2356008 100644 --- a/analysis_templates/cms_minimal/law.cfg +++ b/analysis_templates/cms_minimal/law.cfg @@ -47,6 +47,9 @@ default_keep_reduced_events: True # slightly to the left to avoid them being excluded from the last bin; None leads to automatic mode default_histogram_last_edge_inclusive: None +# boolean flag that, if True, sets the *hists* output of cf.SelectEvents and cf.MergeSelectionStats to optional +default_selection_hists_optional: True + # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False diff --git a/columnflow/tasks/selection.py b/columnflow/tasks/selection.py index a74c1fc44..8f032b532 100644 --- a/columnflow/tasks/selection.py +++ b/columnflow/tasks/selection.py @@ -15,12 +15,21 @@ from columnflow.tasks.external import GetDatasetLFNs from columnflow.tasks.calibration import CalibrateEvents from columnflow.production import Producer -from columnflow.util import maybe_import, ensure_proxy, dev_sandbox, safe_div +from columnflow.util import maybe_import, ensure_proxy, dev_sandbox, safe_div, DotDict np = maybe_import("numpy") ak = maybe_import("awkward") +logger = law.logger.get_logger(__name__) + +default_selection_hists_optional = law.config.get_expanded_bool( + "analysis", + "default_selection_hists_optional", + True, +) + + class SelectEvents( SelectorMixin, CalibratorsMixin, @@ -29,6 +38,9 @@ class SelectEvents( law.LocalWorkflow, RemoteWorkflow, ): + # flag that sets the *hists* output to optional if True + selection_hists_optional = default_selection_hists_optional + # default sandbox, might be overwritten by selector function sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) @@ -85,6 +97,7 @@ def output(self): outputs = { "results": self.target(f"results_{self.branch}.parquet"), "stats": self.target(f"stats_{self.branch}.json"), + "hists": self.target(f"hists_{self.branch}.pickle", optional=self.selection_hists_optional), } # add additional columns in case the selector produces some @@ -112,6 +125,7 @@ def run(self): result_chunks = {} column_chunks = {} stats = defaultdict(float) + hists = DotDict() # run the selector setup reader_targets = self.selector_inst.run_setup(reqs["selector"], inputs["selector"]) @@ -182,7 +196,7 @@ def run(self): ) # invoke the selection function - events, results = self.selector_inst(events, stats) + events, results = self.selector_inst(events, stats, hists=hists) # complain when there is no event mask if results.event is None: @@ -232,6 +246,7 @@ def run(self): # save stats outputs["stats"].dump(stats, indent=4, formatter="json") + outputs["hists"].dump(hists, formatter="pickle") # print some stats eff = safe_div(stats["num_events_selected"], stats["num_events"]) @@ -273,6 +288,12 @@ class MergeSelectionStats( DatasetTask, law.tasks.ForestMerge, ): + # flag that sets the *hists* output to optional if True + selection_hists_optional = default_selection_hists_optional + + # default sandbox, might be overwritten by selector function (needed to load hist objects) + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + # merge 25 stats files into 1 at every step of the merging cascade merge_factor = 25 @@ -300,7 +321,10 @@ def merge_requires(self, start_branch, end_branch): ) def merge_output(self): - return {"stats": self.target("stats.json")} + return { + "stats": self.target("stats.json"), + "hists": self.target("hists.pickle", optional=self.selection_hists_optional), + } def trace_merge_inputs(self, inputs): return super().trace_merge_inputs(inputs["collection"].targets.values()) @@ -312,12 +336,29 @@ def run(self): def merge(self, inputs, output): # merge input stats merged_stats = defaultdict(float) + merged_hists = {} + + # check that hists are present for all inputs + hist_inputs_exist = [inp["hists"].exists() for inp in inputs] + if any(hist_inputs_exist) and not all(hist_inputs_exist): + logger.warning( + f"For dataset {self.dataset_inst.name}, cf.SelectEvents has produced hists for " + "some but not all files. Histograms will not be merged and an empty pickle file will be stored.", + ) + for inp in inputs: stats = inp["stats"].load(formatter="json", cache=False) self.merge_counts(merged_stats, stats) + # merge hists only if all hists are present + if all(hist_inputs_exist): + for inp in inputs: + hists = inp["hists"].load(formatter="pickle", cache=False) + self.merge_counts(merged_hists, hists) + # write the output output["stats"].dump(merged_stats, indent=4, formatter="json", cache=False) + output["hists"].dump(merged_hists, formatter="pickle", cache=False) @classmethod def merge_counts(cls, dst: dict, src: dict) -> dict: diff --git a/law.cfg b/law.cfg index e8491204a..c946a7880 100644 --- a/law.cfg +++ b/law.cfg @@ -43,6 +43,9 @@ default_keep_reduced_events: True # slightly to the left to avoid them being excluded from the last bin; None leads to automatic mode default_histogram_last_edge_inclusive: None +# boolean flag that, if True, sets the *hists* output of cf.SelectEvents and cf.MergeSelectionStats to optional +default_selection_hists_optional: True + # wether or not the ensure_proxy decorator should be skipped, even if used by task's run methods skip_ensure_proxy: False