Skip to content

Commit

Permalink
Merge branch 'columnflow:master' into run3_working_branch
Browse files Browse the repository at this point in the history
  • Loading branch information
haddadanas authored Jun 3, 2024
2 parents 159b6e0 + e3fef7b commit 7d806dd
Show file tree
Hide file tree
Showing 6 changed files with 256 additions and 161 deletions.
19 changes: 12 additions & 7 deletions columnflow/tasks/framework/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -914,8 +914,8 @@ def __init__(self, *args, **kwargs):
# store dataset info for the global shift
key = (
self.global_shift_inst.name
if self.global_shift_inst and self.global_shift_inst.name in self.dataset_inst.info else
"nominal"
if self.global_shift_inst and self.global_shift_inst.name in self.dataset_inst.info
else "nominal"
)
self.dataset_info_inst = self.dataset_inst.get_info(key)

Expand All @@ -928,17 +928,22 @@ def store_parts(self):
return parts

@property
def file_merging_factor(self):
def file_merging_factor(self) -> int:
"""
Returns the number of files that are handled in one branch. Consecutive merging steps are
not handled yet.
Returns the number of files that are handled in one branch. When the :py:attr:`file_merging`
attribute is set to a positive integer, this value is returned. Otherwise, if the value is
zero, the original number of files is used instead.
Consecutive merging steps are not handled yet.
"""
n_files = self.dataset_info_inst.n_files

if isinstance(self.file_merging, int):
# interpret the file_merging attribute as the merging factor itself
# non-positive numbers mean "merge all in one"
n_merge = self.file_merging if self.file_merging > 0 else n_files
# zero means "merge all in one"
if self.file_merging < 0:
raise ValueError(f"invalid file_merging value {self.file_merging}")
n_merge = n_files if self.file_merging == 0 else self.file_merging
else:
# no merging at all
n_merge = 1
Expand Down
23 changes: 9 additions & 14 deletions columnflow/tasks/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ShiftSourcesMixin, WeightProducerMixin, ChunkedIOMixin,
)
from columnflow.tasks.framework.remote import RemoteWorkflow
from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents
from columnflow.tasks.reduction import ReducedEventsUser
from columnflow.tasks.production import ProduceColumns
from columnflow.tasks.ml import MLEvaluation
from columnflow.util import dev_sandbox
Expand All @@ -26,20 +26,17 @@ class CreateHistograms(
WeightProducerMixin,
MLModelsMixin,
ProducersMixin,
SelectorStepsMixin,
CalibratorsMixin,
ReducedEventsUser,
ChunkedIOMixin,
MergeReducedEventsUser,
law.LocalWorkflow,
RemoteWorkflow,
):
sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox"))

# upstream requirements
reqs = Requirements(
MergeReducedEventsUser.reqs,
ReducedEventsUser.reqs,
RemoteWorkflow.reqs,
MergeReducedEvents=MergeReducedEvents,
ProduceColumns=ProduceColumns,
MLEvaluation=MLEvaluation,
)
Expand All @@ -62,7 +59,7 @@ def workflow_requires(self):
reqs = super().workflow_requires()

# require the full merge forest
reqs["events"] = self.reqs.MergeReducedEvents.req(self, tree_index=-1)
reqs["events"] = self.reqs.ProvideReducedEvents.req(self)

if not self.pilot:
if self.producer_insts:
Expand All @@ -83,9 +80,7 @@ def workflow_requires(self):
return reqs

def requires(self):
reqs = {
"events": self.reqs.MergeReducedEvents.req(self, tree_index=self.branch, _exclude={"branch"}),
}
reqs = {"events": self.reqs.ProvideReducedEvents.req(self)}

if self.producer_insts:
reqs["producers"] = [
Expand All @@ -104,11 +99,11 @@ def requires(self):

return reqs

@MergeReducedEventsUser.maybe_dummy
workflow_condition = ReducedEventsUser.workflow_condition.copy()

@workflow_condition.output
def output(self):
return {
"hists": self.target(f"histograms__vars_{self.variables_repr}__{self.branch}.pickle"),
}
return {"hists": self.target(f"histograms__vars_{self.variables_repr}__{self.branch}.pickle")}

@law.decorator.log
@law.decorator.localize(input=True, output=False)
Expand Down
41 changes: 17 additions & 24 deletions columnflow/tasks/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from columnflow.tasks.framework.plotting import ProcessPlotSettingMixin, PlotBase
from columnflow.tasks.framework.remote import RemoteWorkflow
from columnflow.tasks.framework.decorators import view_output_plots
from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents
from columnflow.tasks.reduction import ReducedEventsUser
from columnflow.tasks.production import ProduceColumns
from columnflow.util import dev_sandbox, safe_div, DotDict, maybe_import
from columnflow.columnar_util import set_ak_column
Expand All @@ -38,10 +38,8 @@
class PrepareMLEvents(
MLModelDataMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
ChunkedIOMixin,
MergeReducedEventsUser,
ReducedEventsUser,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand All @@ -51,9 +49,8 @@ class PrepareMLEvents(

# upstream requirements
reqs = Requirements(
MergeReducedEventsUser.reqs,
ReducedEventsUser.reqs,
RemoteWorkflow.reqs,
MergeReducedEvents=MergeReducedEvents,
ProduceColumns=ProduceColumns,
)

Expand Down Expand Up @@ -104,7 +101,7 @@ def workflow_requires(self):
reqs = super().workflow_requires()

# require the full merge forest
reqs["events"] = self.reqs.MergeReducedEvents.req(self, tree_index=-1)
reqs["events"] = self.reqs.ProvideReducedEvents.req(self)

# add producer dependent requirements
if self.preparation_producer_inst:
Expand All @@ -121,9 +118,8 @@ def workflow_requires(self):
return reqs

def requires(self):
reqs = {
"events": self.reqs.MergeReducedEvents.req(self, tree_index=self.branch, _exclude={"branch"}),
}
reqs = {"events": self.reqs.ProvideReducedEvents.req(self)}

if self.preparation_producer_inst:
reqs["preparation_producer"] = self.preparation_producer_inst.run_requires()

Expand All @@ -136,7 +132,9 @@ def requires(self):

return reqs

@MergeReducedEventsUser.maybe_dummy
workflow_condition = ReducedEventsUser.workflow_condition.copy()

@workflow_condition.output
def output(self):
k = self.ml_model_inst.folds
outputs = {
Expand Down Expand Up @@ -194,7 +192,7 @@ def run(self):
num_fold_events = {f: 0 for f in range(self.ml_model_inst.folds)}

# iterate over chunks of events and columns
files = [inputs["events"]["collection"][0]["events"]]
files = [inputs["events"]["events"]]
if self.producer_insts:
files.extend([inp["columns"] for inp in inputs["producers"]])

Expand Down Expand Up @@ -596,10 +594,8 @@ def run(self):
class MLEvaluation(
MLModelMixin,
ProducersMixin,
SelectorMixin,
CalibratorsMixin,
ChunkedIOMixin,
MergeReducedEventsUser,
ReducedEventsUser,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand All @@ -612,10 +608,9 @@ class MLEvaluation(

# upstream requirements
reqs = Requirements(
MergeReducedEventsUser.reqs,
ReducedEventsUser.reqs,
RemoteWorkflow.reqs,
MLTraining=MLTraining,
MergeReducedEvents=MergeReducedEvents,
ProduceColumns=ProduceColumns,
)

Expand Down Expand Up @@ -669,7 +664,7 @@ def workflow_requires(self):
producers=(self.producers,),
)

reqs["events"] = self.reqs.MergeReducedEvents.req_different_branching(self)
reqs["events"] = self.reqs.ProvideReducedEvents.req(self)

# add producer dependent requirements
if self.preparation_producer_inst:
Expand All @@ -694,11 +689,7 @@ def requires(self):
producers=(self.producers,),
branch=-1,
),
"events": self.reqs.MergeReducedEvents.req_different_branching(
self,
tree_index=self.branch,
branch=-1,
),
"events": self.reqs.ProvideReducedEvents.req(self, _exclude=self.exclude_params_branch),
}
if self.preparation_producer_inst:
reqs["preparation_producer"] = self.preparation_producer_inst.run_requires()
Expand All @@ -712,7 +703,9 @@ def requires(self):

return reqs

@MergeReducedEventsUser.maybe_dummy
workflow_condition = ReducedEventsUser.workflow_condition.copy()

@workflow_condition.output
def output(self):
return {"mlcolumns": self.target(f"mlcolumns_{self.branch}.parquet")}

Expand Down
24 changes: 10 additions & 14 deletions columnflow/tasks/production.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@
import law

from columnflow.tasks.framework.base import Requirements, AnalysisTask, wrapper_factory
from columnflow.tasks.framework.mixins import (
CalibratorsMixin, SelectorStepsMixin, ProducerMixin, ChunkedIOMixin, ProducersMixin,
)
from columnflow.tasks.framework.mixins import ProducerMixin, ProducersMixin, ChunkedIOMixin
from columnflow.tasks.framework.remote import RemoteWorkflow
from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents
from columnflow.tasks.reduction import ReducedEventsUser
from columnflow.util import dev_sandbox


class ProduceColumns(
ProducerMixin,
SelectorStepsMixin,
CalibratorsMixin,
ChunkedIOMixin,
MergeReducedEventsUser,
ReducedEventsUser,
law.LocalWorkflow,
RemoteWorkflow,
):
Expand All @@ -30,9 +26,8 @@ class ProduceColumns(

# upstream requirements
reqs = Requirements(
MergeReducedEventsUser.reqs,
ReducedEventsUser.reqs,
RemoteWorkflow.reqs,
MergeReducedEvents=MergeReducedEvents,
)

# register shifts found in the chosen producer to this task
Expand All @@ -45,7 +40,7 @@ def workflow_requires(self):
reqs = super().workflow_requires()

# require the full merge forest
reqs["events"] = self.reqs.MergeReducedEvents.req(self, tree_index=-1)
reqs["events"] = self.reqs.ProvideReducedEvents.req(self)

# add producer dependent requirements
reqs["producer"] = self.producer_inst.run_requires()
Expand All @@ -54,11 +49,13 @@ def workflow_requires(self):

def requires(self):
return {
"events": self.reqs.MergeReducedEvents.req(self, tree_index=self.branch, _exclude={"branch"}),
"events": self.reqs.ProvideReducedEvents.req(self),
"producer": self.producer_inst.run_requires(),
}

@MergeReducedEventsUser.maybe_dummy
workflow_condition = ReducedEventsUser.workflow_condition.copy()

@workflow_condition.output
def output(self):
outputs = {}

Expand Down Expand Up @@ -105,7 +102,7 @@ def run(self):

# prepare inputs for localization
with law.localize_file_targets(
[inputs["events"]["collection"][0]["events"], *reader_targets.values()],
[inputs["events"]["events"], *reader_targets.values()],
mode="r",
) as inps:
# iterate over chunks of events and diffs
Expand Down Expand Up @@ -177,7 +174,6 @@ def run(self):
class ProduceColumnsWrapper(
ProduceColumnsWrapperBase,
ProducersMixin,

):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down
Loading

0 comments on commit 7d806dd

Please sign in to comment.