Skip to content

Commit

Permalink
Merge pull request columnflow#558 from haddadanas/fix_normalization_w…
Browse files Browse the repository at this point in the history
…eights_for_stitching

Small fixes for weight calculation and stitching
  • Loading branch information
pkausw authored Nov 7, 2024
2 parents d67acd2 + 15459d1 commit 69343b8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
35 changes: 28 additions & 7 deletions columnflow/production/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,12 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra
attribute. When py:attr`allow_stitching` is set to True, the sum of event weights is computed
for all datasets with a leaf process contained in the leaf processes of the
py:attr:`dataset_inst`. For stitching, the process_id needs to be reconstructed for each leaf
process on a per event basis.
process on a per event basis. Moreover, when stitching is enabled, an additional normalization
weight is computed for the inclusive dataset only and stored in a column named
`<weight_name>_inclusive_only`. This weight resembles the normalization weight for the
inclusive dataset, as if it were unstitched and should therefore only be applied, when using the
inclusive dataset as a standalone dataset.
"""
# read the process id column
process_id = np.asarray(events.process_id)
Expand All @@ -197,6 +202,10 @@ def normalization_weights(self: Producer, events: ak.Array, **kwargs) -> ak.Arra
norm_weight = events.mc_weight * process_weight
events = set_ak_column(events, self.weight_name, norm_weight, value_type=np.float32)

# If we are stitching, we also compute the inclusive weight for debugging purposes
if self.allow_stitching and self.dataset_inst == self.inclusive_dataset:
incl_norm_weight = events.mc_weight * self.inclusive_weight
events = set_ak_column(events, self.weight_name_incl, incl_norm_weight, value_type=np.float32)
return events


Expand All @@ -206,11 +215,6 @@ def normalization_weights_requires(self: Producer, reqs: dict) -> None:
Adds the requirements needed by the underlying py:attr:`task` to access selection stats into
*reqs*.
"""
if self.allow_stitching:
self.stitching_datasets = self.get_stitching_datasets()
else:
self.stitching_datasets = [self.dataset_inst]

# check that all datasets are known
for dataset in self.stitching_datasets:
if not self.config_inst.has_dataset(dataset):
Expand Down Expand Up @@ -292,7 +296,7 @@ def normalization_weights_setup(
# create a event weight lookup table
process_weight_table = sp.sparse.lil_matrix((1, max_id + 1), dtype=np.float32)
if self.allow_stitching and self.get_xsecs_from_inclusive_dataset:
inclusive_dataset = self.get_inclusive_dataset()
inclusive_dataset = self.inclusive_dataset
logger.info(f"using inclusive dataset {inclusive_dataset.name} for cross section lookup")

# get the branching ratios from the inclusive sample
Expand All @@ -310,6 +314,11 @@ def normalization_weights_setup(
f"{inclusive_dataset}",
)
inclusive_xsec = inclusive_proc.get_xsec(self.config_inst.campaign.ecm).nominal
self.inclusive_weight = (
lumi * inclusive_xsec / normalization_selection_stats[inclusive_dataset.name]["sum_mc_weight"]
if self.dataset_inst == inclusive_dataset
else 0
)
for process_id, br in branching_ratios.items():
sum_weights = merged_selection_stats["sum_mc_weight_per_process"][str(process_id)]
process_weight_table[0, process_id] = lumi * inclusive_xsec * br / sum_weights
Expand All @@ -330,7 +339,19 @@ def normalization_weights_init(self: Producer) -> None:
"""
Initializes the normalization weights producer by setting up the normalization weight column.
"""
if getattr(self, "dataset_inst", None) is None:
return

self.produces.add(self.weight_name)
if self.allow_stitching:
self.stitching_datasets = self.get_stitching_datasets()
self.inclusive_dataset = self.get_inclusive_dataset()
else:
self.stitching_datasets = [self.dataset_inst]

if self.allow_stitching and self.dataset_inst == self.inclusive_dataset:
self.weight_name_incl = f"{self.weight_name}_inclusive_only"
self.produces.add(self.weight_name_incl)


stitched_normalization_weights = normalization_weights.derive(
Expand Down
6 changes: 3 additions & 3 deletions columnflow/selection/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def increment_stats(
f"but found a sequence: {obj}",
)
if len(obj) == 1:
weights = obj[0]
weights = ak.values_astype(obj[0], np.float64)
elif len(obj) == 2:
weights, weight_mask = obj
weights, weight_mask = ak.values_astype(obj[0], np.float64), obj[1]
else:
raise Exception(f"cannot interpret as weights and optional mask: '{obj}'")
elif op == self.NUM:
weight_mask = obj
else: # SUM
weights = obj
weights = ak.values_astype(obj, np.float64)

# when mask is an Ellipsis, it cannot be AND joined to other masks, so convert to true mask
if weight_mask is Ellipsis:
Expand Down

0 comments on commit 69343b8

Please sign in to comment.