From 5309458622a679f96e3850e789f7ddf245997b75 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Wed, 6 Nov 2024 14:25:32 +0100 Subject: [PATCH] Add parse_selection_fields utility and enhance histogram task with selection handling --- columnflow/tasks/histograms.py | 31 ++++++++++++++++++++---- columnflow/util.py | 43 +++++++++++++++++++++++++++++++++- 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/columnflow/tasks/histograms.py b/columnflow/tasks/histograms.py index fb2907867..f6c581e1c 100644 --- a/columnflow/tasks/histograms.py +++ b/columnflow/tasks/histograms.py @@ -6,6 +6,9 @@ from __future__ import annotations +from functools import reduce +from operator import and_ + import luigi import law @@ -19,7 +22,7 @@ from columnflow.tasks.reduction import ReducedEventsUser from columnflow.tasks.production import ProduceColumns from columnflow.tasks.ml import MLEvaluation -from columnflow.util import dev_sandbox +from columnflow.util import dev_sandbox, parse_selection_fields class CreateHistograms( @@ -148,12 +151,16 @@ def run(self): self.config_inst.get_variable(var_name) for var_name in law.util.flatten(self.variable_tuples.values()) ) - for inp in ( + for inp in (( [variable_inst.expression] if isinstance(variable_inst.expression, str) # for variable_inst with custom expressions, read columns declared via aux key else variable_inst.x("inputs", []) - ) + ) + ( + parse_selection_fields(variable_inst.selection, only_fields=True) + if isinstance(variable_inst.selection, str) + else variable_inst.x("inputs", []) + )) } # empty float array to use when input files have no entries @@ -243,8 +250,24 @@ def expr(events, *args, **kwargs): if len(events) == 0 and not has_ak_column(events, route): return empty_f32 return route.apply(events, null_value=variable_inst.null_value) + + # prepare the selection + mask = ak.Array(np.ones(len(events), dtype=np.bool)) + sel = variable_inst.selection + if sel != "1": + if isinstance(sel, str): + selections = [ + op(Route(s).apply(events, null_value=variable_inst.null_value), val) + for (s, op, val) in parse_selection_fields(variable_inst.selection) + ] + selections = reduce(and_, selections) + mask = selections + elif callable(sel): + mask = sel(events) + else: + raise ValueError(f"invalid selection: {sel}") # apply it - fill_data[variable_inst.name] = expr(events) + fill_data[variable_inst.name] = ak.where(mask, expr(events), variable_inst.null_value) # fill it fill_hist( diff --git a/columnflow/util.py b/columnflow/util.py index cc5d8502f..a85237efa 100644 --- a/columnflow/util.py +++ b/columnflow/util.py @@ -11,7 +11,7 @@ "maybe_import", "import_plt", "import_ROOT", "import_file", "create_random_name", "expand_path", "real_path", "ensure_dir", "wget", "call_thread", "call_proc", "ensure_proxy", "dev_sandbox", "safe_div", "try_float", "try_complex", "try_int", "is_pattern", "is_regex", "pattern_matcher", - "dict_add_strict", "get_source_code", + "dict_add_strict", "get_source_code", "parse_selection_fields", "DotDict", "MockModule", "FunctionArgs", "ClassPropertyDescriptor", "classproperty", "DerivableMeta", "Derivable", ] @@ -27,6 +27,7 @@ import re import inspect import multiprocessing +import operator import multiprocessing.pool from functools import wraps from collections import OrderedDict @@ -532,6 +533,46 @@ def get_source_code(obj: Any, indent: str | int = None) -> str: return code +def parse_selection_fields(selection, only_fields=False): + """ + Parses the fields used in the selection string and returns them as a list. + """ + if not isinstance(selection, str): + return [] + if selection == "1": + return [] + + # Find mask concatenation + pattern = r"(\||&)" + sels = re.split(pattern, selection) if re.search(pattern, selection) else [selection] + + operations = { + ">": operator.gt, + ">=": operator.ge, + "<": operator.lt, + "<=": operator.le, + "==": operator.eq, + "!=": operator.ne, + } + + results = [] + op_pattern = r"([<>]=?|==|!=)" + for sel in sels: + if sel.strip() in ("|", "&"): + continue + if re.search(op_pattern, sel): + parts = re.split(op_pattern, sel, maxsplit=1) + parts = tuple(part.strip(" ()") for part in parts) + results.append((parts[0], operations[parts[1]], float(parts[2]))) + else: + results.append((sel, None, None)) + + if only_fields: + return [r[0] for r in results] + + return results + + class DotDict(OrderedDict): """ Subclass of *OrderedDict* that provides read and write access to items via attributes by