From 4efa4e7dbe9d44ff5a1a4163ccc9028363ff2935 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Thu, 12 Jan 2023 13:15:17 +0100 Subject: [PATCH 01/36] Add MergeMLEvaluation task. --- columnflow/tasks/ml.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 4a64ff40b..52a497bfb 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -465,3 +465,48 @@ def run(self): require_cls=MLEvaluation, enable=["configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets"], ) + + +class MergeMLEvaluation( + MergeReducedEventsUser, + MLModelMixin, + ProducersMixin, + SelectorMixin, + CalibratorsMixin, + ChunkedIOMixin, + law.tasks.ForestMerge, + RemoteWorkflow, +): + sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") + + # recursively merge 20 files into one + merge_factor = 20 + + # default upstream dependency task classes + dep_MLEvaluation = MLEvaluation + + def create_branch_map(self): + # DatasetTask implements a custom branch map, but we want to use the one in ForestMerge + return law.tasks.ForestMerge.create_branch_map(self) + + def merge_workflow_requires(self): + return self.dep_MLEvaluation.req(self, _exclude={"branches"}) + + def merge_requires(self, start_branch, end_branch): + return [ + self.dep_MLEvaluation.req(self, branch=b) + for b in range(start_branch, end_branch) + ] + + def merge_output(self): + return self.target("ml_cols.parquet") + + def merge(self, inputs, output): + law.pyarrow.merge_parquet_task(self, inputs, output) + + +MergeMLEvaluationWrapper = wrapper_factory( + base_cls=AnalysisTask, + require_cls=MergeMLEvaluation, + enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"], +) From 50594618e9736652df8b9965546d77a0b0e2ead2 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Thu, 12 Jan 2023 13:15:17 +0100 Subject: [PATCH 02/36] Add MergeMLEvaluation task. --- columnflow/tasks/ml.py | 45 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 422354920..6ea1d2308 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -565,3 +565,48 @@ def run(self): require_cls=MLEvaluation, enable=["configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets"], ) + + +class MergeMLEvaluation( + MergeReducedEventsUser, + MLModelMixin, + ProducersMixin, + SelectorMixin, + CalibratorsMixin, + ChunkedIOMixin, + law.tasks.ForestMerge, + RemoteWorkflow, +): + sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") + + # recursively merge 20 files into one + merge_factor = 20 + + # default upstream dependency task classes + dep_MLEvaluation = MLEvaluation + + def create_branch_map(self): + # DatasetTask implements a custom branch map, but we want to use the one in ForestMerge + return law.tasks.ForestMerge.create_branch_map(self) + + def merge_workflow_requires(self): + return self.dep_MLEvaluation.req(self, _exclude={"branches"}) + + def merge_requires(self, start_branch, end_branch): + return [ + self.dep_MLEvaluation.req(self, branch=b) + for b in range(start_branch, end_branch) + ] + + def merge_output(self): + return self.target("ml_cols.parquet") + + def merge(self, inputs, output): + law.pyarrow.merge_parquet_task(self, inputs, output) + + +MergeMLEvaluationWrapper = wrapper_factory( + base_cls=AnalysisTask, + require_cls=MergeMLEvaluation, + enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"], +) From 4b7d8d8112d99135fc6cb88df3cb8fad0781eab1 Mon Sep 17 00:00:00 2001 From: Mathis Frahm Date: Wed, 22 Feb 2023 16:42:35 +0100 Subject: [PATCH 03/36] starting point for PlotMLEvaluation task --- columnflow/plotting/plot_ml_evaluation.py | 22 ++++ columnflow/tasks/ml.py | 138 ++++++++++++++++++++-- 2 files changed, 152 insertions(+), 8 deletions(-) create mode 100644 columnflow/plotting/plot_ml_evaluation.py diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py new file mode 100644 index 000000000..0a4568fb8 --- /dev/null +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -0,0 +1,22 @@ +# coding: utf-8 + +""" +Example plot functions for ML Evaluation +""" + +from __future__ import annotations + +from columnflow.util import maybe_import + +ak = maybe_import("awkward") +od = maybe_import("order") +plt = maybe_import("matplotlib.pyplot") + +def plot_ml_evaluation( + events: ak.Array, + config_inst: od.Config, + category_inst: od.Category, + **kwargs, +) -> plt.Figure: + + return None diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 6ea1d2308..56c30d3b8 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -7,6 +7,8 @@ import law import luigi +from collections import OrderedDict + from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory from columnflow.tasks.framework.mixins import ( CalibratorsMixin, @@ -16,11 +18,15 @@ MLModelTrainingMixin, MLModelMixin, ChunkedIOMixin, + CategoriesMixin, + SelectorStepsMixin, ) +from columnflow.tasks.framework.plotting import ProcessPlotSettingMixin, PlotBase from columnflow.tasks.framework.remote import RemoteWorkflow +from columnflow.tasks.framework.decorators import view_output_plots from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents from columnflow.tasks.production import ProduceColumns -from columnflow.util import dev_sandbox, safe_div +from columnflow.util import dev_sandbox, safe_div, DotDict class PrepareMLEvents( @@ -568,12 +574,13 @@ def run(self): class MergeMLEvaluation( - MergeReducedEventsUser, MLModelMixin, ProducersMixin, SelectorMixin, CalibratorsMixin, ChunkedIOMixin, + DatasetTask, + # MergeReducedEventsUser, law.tasks.ForestMerge, RemoteWorkflow, ): @@ -582,27 +589,31 @@ class MergeMLEvaluation( # recursively merge 20 files into one merge_factor = 20 - # default upstream dependency task classes - dep_MLEvaluation = MLEvaluation + # upstream requirements + reqs = Requirements( + RemoteWorkflow.reqs, + MLEvaluation=MLEvaluation, + ) def create_branch_map(self): # DatasetTask implements a custom branch map, but we want to use the one in ForestMerge return law.tasks.ForestMerge.create_branch_map(self) def merge_workflow_requires(self): - return self.dep_MLEvaluation.req(self, _exclude={"branches"}) + return self.reqs.MLEvaluation.req(self, _exclude={"branches"}) def merge_requires(self, start_branch, end_branch): return [ - self.dep_MLEvaluation.req(self, branch=b) + self.reqs.MLEvaluation.req(self, branch=b) for b in range(start_branch, end_branch) ] def merge_output(self): - return self.target("ml_cols.parquet") + return {"mlcolumns": self.target("mlcolumns.parquet")} def merge(self, inputs, output): - law.pyarrow.merge_parquet_task(self, inputs, output) + inputs = [inp["mlcolumns"] for inp in inputs] + law.pyarrow.merge_parquet_task(self, inputs, output["mlcolumns"]) MergeMLEvaluationWrapper = wrapper_factory( @@ -610,3 +621,114 @@ def merge(self, inputs, output): require_cls=MergeMLEvaluation, enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"], ) + + +class PlotMLEvaluation( + ProcessPlotSettingMixin, + CategoriesMixin, + MLModelMixin, + ProducersMixin, + SelectorStepsMixin, + CalibratorsMixin, + law.LocalWorkflow, + RemoteWorkflow, +): + + sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") + + plot_function = PlotBase.plot_function.copy( + default="columnflow.plotting.plot_ml_evaluation.plot_ml_evaluation", + add_default_to_description=True, + ) + + # upstream requirements + reqs = Requirements( + RemoteWorkflow.reqs, + MergeMLEvaluation=MergeMLEvaluation, + ) + + def store_parts(self): + parts = super().store_parts() + parts.insert_before("version", "plot", f"datasets_{self.datasets_repr}") + return parts + + def create_branch_map(self): + return [ + DotDict({"category": cat_name}) + for cat_name in sorted(self.categories) + ] + + def requires(self): + return { + d: self.reqs.MergeMLEvaluation.req( + self, + dataset=d, + branch=-1, + _exclude={"branches"}, + ) + for d in self.datasets + } + + def workflow_requires(self, only_super: bool = False): + reqs = super().workflow_requires() + if only_super: + return reqs + + reqs["merged_ml_evaluation"] = self.requires_from_branch() + + return reqs + + def output(self): + b = self.branch_data + return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}.pdf") + + @law.decorator.log + @view_output_plots + def run(self): + import awkward as ak + + category_inst = self.config_inst.get_category(self.branch_data.category) + leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] + process_insts = list(map(self.config_inst.get_process, self.processes)) + sub_process_insts = { + proc: [sub for sub, _, _ in proc.walk_processes(include_self=True)] + for proc in process_insts + } + + with self.publish_step(f"plotting in {category_inst.name}"): + all_events = OrderedDict() + for dataset, inp in self.input().items(): + dataset_inst = self.config_inst.get_dataset(dataset) + if len(dataset_inst.processes) != 1: + raise NotImplementedError( + f"dataset {dataset_inst.name} has {len(dataset_inst.processes)} assigned, " + "which is not implemented yet.", + ) + + events = ak.from_parquet(self.input()[dataset].path) + + # masking with leaf categories + category_mask = False + for leaf in leaf_category_insts: + category_mask = ak.where(ak.any(events.category_ids == leaf.id, axis=1), True, category_mask) + + events = events[category_mask] + + # loop per process + for process_inst in process_insts: + # skip when the dataset is already known to not contain any sub process + if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])): + continue + + # TODO: use process_ids to correctly assign events to processes e.g. for sample stitching + if process_inst.name in all_events.keys(): + all_events[process_inst.name] = ak.concatenate([all_events[process_inst.name], events]) + else: + all_events[process_inst] = events + + figs, _ = self.call_plot_func( + self.plot_function, + all_events, + self.config_inst, + category_inst, + ) From a879dd6a439bf4f2b9c6174c7d27d05e14dfef53 Mon Sep 17 00:00:00 2001 From: Marcel R Date: Tue, 6 Jun 2023 18:18:02 +0200 Subject: [PATCH 04/36] Linting. --- columnflow/plotting/plot_ml_evaluation.py | 2 +- modules/law | 2 +- modules/order | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 0a4568fb8..ae81be025 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -12,11 +12,11 @@ od = maybe_import("order") plt = maybe_import("matplotlib.pyplot") + def plot_ml_evaluation( events: ak.Array, config_inst: od.Config, category_inst: od.Category, **kwargs, ) -> plt.Figure: - return None diff --git a/modules/law b/modules/law index eac3987bd..f54534a3f 160000 --- a/modules/law +++ b/modules/law @@ -1 +1 @@ -Subproject commit eac3987bd273a1c56ab2c94a2000e5623d379ddf +Subproject commit f54534a3f1971aef9bf48e89be60995bd0544925 diff --git a/modules/order b/modules/order index 000383347..2b2b69579 160000 --- a/modules/order +++ b/modules/order @@ -1 +1 @@ -Subproject commit 0003833474f638c6728edb5500374fff091ed212 +Subproject commit 2b2b69579a867ce5f1fa0a89fe0b378155ddd2a8 From 2e1ebad40a588980b0f5bf4f145f9ebe4d0392ea Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Tue, 3 Oct 2023 15:46:14 +0200 Subject: [PATCH 05/36] Added the MergeMLEvaluation task and Wrapper --- columnflow/tasks/ml.py | 73 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 6d939e6dc..072f773fe 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -608,3 +608,76 @@ def run(self): require_cls=MLEvaluation, enable=["configs", "skip_configs", "shifts", "skip_shifts", "datasets", "skip_datasets"], ) + + +class MergeMLEvaluation( + MLModelDataMixin, + DatasetTask, + law.tasks.ForestMerge, + RemoteWorkflow, +): + sandbox = dev_sandbox(law.config.get("analysis", "default_columnar_sandbox")) + + # disable the shift parameter + shift = None + effective_shift = None + allow_empty_shift = True + + # in each step, merge 10 into 1 + merge_factor = 10 + + allow_empty_ml_model = False + + # upstream requirements + reqs = Requirements( + RemoteWorkflow.reqs, + MLEvaluation=MLEvaluation, + ) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # tell ForestMerge to not cache the internal merging structure by default, + # (this is enabled in merge_workflow_requires) + self._cache_forest = False + + def create_branch_map(self): + # DatasetTask implements a custom branch map, but we want to use the one in ForestMerge + return law.tasks.ForestMerge.create_branch_map(self) + + def merge_workflow_requires(self): + req = self.reqs.MLEvaluation.req(self, _exclude={"branches"}) + + # if the merging stats exist, allow the forest to be cached + self._cache_forest = req.merging_stats_exist + + return req + + def merge_requires(self, start_leaf, end_leaf): + return [ + self.reqs.MLEvaluation.req(self, branch=i) + for i in range(start_leaf, end_leaf) + ] + + def trace_merge_inputs(self, inputs): + return super().trace_merge_inputs([inp["mlcolumns"] for inp in inputs]) + + def merge_output(self): + return {"mlcolumns": self.target(f"{self.ml_model}/mlcolumns.parquet")} + + @law.decorator.log + def run(self): + return super().run() + + def merge(self, inputs, output): + if not self.is_leaf(): + inputs = [inp["mlcolumns"] for inp in inputs] + + law.pyarrow.merge_parquet_task(self, inputs, output["mlcolumns"]) + + +MergeMLEvaluationWrapper = wrapper_factory( + base_cls=AnalysisTask, + require_cls=MergeMLEvaluation, + enable=["configs", "skip_configs", "datasets", "skip_datasets"], +) From ffb38538bc201881975b6e0ebf0e71c9d7b5637a Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 5 Oct 2023 16:07:33 +0200 Subject: [PATCH 06/36] fixed linting issues --- columnflow/tasks/ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 955f24c2a..36be3f88d 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -748,7 +748,7 @@ def run(self): "which is not implemented yet.", ) - events = ak.from_parquet(self.input()[dataset]['mlcolumns'].path) + events = ak.from_parquet(self.input()[dataset]["mlcolumns"].path) # masking with leaf categories category_mask = False From 726da748d55a5dead6957819c4ecd63a7b9ee047 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 5 Oct 2023 16:23:11 +0200 Subject: [PATCH 07/36] Merged with upstream master --- .all-contributorsrc | 9 +++++++++ README.md | 1 + columnflow/inference/__init__.py | 3 +++ columnflow/inference/cms/datacard.py | 30 +++++++++++++--------------- columnflow/tasks/ml.py | 10 +++++++++- 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/.all-contributorsrc b/.all-contributorsrc index 6b4464691..98aaf356b 100644 --- a/.all-contributorsrc +++ b/.all-contributorsrc @@ -118,6 +118,15 @@ "contributions": [ "ideas" ] + }, + { + "login": "haddadanas", + "name": "haddadanas", + "avatar_url": "https://avatars.githubusercontent.com/u/103462379?v=4", + "profile": "https://github.com/haddadanas", + "contributions": [ + "code" + ] } ], "commitType": "docs" diff --git a/README.md b/README.md index 3b686c30d..3c02735ac 100644 --- a/README.md +++ b/README.md @@ -138,6 +138,7 @@ For a better overview of the tasks that are triggered by the commands below, che Johannes Lange
Johannes Lange

💻 BalduinLetzer
BalduinLetzer

💻 JanekMoels
JanekMoels

🤔 + haddadanas
haddadanas

💻 diff --git a/columnflow/inference/__init__.py b/columnflow/inference/__init__.py index 5a350b886..71f971c0b 100644 --- a/columnflow/inference/__init__.py +++ b/columnflow/inference/__init__.py @@ -251,6 +251,7 @@ def category_spec( config_data_datasets: Sequence[str] | None = None, data_from_processes: Sequence[str] | None = None, mc_stats: float | tuple | None = None, + empty_bin_value: float = 1e-5, ) -> DotDict: """ Returns a dictionary representing a category (interchangeably called bin or channel in other @@ -264,6 +265,7 @@ def category_spec( when *config_data_datasets* is not defined, make of a fake data contribution. - *mc_stats*: Either *None* to disable MC stat uncertainties, or a float or tuple of floats to control the options of MC stat options. + - *empty_bin_value*: When bins are no content, they are filled with this value. """ return DotDict([ ("name", str(name)), @@ -272,6 +274,7 @@ def category_spec( ("config_data_datasets", list(map(str, config_data_datasets or []))), ("data_from_processes", list(map(str, data_from_processes or []))), ("mc_stats", mc_stats), + ("empty_bin_value", empty_bin_value), ("processes", []), ]) diff --git a/columnflow/inference/cms/datacard.py b/columnflow/inference/cms/datacard.py index b8ecee735..d006f84bd 100644 --- a/columnflow/inference/cms/datacard.py +++ b/columnflow/inference/cms/datacard.py @@ -315,7 +315,6 @@ def write( def write_shapes( self, shapes_path: str, - fill_empty_bins: float = 1e-5, ) -> tuple[ dict[str, dict[str, float]], dict[str, dict[str, dict[str, tuple[float, float]]]], @@ -330,8 +329,6 @@ def write_shapes( "category -> process -> parameter -> (down effect, up effect)", - the datacard pattern for extracting nominal shapes, and - the datacard pattern for extracting systematic shapes. - - When *fill_empty_bins* is non-zero, empty (and negative!) bins are filled with this value. """ # create the directory shapes_path = real_path(shapes_path) @@ -352,15 +349,19 @@ def write_shapes( # create the output file out_file = uproot.recreate(shapes_path) - # helper to fill empty bins in-place - def fill_empty(h): - value = h.view().value - mask = value <= 0 - value[mask] = fill_empty_bins - h.view().variance[mask] = fill_empty_bins - # iterate through shapes for cat_name, hists in self.histograms.items(): + cat_obj = self.inference_model_inst.get_category(cat_name) + + # helper to fill empty bins in-place + fill_empty = lambda h: None + if cat_obj.empty_bin_value: + def fill_empty(h): + value = h.view().value + mask = value <= 0 + value[mask] = cat_obj.empty_bin_value + h.view().variance[mask] = cat_obj.empty_bin_value + _rates = rates[cat_name] = OrderedDict() _effects = effects[cat_name] = OrderedDict() for proc_name, _hists in hists.items(): @@ -376,8 +377,7 @@ def fill_empty(h): # nominal shape h_nom = _hists["nominal"].copy() * scale - if fill_empty_bins: - fill_empty(h_nom) + fill_empty(h_nom) nom_name = nom_pattern.format(category=cat_name, process=proc_name) out_file[nom_name] = h_nom _rates[proc_name] = h_nom.sum().value @@ -450,9 +450,8 @@ def get_shapes(param_name): continue # empty bins are always filled - if fill_empty_bins: - fill_empty(h_down) - fill_empty(h_up) + fill_empty(h_down) + fill_empty(h_up) # save them when they represent real shapes if param_obj.type.is_shape: @@ -478,7 +477,6 @@ def get_shapes(param_name): ) # dedicated data handling - cat_obj = self.inference_model_inst.get_category(cat_name) if cat_obj.config_data_datasets: if "data" not in hists: raise Exception( diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 36be3f88d..46ae3222c 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -623,10 +623,13 @@ class MergeMLEvaluation( CalibratorsMixin, ChunkedIOMixin, DatasetTask, - # MergeReducedEventsUser, law.tasks.ForestMerge, RemoteWorkflow, ): + """ + Task to merge events for a dataset, where the `MLEvaluation` produces multiple parquet files. + The task serves as a helper task for plotting the ML evaluation results in the `PlotMLResults` task. + """ sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") # recursively merge 20 files into one @@ -663,6 +666,11 @@ def merge(self, inputs, output): base_cls=AnalysisTask, require_cls=MergeMLEvaluation, enable=["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"], + docs=""" + Wrapper task to merge events for multiple datasets. + + :enables: ["configs", "skip_configs", "datasets", "skip_datasets", "shifts", "skip_shifts"] + """, ) From 819afea30d28fb3edc9392fd1f8c0abad78f1b37 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 19 Oct 2023 10:25:23 +0200 Subject: [PATCH 08/36] name change + docs --- columnflow/tasks/ml.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 46ae3222c..fdac57c6a 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -674,7 +674,7 @@ def merge(self, inputs, output): ) -class PlotMLEvaluation( +class PlotMLResultsBase( ProcessPlotSettingMixin, CategoriesMixin, MLModelMixin, @@ -684,12 +684,20 @@ class PlotMLEvaluation( law.LocalWorkflow, RemoteWorkflow, ): + """A base class, used for the implementation of the ML plotting tasks. This class implements + a `plot_function` parameter for choosing a desired plotting function and a `prepare_inputs` method, + that returns a dict with the chosen datasets. + Raises: + NotImplementedError: This error is raised if a givin dataset contains more than one process. + """ sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_ml_evaluation.plot_ml_evaluation", add_default_to_description=True, + description="The full path of the desired plot function, that is to be called on the inputs. \ + The full path should be givin using the dot notation", ) # upstream requirements @@ -756,7 +764,7 @@ def run(self): "which is not implemented yet.", ) - events = ak.from_parquet(self.input()[dataset]["mlcolumns"].path) + events = ak.from_parquet(inp["mlcolumns"].path) # masking with leaf categories category_mask = False @@ -776,10 +784,17 @@ def run(self): all_events[process_inst.name] = ak.concatenate([all_events[process_inst.name], events]) else: all_events[process_inst] = events - + from IPython import embed; embed() figs, _ = self.call_plot_func( self.plot_function, events=all_events, config_inst=self.config_inst, category_inst=category_inst, ) + +class PlotMLResults(PlotMLResultsBase): + + def output(self): + output = {"plot": super().output(), + "array": self.target(f"plot__proc_{self.processes_repr}.parquet")} + return output \ No newline at end of file From 46693905cd293d6b3efbbf167a4269bff5864904 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Tue, 24 Oct 2023 15:44:17 +0200 Subject: [PATCH 09/36] working PlotMlResults --- columnflow/tasks/ml.py | 108 ++++++++++++++++++++++++++++------------- 1 file changed, 74 insertions(+), 34 deletions(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index b0acced00..74095e188 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -693,10 +693,12 @@ class PlotMLResultsBase( ): """A base class, used for the implementation of the ML plotting tasks. This class implements a `plot_function` parameter for choosing a desired plotting function and a `prepare_inputs` method, - that returns a dict with the chosen datasets. + that returns a dict with the chosen events. Raises: NotImplementedError: This error is raised if a givin dataset contains more than one process. + Exception: This exception is raised if `plot_sub_processes` is used without providing the + `process_ids` column in the data """ sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") @@ -707,6 +709,20 @@ class PlotMLResultsBase( The full path should be givin using the dot notation", ) + skip_processes = law.CSVParameter( + default=("",), + description="names of processes to skip; These processes will not be displayed int he plot. \ + config; default: ('*',)", + brace_expand=True, + ) + + plot_sub_processes = luigi.BoolParameter( + default=False, + significant=False, + description="when True, each process is divided into the different subprocesses \ + which will be used as classes for the plot; default: False", + ) + # upstream requirements reqs = Requirements( RemoteWorkflow.reqs, @@ -748,9 +764,7 @@ def output(self): b = self.branch_data return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}.pdf") - @law.decorator.log - @view_output_plots - def run(self): + def prepare_inputs(self): import awkward as ak category_inst = self.config_inst.get_category(self.branch_data.category) @@ -761,47 +775,73 @@ def run(self): for proc in process_insts } - with self.publish_step(f"plotting in {category_inst.name}"): - all_events = OrderedDict() - for dataset, inp in self.input().items(): - dataset_inst = self.config_inst.get_dataset(dataset) - if len(dataset_inst.processes) != 1: - raise NotImplementedError( - f"dataset {dataset_inst.name} has {len(dataset_inst.processes)} assigned, " - "which is not implemented yet.", - ) + all_events = OrderedDict() + for dataset, inp in self.input().items(): + dataset_inst = self.config_inst.get_dataset(dataset) + if len(dataset_inst.processes) != 1: + raise NotImplementedError( + f"dataset {dataset_inst.name} has {len(dataset_inst.processes)} assigned, " + "which is not implemented yet.", + ) - events = ak.from_parquet(inp["mlcolumns"].path) + events = ak.from_parquet(inp["mlcolumns"].path) - # masking with leaf categories - category_mask = False - for leaf in leaf_category_insts: - category_mask = ak.where(ak.any(events.category_ids == leaf.id, axis=1), True, category_mask) + # masking with leaf categories + category_mask = False + for leaf in leaf_category_insts: + category_mask = ak.where(ak.any(events.category_ids == leaf.id, axis=1), True, category_mask) - events = events[category_mask] + events = events[category_mask] - # loop per process - for process_inst in process_insts: - # skip when the dataset is already known to not contain any sub process - if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])): - continue + # loop per process + for process_inst in process_insts: + # skip when the dataset is already known to not contain any sub process + if not any(map(dataset_inst.has_process, sub_process_insts[process_inst])): + continue - # TODO: use process_ids to correctly assign events to processes e.g. for sample stitching + if not self.plot_sub_processes: if process_inst.name in all_events.keys(): - all_events[process_inst.name] = ak.concatenate([all_events[process_inst.name], events]) + all_events[process_inst.name] = ak.concatenate([all_events[process_inst.name], + getattr(events, self.ml_model)]) else: - all_events[process_inst] = events - from IPython import embed; embed() - figs, _ = self.call_plot_func( - self.plot_function, - events=all_events, - config_inst=self.config_inst, - category_inst=category_inst, - ) + all_events[process_inst] = getattr(events, self.ml_model) + else: + if "process_ids" in events.fields: + for sub_process in sub_process_insts[process_inst]: + if sub_process.name in self.skip_processes: + continue + + process_mask = ak.where(events.process_ids == sub_process.id, True, False) + if sub_process.name in all_events.keys(): + all_events[sub_process.name] = ( + ak.concatenate([all_events[sub_process.name], + getattr(events[process_mask], self.ml_model)])) + else: + all_events[sub_process.name] = getattr(events[process_mask], self.ml_model) + else: + raise Exception("No `process_ids` column stored in the events! " + f"Process selection for {dataset} cannot not be applied!") + return all_events + class PlotMLResults(PlotMLResultsBase): + # override the plot_function parameter to be able to only choose between CM and ROC + def output(self): output = {"plot": super().output(), "array": self.target(f"plot__proc_{self.processes_repr}.parquet")} return output + + @law.decorator.log + @view_output_plots + def run(self): + category_inst = self.config_inst.get_category(self.branch_data.category) + with self.publish_step(f"plotting in {category_inst.name}"): # TODO what does this do? + all_events = self.prepare_inputs() + figs, _ = self.call_plot_func( # TODO implement the plotting function to work + self.plot_function, + events=all_events, + config_inst=self.config_inst, + category_inst=category_inst, + ) From 53f0008dd77823d9de34a703d6a768e098ba81f3 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 27 Oct 2023 17:33:42 +0200 Subject: [PATCH 10/36] working cm but buggy plot ratio --- columnflow/plotting/plot_ml_evaluation.py | 200 +++++++++++++++++++++- columnflow/tasks/ml.py | 10 +- 2 files changed, 207 insertions(+), 3 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index ae81be025..6b74136cb 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -1,7 +1,7 @@ # coding: utf-8 """ -Example plot functions for ML Evaluation +Useful plot functions for ML Evaluation """ from __future__ import annotations @@ -10,7 +10,27 @@ ak = maybe_import("awkward") od = maybe_import("order") +np = maybe_import("numpy") +sci = maybe_import("scinum") plt = maybe_import("matplotlib.pyplot") +hep = maybe_import("mplhep") +colors = maybe_import("matplotlib.colors") + +# Define a CF custom color map +cf_colors = { + "cf_green_cmap": colors.ListedColormap(["#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", + "#305126", "#325A25", "#356224", "#386B22", "#3B7520", "#3F7F1E", "#43891B", + "#479418", "#4C9F14", "#52AA10", "#58B60C", "#5FC207", "#67cf02"]), + "cf_ygb_cmap": colors.ListedColormap(["#003675", "#005B83", "#008490", "#009A83", "#00A368", "#00AC49", "#00B428", + "#00BC06", "#0CC300", "#39C900", "#67cf02", "#72DB02", "#7EE605", "#8DF207", + "#9CFD09", "#AEFF0B", "#C1FF0E", "#D5FF10", "#EBFF12", "#FFFF14"]), + "cf_cmap": colors.ListedColormap(["#002C9C", "#00419F", "#0056A2", "#006BA4", "#0081A7", "#0098AA", "#00ADAB", + "#00B099", "#00B287", "#00B574", "#00B860", "#00BB4C", "#00BD38", "#00C023", + "#00C20D", "#06C500", "#1EC800", "#36CA00", "#4ECD01", "#67cf02"]), + "viridis": colors.ListedColormap(["#263DA8", "#1652CC", "#1063DB", "#1171D8", "#1380D5", "#0E8ED0", "#089DCC", + "#0DA7C2", "#1DAFB3", "#2DB7A3", "#52BA91", "#73BD80", "#94BE71", "#B2BC65", + "#D0BA59", "#E1BF4A", "#F4C53A", "#FCD12B", "#FAE61C", "#F9F90E"]), +} def plot_ml_evaluation( @@ -20,3 +40,181 @@ def plot_ml_evaluation( **kwargs, ) -> plt.Figure: return None + + +def plot_cm( + events: dict, + config_inst: od.Config, + category_inst: od.Category, + sample_weights: np.ndarray = None, + normalization: str = "row", + skip_uncertainties: bool = False, + *args, + **kwargs, +) -> tuple[plt.Figure, np.ndarray]: + """ Generates the figure of the confusion matrix given the output of the nodes + and a true labels array. The Cronfusion matrix can also be weighted + + Args: + events (dict): dictionary with the true labels as keys and the model output of \ + the events as values. + config_inst (od.Config): used configuration for the plot + category_inst (od.Category): used category instance, for which the plot is created + sample_weights (np.ndarray, optional): sample weights of the events. Defaults to None. + normalization (str, optional): type of normalization of the confusion matrix. Defaults to "row". + skip_uncertainties (bool, optional): calculate errors of the cm elements. Defaults to False. + + Returns: + plt.Figure: The plot to be saved in the task. The matrix has + + Raises: + AssertionError: If both predictions and labels have mismatched shapes, \ + or if `weights` is not `None` and its shape doesn't match `predictions`. + """ + + # defining some useful properties and output shapes + true_lables = list(events.keys()) + pred_lables = [s.removeprefix('score_') for s in list(events.values())[0].fields] + return_type = np.float32 if sample_weights else np.int32 + mat_shape = (len(true_lables), len(pred_lables)) + + def get_conf_matrix() -> np.ndarray: + # TODO implement weights assertion and processing + + result = np.zeros(shape=mat_shape, dtype=return_type) + counts = np.zeros(shape=mat_shape, dtype=return_type) + + # looping over the datasets + for ind, pred in enumerate(events.values()): + # remove awkward structure to use the numpy logic + pred = ak.to_numpy(pred) + pred = pred.view(float).reshape((pred.size, len(pred_lables))) + + # create predictions of the model output + pred = np.argmax(pred, axis=-1) + + for index, count in zip(*np.unique(pred, return_counts=True)): + result[ind, index] += count + counts[ind, index] += count + + if not skip_uncertainties: + vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count)))) + result = vecNumber(result, counts) + + # Normalize Matrix if needed + if normalization is not None: + valid = {"row": 1, "column": 0} + assert (normalization in valid.keys()), ( + f"\"{normalization}\" is no valid argument for normalization. If givin, normalization \ + should only take \"row\" or \"column\"") + + row_sums = result.sum(axis=valid.get(normalization)) + result = result / row_sums[:, np.newaxis] + + return result + + def plot_confusion_matrix(cm: np.ndarray, + title="", + colormap: str = "cf_cmap", + cmap_label: str = "Accuracy", + digits: int = 3, + ) -> plt.figure: + """plots a givin confusion matrix + + Args: + cm (np.ndarray): _description_ + title (str, optional): _description_. Defaults to "Confusion matrix". + colormap (str, optional): _description_. Defaults to "cf_cmap". + cmap_label (str, optional): _description_. Defaults to "Accuracy". + digits (int, optional): _description_. Defaults to 3. + + Returns: + plt.figure: _description_ + """ + + # Some useful variables and functions + n_processes = cm.shape[0] + n_classes = cm.shape[1] + cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) + + def scale_font(class_number: int) -> int: + """function (defined emperically) to scale the font""" + if class_number > 10: + return max(8, int(- 8 / 10 * class_number + 23)) + else: + return int(class_number / 14 * (9 * class_number - 177) + 510 / 7) + + def get_errors(matrix): + """Useful for seperating the error from the data""" + if matrix.dtype.name == "object": + get_errors_vec = np.vectorize(lambda x: x.get(sci.UP, unc=True)) + return get_errors_vec(matrix) + else: + return np.zeros_like(matrix) + + def value_text(i, j): + """Format the inputs as 'Number +- Uncertainty' """ + import re + def fmt(v): + s = "{{:.{}f}}".format(digits).format(v) + return s if re.sub(r"(0|\.)", "", s) else ("<" + s[:-1] + "1") + if skip_uncertainties: + return fmt(values[i][j]) + else: + return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j]))) + + # Get values and (if available) their uncertenties + values = cm.astype(np.float32) + uncs = get_errors(cm) + + + # Setting some plotting values + thresh = values.max() / 2. + font_size = scale_font(n_classes) + + # Remove Major ticks and edit minor ticks + # plt.style.use(hep.style.CMS) + # hep.cms.label(llabel="private work", + # rlabel=title if title is not None else "") + minor_tick_length = max(int(120 / n_classes), 12) + minor_tick_width = max(6 / n_classes, 0.6) + xtick_marks = np.arange(n_classes) + ytick_marks = np.arange(n_processes) + plt.tick_params(axis="both", which="major", + bottom=False, top=False, left=False, right=False) + plt.tick_params(axis="both", which="minor", + bottom=True, top=True, left=True, right=True, + length=minor_tick_length, width=minor_tick_width) + plt.xticks(xtick_marks + 0.5, minor=True) + plt.yticks(ytick_marks + 0.49, minor=True) + plt.xticks(xtick_marks, pred_lables, rotation=0)#, fontsize=font_size) + plt.yticks(ytick_marks, true_lables)#, fontsize=font_size) + plt.xlabel("Predicted process", loc="right", labelpad=10) #,fontsize=font_size + 3) + plt.ylabel("True process", loc="top", labelpad=15) #, fontsize=font_size) + plt.tight_layout() + + # plotting + plt.imshow(values, interpolation="nearest", cmap=cmap) + + # Justify Color bar + colorbar = plt.colorbar(fraction=0.0471, pad=0.01) + colorbar.set_label(label=cmap_label)#, fontsize=font_size + 3) + # colorbar.ax.tick_params(labelsize=font_size) + plt.clim(0, max(1, values.max())) + + # Add Matrix Elemtns + # offset = 0.12 if len(class_labels) > 2 and len(class_labels) < 6 else 0.1 + # size_offset = 1 if len(class_labels) > 5 else 3 + for i in range(values.shape[0]): + for j in range(values.shape[1]): + plt.text(j, i, value_text(i, j), #fontdict={"size": font_size}, + horizontalalignment="center", verticalalignment="center", + color="white" if values[i, j] < thresh else "black") + + # Add Axes and plot labels + from IPython import embed; embed() + + return plt.gcf() + + cm = get_conf_matrix() + fig = plot_confusion_matrix(cm, *args, **kwargs) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 74095e188..4f6417524 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -723,6 +723,12 @@ class PlotMLResultsBase( which will be used as classes for the plot; default: False", ) + skip_uncertainties = luigi.BoolParameter( + default=False, + significant=False, + description="when True, uncertainties are not displayed in the table; default: False", + ) + # upstream requirements reqs = Requirements( RemoteWorkflow.reqs, @@ -835,11 +841,11 @@ def output(self): @law.decorator.log @view_output_plots - def run(self): + def run(self, *args, **kwargs): category_inst = self.config_inst.get_category(self.branch_data.category) with self.publish_step(f"plotting in {category_inst.name}"): # TODO what does this do? all_events = self.prepare_inputs() - figs, _ = self.call_plot_func( # TODO implement the plotting function to work + figs, _ = self.call_plot_func( self.plot_function, events=all_events, config_inst=self.config_inst, From 7c40a9cda8dde6b603c9c895c4206c83d9f2845d Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 30 Oct 2023 18:56:53 +0100 Subject: [PATCH 11/36] commit for further testing --- columnflow/plotting/plot_ml_evaluation.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 6b74136cb..a605358de 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -13,7 +13,7 @@ np = maybe_import("numpy") sci = maybe_import("scinum") plt = maybe_import("matplotlib.pyplot") -hep = maybe_import("mplhep") +hep = maybe_import("mplhep") colors = maybe_import("matplotlib.colors") # Define a CF custom color map @@ -187,10 +187,10 @@ def fmt(v): length=minor_tick_length, width=minor_tick_width) plt.xticks(xtick_marks + 0.5, minor=True) plt.yticks(ytick_marks + 0.49, minor=True) - plt.xticks(xtick_marks, pred_lables, rotation=0)#, fontsize=font_size) - plt.yticks(ytick_marks, true_lables)#, fontsize=font_size) - plt.xlabel("Predicted process", loc="right", labelpad=10) #,fontsize=font_size + 3) - plt.ylabel("True process", loc="top", labelpad=15) #, fontsize=font_size) + plt.xticks(xtick_marks, pred_lables, rotation=0, fontsize=font_size) + plt.yticks(ytick_marks, true_lables, fontsize=font_size) + plt.xlabel("Predicted process", loc="right", labelpad=10,fontsize=font_size + 3) + plt.ylabel("True process", loc="top", labelpad=15, fontsize=font_size) plt.tight_layout() # plotting @@ -198,8 +198,8 @@ def fmt(v): # Justify Color bar colorbar = plt.colorbar(fraction=0.0471, pad=0.01) - colorbar.set_label(label=cmap_label)#, fontsize=font_size + 3) - # colorbar.ax.tick_params(labelsize=font_size) + colorbar.set_label(label=cmap_label, fontsize=font_size + 3) + colorbar.ax.tick_params(labelsize=font_size) plt.clim(0, max(1, values.max())) # Add Matrix Elemtns @@ -207,7 +207,7 @@ def fmt(v): # size_offset = 1 if len(class_labels) > 5 else 3 for i in range(values.shape[0]): for j in range(values.shape[1]): - plt.text(j, i, value_text(i, j), #fontdict={"size": font_size}, + plt.text(j, i, value_text(i, j), fontdict={"size": font_size}, horizontalalignment="center", verticalalignment="center", color="white" if values[i, j] < thresh else "black") From 7eff01f6711ac4fa18113b2ebe05623afa2f03ef Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Wed, 1 Nov 2023 20:33:52 +0100 Subject: [PATCH 12/36] working plotting cm --- columnflow/plotting/plot_ml_evaluation.py | 139 +++++++++++++--------- columnflow/tasks/ml.py | 58 +++++++-- 2 files changed, 128 insertions(+), 69 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index a605358de..744e4a9b9 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -13,7 +13,7 @@ np = maybe_import("numpy") sci = maybe_import("scinum") plt = maybe_import("matplotlib.pyplot") -hep = maybe_import("mplhep") +hep = maybe_import("mplhep") colors = maybe_import("matplotlib.colors") # Define a CF custom color map @@ -49,6 +49,7 @@ def plot_cm( sample_weights: np.ndarray = None, normalization: str = "row", skip_uncertainties: bool = False, + x_labels: list[str] = None, *args, **kwargs, ) -> tuple[plt.Figure, np.ndarray]: @@ -73,12 +74,12 @@ def plot_cm( """ # defining some useful properties and output shapes - true_lables = list(events.keys()) - pred_lables = [s.removeprefix('score_') for s in list(events.values())[0].fields] + true_labels = list(events.keys()) + pred_labels = [s.removeprefix("score_") for s in list(events.values())[0].fields] return_type = np.float32 if sample_weights else np.int32 - mat_shape = (len(true_lables), len(pred_lables)) + mat_shape = (len(true_labels), len(pred_labels)) - def get_conf_matrix() -> np.ndarray: + def get_conf_matrix(*args, **kwargs) -> np.ndarray: # TODO implement weights assertion and processing result = np.zeros(shape=mat_shape, dtype=return_type) @@ -88,7 +89,7 @@ def get_conf_matrix() -> np.ndarray: for ind, pred in enumerate(events.values()): # remove awkward structure to use the numpy logic pred = ak.to_numpy(pred) - pred = pred.view(float).reshape((pred.size, len(pred_lables))) + pred = pred.view(float).reshape((pred.size, len(pred_labels))) # create predictions of the model output pred = np.argmax(pred, axis=-1) @@ -102,11 +103,11 @@ def get_conf_matrix() -> np.ndarray: result = vecNumber(result, counts) # Normalize Matrix if needed - if normalization is not None: + if normalization: valid = {"row": 1, "column": 0} assert (normalization in valid.keys()), ( - f"\"{normalization}\" is no valid argument for normalization. If givin, normalization \ - should only take \"row\" or \"column\"") + f"\"{normalization}\" is no valid argument for normalization. If givin, normalization " + "should only take \"row\" or \"column\"") row_sums = result.sum(axis=valid.get(normalization)) result = result / row_sums[:, np.newaxis] @@ -114,10 +115,13 @@ def get_conf_matrix() -> np.ndarray: return result def plot_confusion_matrix(cm: np.ndarray, - title="", - colormap: str = "cf_cmap", - cmap_label: str = "Accuracy", - digits: int = 3, + title="", + colormap: str = "cf_cmap", + cmap_label: str = "Accuracy", + digits: int = 3, + x_labels: list[str] = None, + *args, + **kwargs, ) -> plt.figure: """plots a givin confusion matrix @@ -131,18 +135,22 @@ def plot_confusion_matrix(cm: np.ndarray, Returns: plt.figure: _description_ """ + from mpl_toolkits.axes_grid1 import make_axes_locatable - # Some useful variables and functions - n_processes = cm.shape[0] - n_classes = cm.shape[1] - cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) + def calculate_font_size(): + # Get cell width + bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + width, height = fig.dpi * bbox.width, fig.dpi * bbox.height - def scale_font(class_number: int) -> int: - """function (defined emperically) to scale the font""" - if class_number > 10: - return max(8, int(- 8 / 10 * class_number + 23)) - else: - return int(class_number / 14 * (9 * class_number - 177) + 510 / 7) + # Size of each cell in pixels + cell_width = width / n_classes + cell_height = height / n_processes + + # Calculate the font size based on the cell size to ensure font is not too large + font_size = min(cell_width, cell_height) / 10 + font_size = max(min(font_size, 18), 8) + + return font_size def get_errors(matrix): """Useful for seperating the error from the data""" @@ -163,58 +171,71 @@ def fmt(v): else: return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j]))) + # create the plot + plt.style.use(hep.style.CMS) + fig, ax = plt.subplots(dpi=300) + + # Some useful variables and functions + n_processes = cm.shape[0] + n_classes = cm.shape[1] + cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) + x_labels = x_labels if x_labels else [f"out{i}" for i in range(n_classes)] + font_ax = 20 + font_label = 20 + font_text = calculate_font_size() + # Get values and (if available) their uncertenties values = cm.astype(np.float32) uncs = get_errors(cm) - - # Setting some plotting values - thresh = values.max() / 2. - font_size = scale_font(n_classes) - # Remove Major ticks and edit minor ticks - # plt.style.use(hep.style.CMS) - # hep.cms.label(llabel="private work", - # rlabel=title if title is not None else "") - minor_tick_length = max(int(120 / n_classes), 12) + minor_tick_length = max(int(120 / n_classes), 12) / 2 minor_tick_width = max(6 / n_classes, 0.6) xtick_marks = np.arange(n_classes) ytick_marks = np.arange(n_processes) - plt.tick_params(axis="both", which="major", - bottom=False, top=False, left=False, right=False) - plt.tick_params(axis="both", which="minor", - bottom=True, top=True, left=True, right=True, - length=minor_tick_length, width=minor_tick_width) - plt.xticks(xtick_marks + 0.5, minor=True) - plt.yticks(ytick_marks + 0.49, minor=True) - plt.xticks(xtick_marks, pred_lables, rotation=0, fontsize=font_size) - plt.yticks(ytick_marks, true_lables, fontsize=font_size) - plt.xlabel("Predicted process", loc="right", labelpad=10,fontsize=font_size + 3) - plt.ylabel("True process", loc="top", labelpad=15, fontsize=font_size) - plt.tight_layout() - # plotting - plt.imshow(values, interpolation="nearest", cmap=cmap) + # plot the data + im = ax.imshow(values, interpolation="nearest", cmap=cmap) - # Justify Color bar - colorbar = plt.colorbar(fraction=0.0471, pad=0.01) - colorbar.set_label(label=cmap_label, fontsize=font_size + 3) - colorbar.ax.tick_params(labelsize=font_size) - plt.clim(0, max(1, values.max())) + # Plot settings + thresh = values.max() / 2. + ax.tick_params(axis="both", which="major", + bottom=False, top=False, left=False, right=False) + ax.tick_params(axis="both", which="minor", + bottom=True, top=True, left=True, right=True, + length=minor_tick_length, width=minor_tick_width) + ax.set_xticks(xtick_marks + 0.5, minor=True) + ax.set_yticks(ytick_marks + 0.5, minor=True) + ax.set_xticks(xtick_marks) + ax.set_xticklabels(x_labels, rotation=0, fontsize=font_label) + ax.set_yticks(ytick_marks) + ax.set_yticklabels(true_labels, fontsize=font_label) + ax.set_xlabel("Predicted process", loc="right", labelpad=10, fontsize=font_ax) + ax.set_ylabel("True process", loc="top", labelpad=15, fontsize=font_ax) + + # adding a color bar on a new axis and adjusting its values + divider = make_axes_locatable(ax) + cax = divider.append_axes("right", size="20%", pad=0.10) + colorbar = fig.colorbar(im, cax=cax) + colorbar.set_label(label=cmap_label, fontsize=font_ax) + colorbar.ax.tick_params(labelsize=font_ax - 5) + im.set_clim(0, max(1, values.max())) # Add Matrix Elemtns - # offset = 0.12 if len(class_labels) > 2 and len(class_labels) < 6 else 0.1 - # size_offset = 1 if len(class_labels) > 5 else 3 for i in range(values.shape[0]): for j in range(values.shape[1]): - plt.text(j, i, value_text(i, j), fontdict={"size": font_size}, + ax.text(j, i, value_text(i, j), fontdict={"size": font_text}, horizontalalignment="center", verticalalignment="center", color="white" if values[i, j] < thresh else "black") - # Add Axes and plot labels - from IPython import embed; embed() + # final touches + hep.cms.label(ax=ax, llabel="private work", + rlabel=title if title else "") + plt.tight_layout() + + return fig - return plt.gcf() + cm = get_conf_matrix(*args, **kwargs) + fig = plot_confusion_matrix(cm, x_labels=x_labels, *args, **kwargs) - cm = get_conf_matrix() - fig = plot_confusion_matrix(cm, *args, **kwargs) + return fig, cm diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 4f6417524..d53d4f458 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -697,7 +697,7 @@ class PlotMLResultsBase( Raises: NotImplementedError: This error is raised if a givin dataset contains more than one process. - Exception: This exception is raised if `plot_sub_processes` is used without providing the + ValueError: This error is raised if `plot_sub_processes` is used without providing the `process_ids` column in the data """ sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") @@ -768,7 +768,7 @@ def workflow_requires(self, only_super: bool = False): def output(self): b = self.branch_data - return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}.pdf") + return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}{self.plot_suffix}.pdf") def prepare_inputs(self): import awkward as ak @@ -798,7 +798,6 @@ def prepare_inputs(self): category_mask = ak.where(ak.any(events.category_ids == leaf.id, axis=1), True, category_mask) events = events[category_mask] - # loop per process for process_inst in process_insts: # skip when the dataset is already known to not contain any sub process @@ -810,7 +809,7 @@ def prepare_inputs(self): all_events[process_inst.name] = ak.concatenate([all_events[process_inst.name], getattr(events, self.ml_model)]) else: - all_events[process_inst] = getattr(events, self.ml_model) + all_events[process_inst.name] = getattr(events, self.ml_model) else: if "process_ids" in events.fields: for sub_process in sub_process_insts[process_inst]: @@ -825,14 +824,46 @@ def prepare_inputs(self): else: all_events[sub_process.name] = getattr(events[process_mask], self.ml_model) else: - raise Exception("No `process_ids` column stored in the events! " + raise ValueError("No `process_ids` column stored in the events! " f"Process selection for {dataset} cannot not be applied!") return all_events class PlotMLResults(PlotMLResultsBase): - + """ + A task that generates plots for machine learning results. + + This task generates plots for machine learning results, based on the given + configuration and category. The plots can be either a confusion matrix (CM) or a + receiver operating characteristic (ROC) curve. This task uses the output of the + MergeMLEvaluation task as input and saves the plots with the corresponding array + used to create the plot. + + Attributes: + plot_function (str): The name of the plot function to use. \ + Can be either "plot_cm" or "plot_roc". + processes_repr (str): A string representation of the number of \ + processes used to generate the plot(s). + config_inst (Config): An instance of the Config class that contains \ + the configuration for this task. + branch_data (BranchData): An instance of the BranchData class that \ + contains the input data for this task. + """ # override the plot_function parameter to be able to only choose between CM and ROC + plot_function = luigi.ChoiceParameter( + default="plot_cm", + choices=["cm", "roc"], + description="The name of the plot function to use. \ + Can be either 'plot_cm' or 'plot_roc'.", + ) + + def prepare_plot_parameters(self): + params = self.get_plot_parameters() + + # parse x_label from general settings + x_labels = params.general_settings.get("x_labels", None) + if x_labels: + params.general_settings["x_labels"] = x_labels.replace("&", "$").split(";") def output(self): output = {"plot": super().output(), @@ -841,13 +872,20 @@ def output(self): @law.decorator.log @view_output_plots - def run(self, *args, **kwargs): + def run(self): + func_path = {"cm": "columnflow.plotting.plot_ml_evaluation.plot_cm", + "roc": "columnflow.plotting.plot_ml_evaluation.plot_roc"} category_inst = self.config_inst.get_category(self.branch_data.category) - with self.publish_step(f"plotting in {category_inst.name}"): # TODO what does this do? + self.prepare_plot_parameters() + with self.publish_step(f"plotting in {category_inst.name}"): all_events = self.prepare_inputs() - figs, _ = self.call_plot_func( - self.plot_function, + fig, array = self.call_plot_func( + func_path.get(self.plot_function, self.plot_function), events=all_events, config_inst=self.config_inst, category_inst=category_inst, + **self.get_plot_parameters(), ) + + self.output()["plot"].dump(fig, formatter="mpl") + self.output()["array"].dump(array, formatter="pickle") From a5eb2c885857dea6a84174c6164384de03468762 Mon Sep 17 00:00:00 2001 From: haddadanas <103462379+haddadanas@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:20:01 +0100 Subject: [PATCH 13/36] fix color bar size bug --- columnflow/plotting/plot_ml_evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 744e4a9b9..5196d712b 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -215,7 +215,7 @@ def fmt(v): # adding a color bar on a new axis and adjusting its values divider = make_axes_locatable(ax) - cax = divider.append_axes("right", size="20%", pad=0.10) + cax = divider.append_axes("right", size="5%", pad=0.10) colorbar = fig.colorbar(im, cax=cax) colorbar.set_label(label=cmap_label, fontsize=font_ax) colorbar.ax.tick_params(labelsize=font_ax - 5) From 2807fb1a31b75ecc2d869b61fe298865230a7b91 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 3 Nov 2023 14:29:11 +0100 Subject: [PATCH 14/36] added sample weights + better docs --- columnflow/plotting/plot_ml_evaluation.py | 66 ++++++++++++++++------- 1 file changed, 46 insertions(+), 20 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 5196d712b..6d3bc9ff4 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -46,10 +46,11 @@ def plot_cm( events: dict, config_inst: od.Config, category_inst: od.Category, - sample_weights: np.ndarray = None, + sample_weights: list | bool = False, normalization: str = "row", skip_uncertainties: bool = False, x_labels: list[str] = None, + y_labels: list[str] = None, *args, **kwargs, ) -> tuple[plt.Figure, np.ndarray]: @@ -61,32 +62,50 @@ def plot_cm( the events as values. config_inst (od.Config): used configuration for the plot category_inst (od.Category): used category instance, for which the plot is created - sample_weights (np.ndarray, optional): sample weights of the events. Defaults to None. + sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not + givin the weights are calculated based on the number of eventsDefaults to None. normalization (str, optional): type of normalization of the confusion matrix. Defaults to "row". skip_uncertainties (bool, optional): calculate errors of the cm elements. Defaults to False. + x_labels (list[str], optional): labels for the x-axis. Defaults to None. + y_labels (list[str], optional): labels for the y-axis. Defaults to None. + *args: Additional arguments to pass to the function. + **kwargs: Additional keyword arguments to pass to the function. Returns: - plt.Figure: The plot to be saved in the task. The matrix has + tuple[plt.Figure, np.ndarray]: The resulting plot and the confusion matrix. Raises: AssertionError: If both predictions and labels have mismatched shapes, \ or if `weights` is not `None` and its shape doesn't match `predictions`. - """ + AssertionError: If `normalization` is not one of `None`, `"row"`, `"column"`. + """ # defining some useful properties and output shapes true_labels = list(events.keys()) pred_labels = [s.removeprefix("score_") for s in list(events.values())[0].fields] return_type = np.float32 if sample_weights else np.int32 mat_shape = (len(true_labels), len(pred_labels)) - def get_conf_matrix(*args, **kwargs) -> np.ndarray: - # TODO implement weights assertion and processing - + def create_sample_weights(sample_weights) -> np.ndarray: + if not sample_weights: + return {label: 1 for label in true_labels} + else: + assert (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))), ( + f"Shape of sample_weights {sample_weights.shape} does not match " + f"shape of predictions {mat_shape}") + if isinstance(sample_weights, bool): + size = {label: len(event) for label, event in events.items()} + mean = np.mean(list(size.values())) + sample_weights = {label: mean / length for label, length in size.items()} + return sample_weights + + def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: result = np.zeros(shape=mat_shape, dtype=return_type) counts = np.zeros(shape=mat_shape, dtype=return_type) + sample_weights = create_sample_weights(sample_weights) # looping over the datasets - for ind, pred in enumerate(events.values()): + for ind, (dataset, pred) in enumerate(events.items()): # remove awkward structure to use the numpy logic pred = ak.to_numpy(pred) pred = pred.view(float).reshape((pred.size, len(pred_labels))) @@ -95,7 +114,7 @@ def get_conf_matrix(*args, **kwargs) -> np.ndarray: pred = np.argmax(pred, axis=-1) for index, count in zip(*np.unique(pred, return_counts=True)): - result[ind, index] += count + result[ind, index] += count * sample_weights[dataset] counts[ind, index] += count if not skip_uncertainties: @@ -115,25 +134,31 @@ def get_conf_matrix(*args, **kwargs) -> np.ndarray: return result def plot_confusion_matrix(cm: np.ndarray, - title="", + title: str = "", colormap: str = "cf_cmap", cmap_label: str = "Accuracy", digits: int = 3, x_labels: list[str] = None, + y_labels: list[str] = None, *args, **kwargs, ) -> plt.figure: - """plots a givin confusion matrix + """ + Plots a confusion matrix. Args: - cm (np.ndarray): _description_ - title (str, optional): _description_. Defaults to "Confusion matrix". - colormap (str, optional): _description_. Defaults to "cf_cmap". - cmap_label (str, optional): _description_. Defaults to "Accuracy". - digits (int, optional): _description_. Defaults to 3. + cm (np.ndarray): The confusion matrix to plot. + title (str): The title of the plot, displayed in the top right corner. Defaults to ''. + colormap (str): The name of the colormap to use. Defaults to "cf_cmap". + cmap_label (str): The label of the colorbar. Defaults to "Accuracy". + digits (int): The number of digits to display for each value in the matrix. Defaults to 3. + x_labels (list[str]): The labels for the x-axis. If not provided, the labels will be "out" + y_labels (list[str]): The labels for the y-axis. If not provided, the dataset labels are used. + *args: Additional arguments to pass to the function. + **kwargs: Additional keyword arguments to pass to the function. Returns: - plt.figure: _description_ + plt.figure: The resulting plot. """ from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -180,6 +205,7 @@ def fmt(v): n_classes = cm.shape[1] cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) x_labels = x_labels if x_labels else [f"out{i}" for i in range(n_classes)] + y_labels = y_labels if y_labels else true_labels font_ax = 20 font_label = 20 font_text = calculate_font_size() @@ -209,7 +235,7 @@ def fmt(v): ax.set_xticks(xtick_marks) ax.set_xticklabels(x_labels, rotation=0, fontsize=font_label) ax.set_yticks(ytick_marks) - ax.set_yticklabels(true_labels, fontsize=font_label) + ax.set_yticklabels(y_labels, fontsize=font_label) ax.set_xlabel("Predicted process", loc="right", labelpad=10, fontsize=font_ax) ax.set_ylabel("True process", loc="top", labelpad=15, fontsize=font_ax) @@ -235,7 +261,7 @@ def fmt(v): return fig - cm = get_conf_matrix(*args, **kwargs) - fig = plot_confusion_matrix(cm, x_labels=x_labels, *args, **kwargs) + cm = get_conf_matrix(sample_weights, *args, **kwargs) + fig = plot_confusion_matrix(cm, x_labels=x_labels, y_labels=y_labels, *args, **kwargs) return fig, cm From 23cf6aa8f9a3716b9e564aee9db1297ee3caa0b2 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 6 Nov 2023 14:46:43 +0100 Subject: [PATCH 15/36] Add ample_weights and save multiple plots in pdf --- columnflow/plotting/plot_ml_evaluation.py | 51 ++++++++++++++++------- columnflow/tasks/ml.py | 8 ++-- 2 files changed, 41 insertions(+), 18 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 6d3bc9ff4..0392d9e61 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -5,6 +5,7 @@ """ from __future__ import annotations +from typing import Iterable from columnflow.util import maybe_import @@ -42,6 +43,39 @@ def plot_ml_evaluation( return None +def create_sample_weights(sample_weights: Iterable | bool | None, + events: dict, + true_labels: Iterable + ) -> dict: + """ Helper function to creates the sample weights for the events, if needed. + + Args: + sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not + givin the weights are calculated based on the number of eventsDefaults to None. + events (dict): dictionary with the true labels as keys and the model output of \ + the events as values. + true_labels (np.ndarray): true labels of the events + + Returns: + dict: sample weights of the events + + Raises: + AssertionError: If both predictions and labels have mismatched shapes, \ + or if `weights` is not `None` and its shape doesn't match `predictions`. + """ + if not sample_weights: + return {label: 1 for label in true_labels} + else: + assert (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))), ( + f"Shape of sample_weights {sample_weights.shape} does not match " + f"shape of predictions {len(true_labels)}") + if isinstance(sample_weights, bool): + size = {label: len(event) for label, event in events.items()} + mean = np.mean(list(size.values())) + sample_weights = {label: mean / length for label, length in size.items()} + return sample_weights + + def plot_cm( events: dict, config_inst: od.Config, @@ -86,23 +120,10 @@ def plot_cm( return_type = np.float32 if sample_weights else np.int32 mat_shape = (len(true_labels), len(pred_labels)) - def create_sample_weights(sample_weights) -> np.ndarray: - if not sample_weights: - return {label: 1 for label in true_labels} - else: - assert (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))), ( - f"Shape of sample_weights {sample_weights.shape} does not match " - f"shape of predictions {mat_shape}") - if isinstance(sample_weights, bool): - size = {label: len(event) for label, event in events.items()} - mean = np.mean(list(size.values())) - sample_weights = {label: mean / length for label, length in size.items()} - return sample_weights - def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: result = np.zeros(shape=mat_shape, dtype=return_type) counts = np.zeros(shape=mat_shape, dtype=return_type) - sample_weights = create_sample_weights(sample_weights) + sample_weights = create_sample_weights(sample_weights, events, true_labels) # looping over the datasets for ind, (dataset, pred) in enumerate(events.items()): @@ -264,4 +285,4 @@ def fmt(v): cm = get_conf_matrix(sample_weights, *args, **kwargs) fig = plot_confusion_matrix(cm, x_labels=x_labels, y_labels=y_labels, *args, **kwargs) - return fig, cm + return [fig], cm diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index d53d4f458..b465259f6 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -873,19 +873,21 @@ def output(self): @law.decorator.log @view_output_plots def run(self): + from matplotlib.backends.backend_pdf import PdfPages func_path = {"cm": "columnflow.plotting.plot_ml_evaluation.plot_cm", "roc": "columnflow.plotting.plot_ml_evaluation.plot_roc"} category_inst = self.config_inst.get_category(self.branch_data.category) self.prepare_plot_parameters() with self.publish_step(f"plotting in {category_inst.name}"): all_events = self.prepare_inputs() - fig, array = self.call_plot_func( + figs, array = self.call_plot_func( func_path.get(self.plot_function, self.plot_function), events=all_events, config_inst=self.config_inst, category_inst=category_inst, **self.get_plot_parameters(), ) - - self.output()["plot"].dump(fig, formatter="mpl") self.output()["array"].dump(array, formatter="pickle") + with PdfPages(self.output()["plot"].abspath) as pdf: + for f in figs: + f.savefig(pdf, format="pdf") From c8d52383ab4693f9b35c5338c903bdbd4410115a Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 6 Nov 2023 14:47:24 +0100 Subject: [PATCH 16/36] fixed trailing comma --- columnflow/plotting/plot_ml_evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 0392d9e61..28ab1fef9 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -45,7 +45,7 @@ def plot_ml_evaluation( def create_sample_weights(sample_weights: Iterable | bool | None, events: dict, - true_labels: Iterable + true_labels: Iterable, ) -> dict: """ Helper function to creates the sample weights for the events, if needed. From b8506246732b9cd775887175203e80c9d62103bf Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 6 Nov 2023 17:46:53 +0100 Subject: [PATCH 17/36] linting fixes --- columnflow/plotting/plot_ml_evaluation.py | 238 +++++++++++----------- columnflow/tasks/ml.py | 56 ++--- 2 files changed, 151 insertions(+), 143 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 28ab1fef9..adf2fc6d9 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -5,8 +5,10 @@ """ from __future__ import annotations -from typing import Iterable +import re + +from columnflow.types import Sequence from columnflow.util import maybe_import ak = maybe_import("awkward") @@ -17,102 +19,91 @@ hep = maybe_import("mplhep") colors = maybe_import("matplotlib.colors") -# Define a CF custom color map +# Define a CF custom color maps cf_colors = { - "cf_green_cmap": colors.ListedColormap(["#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", - "#305126", "#325A25", "#356224", "#386B22", "#3B7520", "#3F7F1E", "#43891B", - "#479418", "#4C9F14", "#52AA10", "#58B60C", "#5FC207", "#67cf02"]), - "cf_ygb_cmap": colors.ListedColormap(["#003675", "#005B83", "#008490", "#009A83", "#00A368", "#00AC49", "#00B428", - "#00BC06", "#0CC300", "#39C900", "#67cf02", "#72DB02", "#7EE605", "#8DF207", - "#9CFD09", "#AEFF0B", "#C1FF0E", "#D5FF10", "#EBFF12", "#FFFF14"]), - "cf_cmap": colors.ListedColormap(["#002C9C", "#00419F", "#0056A2", "#006BA4", "#0081A7", "#0098AA", "#00ADAB", - "#00B099", "#00B287", "#00B574", "#00B860", "#00BB4C", "#00BD38", "#00C023", - "#00C20D", "#06C500", "#1EC800", "#36CA00", "#4ECD01", "#67cf02"]), - "viridis": colors.ListedColormap(["#263DA8", "#1652CC", "#1063DB", "#1171D8", "#1380D5", "#0E8ED0", "#089DCC", - "#0DA7C2", "#1DAFB3", "#2DB7A3", "#52BA91", "#73BD80", "#94BE71", "#B2BC65", - "#D0BA59", "#E1BF4A", "#F4C53A", "#FCD12B", "#FAE61C", "#F9F90E"]), + "cf_green_cmap": colors.ListedColormap([ + "#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", + "#305126", "#325A25", "#356224", "#386B22", "#3B7520", "#3F7F1E", "#43891B", + "#479418", "#4C9F14", "#52AA10", "#58B60C", "#5FC207", "#67cf02", + ]), + "cf_ygb_cmap": colors.ListedColormap([ + "#003675", "#005B83", "#008490", "#009A83", "#00A368", "#00AC49", "#00B428", + "#00BC06", "#0CC300", "#39C900", "#67cf02", "#72DB02", "#7EE605", "#8DF207", + "#9CFD09", "#AEFF0B", "#C1FF0E", "#D5FF10", "#EBFF12", "#FFFF14", + ]), + "cf_cmap": colors.ListedColormap([ + "#002C9C", "#00419F", "#0056A2", "#006BA4", "#0081A7", "#0098AA", "#00ADAB", + "#00B099", "#00B287", "#00B574", "#00B860", "#00BB4C", "#00BD38", "#00C023", + "#00C20D", "#06C500", "#1EC800", "#36CA00", "#4ECD01", "#67cf02", + ]), + "viridis": colors.ListedColormap([ + "#263DA8", "#1652CC", "#1063DB", "#1171D8", "#1380D5", "#0E8ED0", "#089DCC", + "#0DA7C2", "#1DAFB3", "#2DB7A3", "#52BA91", "#73BD80", "#94BE71", "#B2BC65", + "#D0BA59", "#E1BF4A", "#F4C53A", "#FCD12B", "#FAE61C", "#F9F90E", + ]), } -def plot_ml_evaluation( - events: ak.Array, - config_inst: od.Config, - category_inst: od.Category, - **kwargs, -) -> plt.Figure: - return None - - -def create_sample_weights(sample_weights: Iterable | bool | None, - events: dict, - true_labels: Iterable, - ) -> dict: - """ Helper function to creates the sample weights for the events, if needed. - - Args: - sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not - givin the weights are calculated based on the number of eventsDefaults to None. - events (dict): dictionary with the true labels as keys and the model output of \ - the events as values. - true_labels (np.ndarray): true labels of the events - - Returns: - dict: sample weights of the events - - Raises: - AssertionError: If both predictions and labels have mismatched shapes, \ - or if `weights` is not `None` and its shape doesn't match `predictions`. - """ +def create_sample_weights( + sample_weights: Sequence[float] | bool | None, + events: dict, + true_labels: np.ndarray, +) -> dict: + """ + Helper function to create the sample weights for the events, if needed. + + :param sample_weights: sample weights of the events. If an explicit array is not given, + the weights are calculated based on the number of events. + :param events: dictionary with the true labels as keys and the model output of the events as values. + :param true_labels: true labels of the events + :return: sample weights of the events + :raises ValueError: If both predictions and labels have mismatched shapes, or + if `weights` is not `None` and its shape doesn't match `predictions`. + """ if not sample_weights: return {label: 1 for label in true_labels} - else: - assert (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))), ( - f"Shape of sample_weights {sample_weights.shape} does not match " + if (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))): + raise ValueError(f"Shape of sample_weights {len(sample_weights)} does not match " f"shape of predictions {len(true_labels)}") - if isinstance(sample_weights, bool): - size = {label: len(event) for label, event in events.items()} - mean = np.mean(list(size.values())) - sample_weights = {label: mean / length for label, length in size.items()} - return sample_weights + if isinstance(sample_weights, bool): + size = {label: len(event) for label, event in events.items()} + mean = np.mean(list(size.values())) + return {label: mean / length for label, length in size.items()} + return {label: weight for label, weight in zip(true_labels, sample_weights)} def plot_cm( - events: dict, - config_inst: od.Config, - category_inst: od.Category, - sample_weights: list | bool = False, - normalization: str = "row", - skip_uncertainties: bool = False, - x_labels: list[str] = None, - y_labels: list[str] = None, - *args, - **kwargs, + events: dict, + config_inst: od.Config, + category_inst: od.Category, + sample_weights: list | bool = False, + normalization: str = "row", + skip_uncertainties: bool = False, + x_labels: list[str] | None = None, + y_labels: list[str] | None = None, + *args, + **kwargs, ) -> tuple[plt.Figure, np.ndarray]: """ Generates the figure of the confusion matrix given the output of the nodes - and a true labels array. The Cronfusion matrix can also be weighted - - Args: - events (dict): dictionary with the true labels as keys and the model output of \ - the events as values. - config_inst (od.Config): used configuration for the plot - category_inst (od.Category): used category instance, for which the plot is created - sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not - givin the weights are calculated based on the number of eventsDefaults to None. - normalization (str, optional): type of normalization of the confusion matrix. Defaults to "row". - skip_uncertainties (bool, optional): calculate errors of the cm elements. Defaults to False. - x_labels (list[str], optional): labels for the x-axis. Defaults to None. - y_labels (list[str], optional): labels for the y-axis. Defaults to None. - *args: Additional arguments to pass to the function. - **kwargs: Additional keyword arguments to pass to the function. - - Returns: - tuple[plt.Figure, np.ndarray]: The resulting plot and the confusion matrix. - - Raises: - AssertionError: If both predictions and labels have mismatched shapes, \ - or if `weights` is not `None` and its shape doesn't match `predictions`. - AssertionError: If `normalization` is not one of `None`, `"row"`, `"column"`. - + and a true labels array. The Cronfusion matrix can also be weighted. + + :param events: dictionary with the true labels as keys and the model output of the events as values. + :param config_inst: used configuration for the plot. + :param category_inst: used category instance, for which the plot is created. + :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are + calculated based on the number of events. Defaults to None. + :param normalization: type of normalization of the confusion matrix. Defaults to "row". + :param skip_uncertainties: calculate errors of the cm elements. Defaults to False. + :param x_labels: labels for the x-axis. Defaults to None. + :param y_labels: labels for the y-axis. Defaults to None. + :param *args: Additional arguments to pass to the function. + :param **kwargs: Additional keyword arguments to pass to the function. + + :return: The resulting plot and the confusion matrix. + + :raises AssertionError: If both predictions and labels have mismatched shapes, or if `weights` + is not `None` and its shape doesn't match `predictions`. + :raises AssertionError: If `normalization` is not one of `None`, `"row"`, `"column"`. """ # defining some useful properties and output shapes true_labels = list(events.keys()) @@ -154,32 +145,31 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: return result - def plot_confusion_matrix(cm: np.ndarray, - title: str = "", - colormap: str = "cf_cmap", - cmap_label: str = "Accuracy", - digits: int = 3, - x_labels: list[str] = None, - y_labels: list[str] = None, - *args, - **kwargs, - ) -> plt.figure: + def plot_confusion_matrix( + cm: np.ndarray, + title: str = "", + colormap: str = "cf_cmap", + cmap_label: str = "Accuracy", + digits: int = 3, + x_labels: list[str] | None = None, + y_labels: list[str] | None = None, + *args, + **kwargs, + ) -> plt.figure: """ Plots a confusion matrix. - Args: - cm (np.ndarray): The confusion matrix to plot. - title (str): The title of the plot, displayed in the top right corner. Defaults to ''. - colormap (str): The name of the colormap to use. Defaults to "cf_cmap". - cmap_label (str): The label of the colorbar. Defaults to "Accuracy". - digits (int): The number of digits to display for each value in the matrix. Defaults to 3. - x_labels (list[str]): The labels for the x-axis. If not provided, the labels will be "out" - y_labels (list[str]): The labels for the y-axis. If not provided, the dataset labels are used. - *args: Additional arguments to pass to the function. - **kwargs: Additional keyword arguments to pass to the function. - - Returns: - plt.figure: The resulting plot. + :param cm: The confusion matrix to plot. + :param title: The title of the plot, displayed in the top right corner. Defaults to ''. + :param colormap: The name of the colormap to use. Defaults to "cf_cmap". + :param cmap_label: The label of the colorbar. Defaults to "Accuracy". + :param digits: The number of digits to display for each value in the matrix. Defaults to 3. + :param x_labels: The labels for the x-axis. If not provided, the labels will be "out" + :param y_labels: The labels for the y-axis. If not provided, the dataset labels are used. + :param *args: Additional arguments to pass to the function. + :param **kwargs: Additional keyword arguments to pass to the function. + + :return: The resulting plot. """ from mpl_toolkits.axes_grid1 import make_axes_locatable @@ -208,7 +198,6 @@ def get_errors(matrix): def value_text(i, j): """Format the inputs as 'Number +- Uncertainty' """ - import re def fmt(v): s = "{{:.{}f}}".format(digits).format(v) return s if re.sub(r"(0|\.)", "", s) else ("<" + s[:-1] + "1") @@ -246,11 +235,17 @@ def fmt(v): # Plot settings thresh = values.max() / 2. - ax.tick_params(axis="both", which="major", - bottom=False, top=False, left=False, right=False) - ax.tick_params(axis="both", which="minor", - bottom=True, top=True, left=True, right=True, - length=minor_tick_length, width=minor_tick_width) + ax.tick_params(axis="both", which="major", bottom=False, top=False, left=False, right=False) + ax.tick_params( + axis="both", + which="minor", + bottom=True, + top=True, + left=True, + right=True, + length=minor_tick_length, + width=minor_tick_width, + ) ax.set_xticks(xtick_marks + 0.5, minor=True) ax.set_yticks(ytick_marks + 0.5, minor=True) ax.set_xticks(xtick_marks) @@ -271,13 +266,18 @@ def fmt(v): # Add Matrix Elemtns for i in range(values.shape[0]): for j in range(values.shape[1]): - ax.text(j, i, value_text(i, j), fontdict={"size": font_text}, - horizontalalignment="center", verticalalignment="center", - color="white" if values[i, j] < thresh else "black") + ax.text( + j, + i, + value_text(i, j), + fontdict={"size": font_text}, + horizontalalignment="center", + verticalalignment="center", + color="white" if values[i, j] < thresh else "black", + ) # final touches - hep.cms.label(ax=ax, llabel="private work", - rlabel=title if title else "") + hep.cms.label(ax=ax, llabel="private work", rlabel=title if title else "") plt.tight_layout() return fig diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index b465259f6..38b6d0d3c 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -4,11 +4,11 @@ Tasks related to ML workflows. """ +from collections import OrderedDict + import law import luigi -from collections import OrderedDict - from columnflow.tasks.framework.base import Requirements, AnalysisTask, DatasetTask, wrapper_factory from columnflow.tasks.framework.mixins import ( CalibratorsMixin, @@ -27,6 +27,10 @@ from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict +from modules.columnflow.columnflow.util import maybe_import + + +ak = maybe_import("awkward") class PrepareMLEvents( @@ -705,22 +709,22 @@ class PlotMLResultsBase( plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_ml_evaluation.plot_ml_evaluation", add_default_to_description=True, - description="The full path of the desired plot function, that is to be called on the inputs. \ - The full path should be givin using the dot notation", + description="The full path of the desired plot function, that is to be called on the inputs." + "The full path should be givin using the dot notation", ) skip_processes = law.CSVParameter( default=("",), - description="names of processes to skip; These processes will not be displayed int he plot. \ - config; default: ('*',)", + description="names of processes to skip; These processes will not be displayed int he plot." + "config; default: ('*',)", brace_expand=True, ) plot_sub_processes = luigi.BoolParameter( default=False, significant=False, - description="when True, each process is divided into the different subprocesses \ - which will be used as classes for the plot; default: False", + description="when True, each process is divided into the different subprocesses" + "which will be used as classes for the plot; default: False", ) skip_uncertainties = luigi.BoolParameter( @@ -771,7 +775,6 @@ def output(self): return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}{self.plot_suffix}.pdf") def prepare_inputs(self): - import awkward as ak category_inst = self.config_inst.get_category(self.branch_data.category) leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] @@ -806,8 +809,9 @@ def prepare_inputs(self): if not self.plot_sub_processes: if process_inst.name in all_events.keys(): - all_events[process_inst.name] = ak.concatenate([all_events[process_inst.name], - getattr(events, self.ml_model)]) + all_events[process_inst.name] = ak.concatenate([ + all_events[process_inst.name], getattr(events, self.ml_model), + ]) else: all_events[process_inst.name] = getattr(events, self.ml_model) else: @@ -818,9 +822,10 @@ def prepare_inputs(self): process_mask = ak.where(events.process_ids == sub_process.id, True, False) if sub_process.name in all_events.keys(): - all_events[sub_process.name] = ( - ak.concatenate([all_events[sub_process.name], - getattr(events[process_mask], self.ml_model)])) + all_events[sub_process.name] = ak.concatenate([ + all_events[sub_process.name], + getattr(events[process_mask], self.ml_model), + ]) else: all_events[sub_process.name] = getattr(events[process_mask], self.ml_model) else: @@ -840,21 +845,20 @@ class PlotMLResults(PlotMLResultsBase): used to create the plot. Attributes: - plot_function (str): The name of the plot function to use. \ + plot_function (str): The name of the plot function to use. Can be either "plot_cm" or "plot_roc". - processes_repr (str): A string representation of the number of \ + processes_repr (str): A string representation of the number of processes used to generate the plot(s). - config_inst (Config): An instance of the Config class that contains \ + config_inst (Config): An instance of the Config class that contains the configuration for this task. - branch_data (BranchData): An instance of the BranchData class that \ + branch_data (BranchData): An instance of the BranchData class that contains the input data for this task. """ # override the plot_function parameter to be able to only choose between CM and ROC plot_function = luigi.ChoiceParameter( default="plot_cm", choices=["cm", "roc"], - description="The name of the plot function to use. \ - Can be either 'plot_cm' or 'plot_roc'.", + description="The name of the plot function to use. Can be either 'plot_cm' or 'plot_roc'.", ) def prepare_plot_parameters(self): @@ -866,16 +870,20 @@ def prepare_plot_parameters(self): params.general_settings["x_labels"] = x_labels.replace("&", "$").split(";") def output(self): - output = {"plot": super().output(), - "array": self.target(f"plot__proc_{self.processes_repr}.parquet")} + output = { + "plot": super().output(), + "array": self.target(f"plot__proc_{self.processes_repr}.parquet"), + } return output @law.decorator.log @view_output_plots def run(self): from matplotlib.backends.backend_pdf import PdfPages - func_path = {"cm": "columnflow.plotting.plot_ml_evaluation.plot_cm", - "roc": "columnflow.plotting.plot_ml_evaluation.plot_roc"} + func_path = { + "cm": "columnflow.plotting.plot_ml_evaluation.plot_cm", + "roc": "columnflow.plotting.plot_ml_evaluation.plot_roc", + } category_inst = self.config_inst.get_category(self.branch_data.category) self.prepare_plot_parameters() with self.publish_step(f"plotting in {category_inst.name}"): From ab5b398949188a9ada55f55cb48df27f49c32eb7 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 6 Nov 2023 17:53:03 +0100 Subject: [PATCH 18/36] fixed imports --- columnflow/tasks/ml.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 38b6d0d3c..fbe45ba48 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -27,7 +27,7 @@ from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict -from modules.columnflow.columnflow.util import maybe_import +from columnflow.util import maybe_import ak = maybe_import("awkward") From 62364dd97b791ce801b54547807e2026b3506e8f Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Tue, 7 Nov 2023 14:54:21 +0100 Subject: [PATCH 19/36] removed unneeded None as type from weights --- columnflow/plotting/plot_ml_evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index adf2fc6d9..db743f82c 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -45,7 +45,7 @@ def create_sample_weights( - sample_weights: Sequence[float] | bool | None, + sample_weights: Sequence[float] | bool, events: dict, true_labels: np.ndarray, ) -> dict: @@ -62,7 +62,7 @@ def create_sample_weights( """ if not sample_weights: return {label: 1 for label in true_labels} - if (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))): + if not (isinstance(sample_weights, bool) or (len(sample_weights) == len(true_labels))): raise ValueError(f"Shape of sample_weights {len(sample_weights)} does not match " f"shape of predictions {len(true_labels)}") if isinstance(sample_weights, bool): From 247dbb59c08fd1e4b704b5142893092813956d33 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 10 Nov 2023 15:43:12 +0100 Subject: [PATCH 20/36] initial fixes --- columnflow/plotting/plot_ml_evaluation.py | 54 +++++++------ columnflow/tasks/ml.py | 92 +++++++++++------------ 2 files changed, 75 insertions(+), 71 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index db743f82c..2909d1e26 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -47,7 +47,7 @@ def create_sample_weights( sample_weights: Sequence[float] | bool, events: dict, - true_labels: np.ndarray, + true_labels: Sequence[str], ) -> dict: """ Helper function to create the sample weights for the events, if needed. @@ -58,7 +58,7 @@ def create_sample_weights( :param true_labels: true labels of the events :return: sample weights of the events :raises ValueError: If both predictions and labels have mismatched shapes, or - if `weights` is not `None` and its shape doesn't match `predictions`. + if *weights* is not *None* and its shape doesn't match the predictions length. """ if not sample_weights: return {label: 1 for label in true_labels} @@ -69,7 +69,7 @@ def create_sample_weights( size = {label: len(event) for label, event in events.items()} mean = np.mean(list(size.values())) return {label: mean / length for label, length in size.items()} - return {label: weight for label, weight in zip(true_labels, sample_weights)} + return dict(zip(true_labels, sample_weights)) def plot_cm( @@ -85,25 +85,25 @@ def plot_cm( **kwargs, ) -> tuple[plt.Figure, np.ndarray]: """ Generates the figure of the confusion matrix given the output of the nodes - and a true labels array. The Cronfusion matrix can also be weighted. + and an array of true labels. The Cronfusion matrix can also be weighted. :param events: dictionary with the true labels as keys and the model output of the events as values. :param config_inst: used configuration for the plot. :param category_inst: used category instance, for which the plot is created. :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are - calculated based on the number of events. Defaults to None. - :param normalization: type of normalization of the confusion matrix. Defaults to "row". - :param skip_uncertainties: calculate errors of the cm elements. Defaults to False. - :param x_labels: labels for the x-axis. Defaults to None. - :param y_labels: labels for the y-axis. Defaults to None. + calculated based on the number of events. + :param normalization: type of normalization of the confusion matrix. If not provided, the matrix is row normalized. + :param skip_uncertainties: If true, no uncertainty of the cells will be shown in the plot. + :param x_labels: labels for the x-axis. + :param y_labels: labels for the y-axis. :param *args: Additional arguments to pass to the function. :param **kwargs: Additional keyword arguments to pass to the function. :return: The resulting plot and the confusion matrix. - :raises AssertionError: If both predictions and labels have mismatched shapes, or if `weights` - is not `None` and its shape doesn't match `predictions`. - :raises AssertionError: If `normalization` is not one of `None`, `"row"`, `"column"`. + :raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights* + is not *None* and its shape doesn't match *predictions*. + :raises AssertionError: If *normalization* is not one of *None*, "row", "column". """ # defining some useful properties and output shapes true_labels = list(events.keys()) @@ -136,9 +136,11 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: # Normalize Matrix if needed if normalization: valid = {"row": 1, "column": 0} - assert (normalization in valid.keys()), ( - f"\"{normalization}\" is no valid argument for normalization. If givin, normalization " - "should only take \"row\" or \"column\"") + if normalization not in valid.keys(): + raise ValueError( + f"\"{normalization}\" is not a valid argument for normalization. If given, normalization " + "should only take \"row\" or \"column\"", + ) row_sums = result.sum(axis=valid.get(normalization)) result = result / row_sums[:, np.newaxis] @@ -160,12 +162,13 @@ def plot_confusion_matrix( Plots a confusion matrix. :param cm: The confusion matrix to plot. - :param title: The title of the plot, displayed in the top right corner. Defaults to ''. - :param colormap: The name of the colormap to use. Defaults to "cf_cmap". - :param cmap_label: The label of the colorbar. Defaults to "Accuracy". - :param digits: The number of digits to display for each value in the matrix. Defaults to 3. + :param title: The title of the plot, displayed in the top right corner. + :param colormap: The name of the colormap to use. Can be selected from the following: + "cf_cmap", "cf_green_cmap", "cf_ygb_cmap", "viridis". + :param cmap_label: The label of the colorbar. + :param digits: The number of digits to display for each value in the matrix. :param x_labels: The labels for the x-axis. If not provided, the labels will be "out" - :param y_labels: The labels for the y-axis. If not provided, the dataset labels are used. + :param y_labels: The labels for the y-axis. If not provided, the dataset names are used. :param *args: Additional arguments to pass to the function. :param **kwargs: Additional keyword arguments to pass to the function. @@ -189,7 +192,9 @@ def calculate_font_size(): return font_size def get_errors(matrix): - """Useful for seperating the error from the data""" + """ + Useful for seperating the error from the data + """ if matrix.dtype.name == "object": get_errors_vec = np.vectorize(lambda x: x.get(sci.UP, unc=True)) return get_errors_vec(matrix) @@ -197,14 +202,15 @@ def get_errors(matrix): return np.zeros_like(matrix) def value_text(i, j): - """Format the inputs as 'Number +- Uncertainty' """ + """ + Format the inputs as 'Number +- Uncertainty' + """ def fmt(v): s = "{{:.{}f}}".format(digits).format(v) return s if re.sub(r"(0|\.)", "", s) else ("<" + s[:-1] + "1") if skip_uncertainties: return fmt(values[i][j]) - else: - return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j]))) + return "{}\n\u00B1{}".format(fmt(values[i][j]), fmt(np.nan_to_num(uncs[i][j]))) # create the plot plt.style.use(hep.style.CMS) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index fbe45ba48..2b40f00c0 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -5,6 +5,7 @@ """ from collections import OrderedDict +from typing import Dict, TypeVar import law import luigi @@ -695,42 +696,39 @@ class PlotMLResultsBase( law.LocalWorkflow, RemoteWorkflow, ): - """A base class, used for the implementation of the ML plotting tasks. This class implements - a `plot_function` parameter for choosing a desired plotting function and a `prepare_inputs` method, + """ + A base class, used for the implementation of the ML plotting tasks. This class implements + a ``plot_function`` parameter for choosing a desired plotting function and a ``prepare_inputs`` method, that returns a dict with the chosen events. - - Raises: - NotImplementedError: This error is raised if a givin dataset contains more than one process. - ValueError: This error is raised if `plot_sub_processes` is used without providing the - `process_ids` column in the data """ sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") + Self = TypeVar("Self", bound="PlotMLResultsBase") # TODO add comment + plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_ml_evaluation.plot_ml_evaluation", add_default_to_description=True, - description="The full path of the desired plot function, that is to be called on the inputs." - "The full path should be givin using the dot notation", + description="the full path given using the dot notation of the desired plot function.", ) skip_processes = law.CSVParameter( default=("",), - description="names of processes to skip; These processes will not be displayed int he plot." - "config; default: ('*',)", + description="names of processes to skip; these processes will not be included in the plots." + "config; default: ('',)", brace_expand=True, ) plot_sub_processes = luigi.BoolParameter( default=False, significant=False, - description="when True, each process is divided into the different subprocesses" - "which will be used as classes for the plot; default: False", + description="when True, each process is divided into the different subprocesses; " + "this option requires a ``process_ids`` column to be stored in the events; default: False", ) skip_uncertainties = luigi.BoolParameter( default=False, significant=False, - description="when True, uncertainties are not displayed in the table; default: False", + description="when True, count uncertainties (if available) are not included in the plot; default: False", ) # upstream requirements @@ -739,18 +737,18 @@ class PlotMLResultsBase( MergeMLEvaluation=MergeMLEvaluation, ) - def store_parts(self): + def store_parts(self: Self): parts = super().store_parts() parts.insert_before("version", "plot", f"datasets_{self.datasets_repr}") return parts - def create_branch_map(self): + def create_branch_map(self: Self): return [ DotDict({"category": cat_name}) for cat_name in sorted(self.categories) ] - def requires(self): + def requires(self: Self): return { d: self.reqs.MergeMLEvaluation.req( self, @@ -761,7 +759,7 @@ def requires(self): for d in self.datasets } - def workflow_requires(self, only_super: bool = False): + def workflow_requires(self: Self, only_super: bool = False): reqs = super().workflow_requires() if only_super: return reqs @@ -770,12 +768,22 @@ def workflow_requires(self, only_super: bool = False): return reqs - def output(self): + def output(self: Self): b = self.branch_data return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}{self.plot_suffix}.pdf") - def prepare_inputs(self): + def prepare_inputs(self: Self) -> Dict[str, ak.Array]: + """prepare the inputs for the plot function, based on the given configuration and category. + + Raises: + NotImplementedError: This error is raised if a givin dataset contains more than one process. + ValueError: This error is raised if ``plot_sub_processes`` is used without providing the + ``process_ids`` column in the data + Returns: + Dict[str, ak.Array]: A dictionary with the dataset names as keys and + the corresponding predictions as values. + """ category_inst = self.config_inst.get_category(self.branch_data.category) leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] process_insts = list(map(self.config_inst.get_process, self.processes)) @@ -815,22 +823,21 @@ def prepare_inputs(self): else: all_events[process_inst.name] = getattr(events, self.ml_model) else: - if "process_ids" in events.fields: - for sub_process in sub_process_insts[process_inst]: - if sub_process.name in self.skip_processes: - continue - - process_mask = ak.where(events.process_ids == sub_process.id, True, False) - if sub_process.name in all_events.keys(): - all_events[sub_process.name] = ak.concatenate([ - all_events[sub_process.name], - getattr(events[process_mask], self.ml_model), - ]) - else: - all_events[sub_process.name] = getattr(events[process_mask], self.ml_model) - else: + if "process_ids" not in events.fields: raise ValueError("No `process_ids` column stored in the events! " f"Process selection for {dataset} cannot not be applied!") + for sub_process in sub_process_insts[process_inst]: + if sub_process.name in self.skip_processes: + continue + + process_mask = ak.where(events.process_ids == sub_process.id, True, False) + if sub_process.name in all_events.keys(): + all_events[sub_process.name] = ak.concatenate([ + all_events[sub_process.name], + getattr(events[process_mask], self.ml_model), + ]) + else: + all_events[sub_process.name] = getattr(events[process_mask], self.ml_model) return all_events @@ -838,27 +845,17 @@ class PlotMLResults(PlotMLResultsBase): """ A task that generates plots for machine learning results. - This task generates plots for machine learning results, based on the given + This task generates plots for machine learning results based on the given configuration and category. The plots can be either a confusion matrix (CM) or a receiver operating characteristic (ROC) curve. This task uses the output of the MergeMLEvaluation task as input and saves the plots with the corresponding array used to create the plot. - - Attributes: - plot_function (str): The name of the plot function to use. - Can be either "plot_cm" or "plot_roc". - processes_repr (str): A string representation of the number of - processes used to generate the plot(s). - config_inst (Config): An instance of the Config class that contains - the configuration for this task. - branch_data (BranchData): An instance of the BranchData class that - contains the input data for this task. """ # override the plot_function parameter to be able to only choose between CM and ROC plot_function = luigi.ChoiceParameter( default="plot_cm", choices=["cm", "roc"], - description="The name of the plot function to use. Can be either 'plot_cm' or 'plot_roc'.", + description="The name of the plot function to use. Can be either 'cm' or 'roc'.", ) def prepare_plot_parameters(self): @@ -867,7 +864,7 @@ def prepare_plot_parameters(self): # parse x_label from general settings x_labels = params.general_settings.get("x_labels", None) if x_labels: - params.general_settings["x_labels"] = x_labels.replace("&", "$").split(";") + params.general_settings["x_labels"] = x_labels.split(";") def output(self): output = { @@ -893,6 +890,7 @@ def run(self): events=all_events, config_inst=self.config_inst, category_inst=category_inst, + skip_uncertainties=self.skip_uncertainties, **self.get_plot_parameters(), ) self.output()["array"].dump(array, formatter="pickle") From cdd55a6ffee78b9f9e07e0090f4e486fecfe9bfb Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 10 Nov 2023 15:44:32 +0100 Subject: [PATCH 21/36] Add ROC curve plotting template to plot_ml_evaluation.py --- columnflow/plotting/plot_ml_evaluation.py | 148 ++++++++++++++++++++++ 1 file changed, 148 insertions(+) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 2909d1e26..a16d74cd8 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -292,3 +292,151 @@ def fmt(v): fig = plot_confusion_matrix(cm, x_labels=x_labels, y_labels=y_labels, *args, **kwargs) return [fig], cm + + +def plot_roc( + events: dict, + config_inst: od.Config, + category_inst: od.Category, + sample_weights: Sequence | bool = False, + skip_uncertainties: bool = False, + n_thresholds: int = 200 + 1, + *args, + **kwargs, +) -> tuple[plt.Figure, dict]: + """ Generates the figure of the ROC curve given the events output of the ML evaluation + + Args: + events (dict): dictionary with the true labels as keys and the model output of + the events as values. + config_inst (od.Config): used configuration for the plot + category_inst (od.Category): used category instance, for which the plot is created + sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not + givin the weights are calculated based on the number of eventsDefaults to None. + skip_uncertainties (bool, optional): calculate errors of the cm elements. Defaults to False. + n_thresholds (int): number of thresholds used for the ROC curve + *args: Additional arguments to pass to the function. + **kwargs: Additional keyword arguments to pass to the function. + + Returns: + tuple[plt.Figure, dict]: The resulting plot and the ROC curve data. + + Raises: + AssertionError: If both predictions and labels have mismatched shapes, + or if *weights* is not *None* and its shape doesn't match *predictions*. + AssertionError: If *normalization* is not one of *None*, "row", "column". + + """ + # defining some useful properties and output shapes + thresholds = np.geomspace(1e-6, 1, n_thresholds) + weights = create_sample_weights(sample_weights, events, list(events.keys())) + discriminators = list(events.values())[0].fields + + def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> dict: + hists = {} + for cls, predictions in events.items(): + hists[cls] = {} + for disc in discriminators: + hists[cls][disc] = weights[cls] * ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]) + return hists + + def binary_roc_data(positiv_hist: np.ndarray, negativ_hist: np.ndarray, *args, **kwargs) -> tuple: + """ + Compute binary Receiver operating characteristic (ROC) values. + Used as a helper function for the multi-dimensional ROC curve + """ + + # Calculate the different rates + fn = np.cumsum(positiv_hist) + tn = np.cumsum(negativ_hist) + tp = fn[-1] - fn + fp = tn[-1] - tn + + tpr = tp / (tp + fn) + fpr = fp / (fp + tn) + + return fpr, tpr + + def roc_curve_data( + evaluation_type: str, + histograms: dict, + *args, + **kwargs + ) -> dict: + """ + Compute Receiver operating characteristic (ROC) values for a multi-dimensional output. + """ + + def one_vs_rest(names): + result = {} + for ind, cls_name in enumerate(names): + positiv_inputs = model_output[:, ind] + fpr, tpr, th = binary_roc_data(true_labels=(true_labels == ind), + model_output_positive=positiv_inputs, + sample_weights=sample_weights, + *args, + thresholds=thresholds, + errors=errors, + output_length=output_length) + result[f"{cls_name}_vs_rest"] = {"fpr": fpr, + "tpr": tpr, + "thresholds": th} + return result + + def one_vs_one(names): + result = {} + for pos_ind, cls_name in enumerate(names): + for neg_ind, cls_name2 in enumerate(names): + if (pos_ind == neg_ind): + continue + + # Event selection masks only for the 2 classes analysed + inputs_mask = np.logical_or(true_labels == pos_ind, + true_labels == neg_ind) + select_input = model_output[inputs_mask] + select_labels = true_labels[inputs_mask] + select_weights = None if sample_weights is None else sample_weights[inputs_mask] + + positiv_inputs = select_input[:, pos_ind] + fpr, tpr, th = binary_roc_data(true_labels=(select_labels == pos_ind), + model_output_positive=positiv_inputs, + sample_weights=select_weights, + *args, + thresholds=thresholds, + errors=errors, + output_length=output_length) + result[f"{cls_name}_vs_{cls_name2}"] = {"fpr": fpr, + "tpr": tpr, + "thresholds": th} + return result + + is_input_valid(true_labels, model_output) + + # reshape in case only predictions for the positive class are givin + if model_output.ndim != 2: + model_output = model_output.reshape((model_output.size, 1)) + + # Generate class names if not givin + if class_names is None: + class_names = list(range(model_output.shape[1])) + + assert (len(class_names) == model_output.shape[1]), ( + "Number of givin class names does not match the number of output nodes in the *model_output* argument!") + + # Cast trues labels to class numers + if true_labels.dtype.name == "bool": + true_labels = np.logical_not(true_labels).astype(dtype=np.int32) + + # Map true labels to integers if needed + if "int" not in true_labels.dtype.name: + for ind, name in enumerate(class_names): + true_labels = np.where(true_labels == name, ind, true_labels) + + # Choose the evaluation type + if (evaluation_type == "OvO"): + return one_vs_one(class_names) + elif (evaluation_type == "OvR"): + return one_vs_rest(class_names) + else: + raise ValueError("Illeagal Argument! Evaluation Type can only be choosen as \"OvO\" (One vs One) \ + or \"OvR\" (One vs Rest)") From 697e019408ac9a558089335eb5a40b9a141bd222 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 13 Nov 2023 14:34:26 +0100 Subject: [PATCH 22/36] apply review comments :) --- columnflow/plotting/plot_ml_evaluation.py | 3 +- columnflow/tasks/ml.py | 92 ++++++++++++++--------- 2 files changed, 57 insertions(+), 38 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 2909d1e26..a2e5fd2ad 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -198,8 +198,7 @@ def get_errors(matrix): if matrix.dtype.name == "object": get_errors_vec = np.vectorize(lambda x: x.get(sci.UP, unc=True)) return get_errors_vec(matrix) - else: - return np.zeros_like(matrix) + return np.zeros_like(matrix) def value_text(i, j): """ diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 2b40f00c0..ae4344d64 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -3,9 +3,9 @@ """ Tasks related to ML workflows. """ +from __future__ import annotations from collections import OrderedDict -from typing import Dict, TypeVar import law import luigi @@ -29,6 +29,7 @@ from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict from columnflow.util import maybe_import +from columnflow.types import Dict, TypeVar, List ak = maybe_import("awkward") @@ -703,8 +704,6 @@ class PlotMLResultsBase( """ sandbox = dev_sandbox("bash::$CF_BASE/sandboxes/venv_columnar.sh") - Self = TypeVar("Self", bound="PlotMLResultsBase") # TODO add comment - plot_function = PlotBase.plot_function.copy( default="columnflow.plotting.plot_ml_evaluation.plot_ml_evaluation", add_default_to_description=True, @@ -713,8 +712,8 @@ class PlotMLResultsBase( skip_processes = law.CSVParameter( default=("",), - description="names of processes to skip; these processes will not be included in the plots." - "config; default: ('',)", + description="comma seperated list of process names to skip; these processes will not be included in the plots. " + "default: ('',)", brace_expand=True, ) @@ -722,7 +721,9 @@ class PlotMLResultsBase( default=False, significant=False, description="when True, each process is divided into the different subprocesses; " - "this option requires a ``process_ids`` column to be stored in the events; default: False", + "this option requires a ``process_ids`` column to be stored in the events; " + "the ``process_ids`` column assignes a subprocess id number (predefined in the config) to each event; " + "default: False", ) skip_uncertainties = luigi.BoolParameter( @@ -737,18 +738,18 @@ class PlotMLResultsBase( MergeMLEvaluation=MergeMLEvaluation, ) - def store_parts(self: Self): + def store_parts(self: PlotMLResultsBase): parts = super().store_parts() parts.insert_before("version", "plot", f"datasets_{self.datasets_repr}") return parts - def create_branch_map(self: Self): + def create_branch_map(self: PlotMLResultsBase): return [ DotDict({"category": cat_name}) for cat_name in sorted(self.categories) ] - def requires(self: Self): + def requires(self: PlotMLResultsBase): return { d: self.reqs.MergeMLEvaluation.req( self, @@ -759,7 +760,7 @@ def requires(self: Self): for d in self.datasets } - def workflow_requires(self: Self, only_super: bool = False): + def workflow_requires(self: PlotMLResultsBase, only_super: bool = False): reqs = super().workflow_requires() if only_super: return reqs @@ -768,21 +769,23 @@ def workflow_requires(self: Self, only_super: bool = False): return reqs - def output(self: Self): + def output(self: PlotMLResultsBase) -> Dict[str, List]: b = self.branch_data - return self.target(f"plot__proc_{self.processes_repr}__cat_{b.category}{self.plot_suffix}.pdf") + return {"plots": [ + self.target(name) + for name in self.get_plot_names(f"plot__proc_{self.processes_repr}__cat_{b.category}") + ]} - def prepare_inputs(self: Self) -> Dict[str, ak.Array]: - """prepare the inputs for the plot function, based on the given configuration and category. + def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: + """ + prepare the inputs for the plot function, based on the given configuration and category. - Raises: - NotImplementedError: This error is raised if a givin dataset contains more than one process. - ValueError: This error is raised if ``plot_sub_processes`` is used without providing the - ``process_ids`` column in the data + :raises NotImplementedError: This error is raised if a given dataset contains more than one process. + :raises ValueError: This error is raised if ``plot_sub_processes`` is used without providing the + ``process_ids`` column in the data - Returns: - Dict[str, ak.Array]: A dictionary with the dataset names as keys and - the corresponding predictions as values. + :return: Dict[str, ak.Array]: A dictionary with the dataset names as keys and + the corresponding predictions as values. """ category_inst = self.config_inst.get_category(self.branch_data.category) leaf_category_insts = category_inst.get_leaf_categories() or [category_inst] @@ -851,6 +854,7 @@ class PlotMLResults(PlotMLResultsBase): MergeMLEvaluation task as input and saves the plots with the corresponding array used to create the plot. """ + # override the plot_function parameter to be able to only choose between CM and ROC plot_function = luigi.ChoiceParameter( default="plot_cm", @@ -858,31 +862,38 @@ class PlotMLResults(PlotMLResultsBase): description="The name of the plot function to use. Can be either 'cm' or 'roc'.", ) - def prepare_plot_parameters(self): + def prepare_plot_parameters(self: PlotMLResults): + """ + Helper function to prepare the plot parameters for the plot function. + Implemented to parse the axes labels from the general settings. + """ params = self.get_plot_parameters() - # parse x_label from general settings - x_labels = params.general_settings.get("x_labels", None) - if x_labels: - params.general_settings["x_labels"] = x_labels.split(";") + # parse x_label and y_label from general settings + for label in ["x_labels", "y_labels"]: + if label in params.general_settings.keys(): + params.general_settings[label] = params.general_settings[label].split(";") - def output(self): - output = { - "plot": super().output(), - "array": self.target(f"plot__proc_{self.processes_repr}.parquet"), + def output(self: PlotMLResults): + b = self.branch_data + return {"plots": [ + self.target(name) + for name in self.get_plot_names(f"0_plot__proc_{self.processes_repr}__cat_{b.category}") + ], + "array": self.target(f"plot__proc_{self.processes_repr}.parquet") } - return output @law.decorator.log @view_output_plots - def run(self): - from matplotlib.backends.backend_pdf import PdfPages + def run(self: PlotMLResults): func_path = { "cm": "columnflow.plotting.plot_ml_evaluation.plot_cm", "roc": "columnflow.plotting.plot_ml_evaluation.plot_roc", } category_inst = self.config_inst.get_category(self.branch_data.category) self.prepare_plot_parameters() + + # call the plot function with self.publish_step(f"plotting in {category_inst.name}"): all_events = self.prepare_inputs() figs, array = self.call_plot_func( @@ -893,7 +904,16 @@ def run(self): skip_uncertainties=self.skip_uncertainties, **self.get_plot_parameters(), ) + + # save the outputs self.output()["array"].dump(array, formatter="pickle") - with PdfPages(self.output()["plot"].abspath) as pdf: - for f in figs: - f.savefig(pdf, format="pdf") + for file_path in self.output()["plots"]: + if file_path.ext() == "pdf": + from matplotlib.backends.backend_pdf import PdfPages + with PdfPages(file_path.abspath) as pdf: + for f in figs: + f.savefig(pdf, format="pdf") + continue + + for index, f in enumerate(figs): + f.savefig(file_path.abs_dirname + "/" + f"{index}_{file_path.basename[2:]}", format=file_path.ext()) From 9805f6c5ee63065eb7e83c2548f5100c51565664 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Mon, 13 Nov 2023 14:35:38 +0100 Subject: [PATCH 23/36] did not see some linting errors --- columnflow/tasks/ml.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index ae4344d64..91e06812b 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -29,7 +29,7 @@ from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict from columnflow.util import maybe_import -from columnflow.types import Dict, TypeVar, List +from columnflow.types import Dict, List ak = maybe_import("awkward") @@ -880,7 +880,7 @@ def output(self: PlotMLResults): self.target(name) for name in self.get_plot_names(f"0_plot__proc_{self.processes_repr}__cat_{b.category}") ], - "array": self.target(f"plot__proc_{self.processes_repr}.parquet") + "array": self.target(f"plot__proc_{self.processes_repr}.parquet"), } @law.decorator.log From 86949dcbcb2758f4f317eb7c494d099dab4ca07a Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Tue, 14 Nov 2023 14:53:42 +0100 Subject: [PATCH 24/36] fixed issues with uncert and column normalization --- columnflow/plotting/plot_ml_evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index a2e5fd2ad..f47a51dfc 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -130,7 +130,7 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: counts[ind, index] += count if not skip_uncertainties: - vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count)))) + vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0))) result = vecNumber(result, counts) # Normalize Matrix if needed @@ -143,7 +143,7 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: ) row_sums = result.sum(axis=valid.get(normalization)) - result = result / row_sums[:, np.newaxis] + result = result / row_sums[:, np.newaxis] if valid.get(normalization) else result / row_sums return result From 923d9a59f48bb56a688c6f68e84c7f717c695972 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Tue, 14 Nov 2023 15:17:48 +0100 Subject: [PATCH 25/36] Added tests for CM --- tests/run_tests | 6 ++ tests/test_plotting.py | 181 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) create mode 100644 tests/test_plotting.py diff --git a/tests/run_tests b/tests/run_tests index cdf9c122d..fa2669a89 100755 --- a/tests/run_tests +++ b/tests/run_tests @@ -32,6 +32,12 @@ action() { ret="$?" [ "${gret}" = "0" ] && gret="${ret}" + # test_plotting + echo + bash "${this_dir}/run_test" test_plotting "${cf_dir}/sandboxes/venv_columnar${dev}.sh" + ret="$?" + [ "${gret}" = "0" ] && gret="${ret}" + return "${gret}" } action "$@" diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 000000000..aca56b16a --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,181 @@ +""" +Test the plot_ml_evaluation module. +""" + +__all__ = ["TestPlotCM"] + +import unittest +from unittest.mock import MagicMock + +from columnflow.util import maybe_import +from columnflow.plotting.plot_ml_evaluation import plot_cm + + +np = maybe_import("numpy") +ak = maybe_import("awkward") +plt = maybe_import("matplotlib.pyplot") + + +class TestPlotCM(unittest.TestCase): + + def setUp(self): + self.events = { + "dataset_1": ak.Array({ + "out1": [0.1, 0.1, 0.3, 0.5], + "out2": [0.2, 0.2, 0.4, 0.2], + "out3": [0.7, 0.7, 0.3, 0.3], + }), + "dataset_2": ak.Array({ + "out1": [0.2, 0.2, 0.4, 0.3], + "out2": [0.3, 0.3, 0.3, 0.2], + "out3": [0.5, 0.5, 0.3, 0.5], + }), + } + self.config_inst = MagicMock() + self.category_inst = MagicMock() + self.sample_weights = [1, 2] + self.normalization = "row" + self.skip_uncertainties = False + self.x_labels = ["out1", "out2", "out3"] + self.y_labels = ["dataset_1", "dataset_2"] + self.weighted_matrix = np.array([[0.25, 0.25, 0.5], [0.25, 0, 0.75]]) + self.unweighted_matrix = np.array([[0.25, 0.25, 0.5], [0.25, 0, 0.75]]) + self.not_normalized_matrix = np.array([[1, 1, 2], [1, 0, 3]]) + self.column_normalized_matrix = np.array([[0.5, 1, 0.4], [0.5, 0, 0.6]]) + + def test_plot_cm(self): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + self.assertIsInstance(fig, list) + self.assertIsInstance(fig[0], plt.Figure) + self.assertIsInstance(cm, np.ndarray) + self.assertEqual(cm.shape, (2, 3)) + self.assertEqual(cm.tolist(), self.weighted_matrix.tolist()) + + def test_plot_cm_no_weights(self): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + self.assertIsInstance(fig, list) + self.assertIsInstance(fig[0], plt.Figure) + self.assertIsInstance(cm, np.ndarray) + self.assertEqual(cm.shape, (2, 3)) + self.assertEqual(cm.tolist(), self.unweighted_matrix.tolist()) + + def test_plot_cm_skip_uncertainties(self): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=True, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + self.assertIsInstance(fig, list) + self.assertIsInstance(fig[0], plt.Figure) + self.assertIsInstance(cm, np.ndarray) + self.assertEqual(cm.shape, (2, 3)) + self.assertEqual(cm.tolist(), self.weighted_matrix.tolist()) + + def test_plot_cm_no_labels(self): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + ) + x_labels = ["out0", "out1", "out2"] + y_labels = ["dataset_1", "dataset_2"] + self.assertEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) + self.assertEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) + + def test_plot_cm_labels(self): + x_labels = ["vbf", "ggf", "other"] + y_labels = ["Higgs", "Graviton"] + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=x_labels, + y_labels=y_labels, + ) + self.assertEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) + self.assertEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) + + def test_plot_cm_invalid_normalization(self): + with self.assertRaises(ValueError): + plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization="invalid", + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + + def test_plot_cm_no_normalization(self): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + normalization=None, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + self.assertEqual(cm.shape, (2, 3)) + self.assertEqual(cm.tolist(), self.not_normalized_matrix.tolist()) + + def test_plot_cm_column_normalization(self): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + normalization="column", + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + self.assertEqual(cm.shape, (2, 3)) + self.assertEqual(cm.tolist(), self.column_normalized_matrix.tolist()) + + def test_plot_cm_mismatched_weights_shape(self): + sample_weights = [1, 2, 3] + with self.assertRaises(ValueError): + plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) + + +if __name__ == "__main__": + unittest.main() From 98d6ad0f43726562942c1758ca3a5cb63bad044f Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 16 Nov 2023 17:19:34 +0100 Subject: [PATCH 26/36] fixed ROC logic, only plotting missing --- columnflow/plotting/plot_ml_evaluation.py | 177 ++++++++++++---------- 1 file changed, 95 insertions(+), 82 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 0cdeb59c3..c3981028c 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -5,10 +5,11 @@ """ from __future__ import annotations +from calendar import c import re -from columnflow.types import Sequence +from columnflow.types import Sequence, Dict, List, Tuple from columnflow.util import maybe_import ak = maybe_import("awkward") @@ -300,9 +301,11 @@ def plot_roc( sample_weights: Sequence | bool = False, skip_uncertainties: bool = False, n_thresholds: int = 200 + 1, + skip_discriminators: list[str] = [], + evaluation_type: str = "OvR", *args, **kwargs, -) -> tuple[plt.Figure, dict]: +) -> tuple[List[plt.Figure], dict]: """ Generates the figure of the ROC curve given the events output of the ML evaluation Args: @@ -330,16 +333,26 @@ def plot_roc( thresholds = np.geomspace(1e-6, 1, n_thresholds) weights = create_sample_weights(sample_weights, events, list(events.keys())) discriminators = list(events.values())[0].fields + figs = [] - def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> dict: + if evaluation_type not in ["OvO", "OvR"]: + raise ValueError("Illeagal Argument! Evaluation Type can only be choosen as \"OvO\" (One vs One) \ + or \"OvR\" (One vs Rest)") + + def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> Dict[str, Dict[str, np.ndarray]]: hists = {} - for cls, predictions in events.items(): - hists[cls] = {} - for disc in discriminators: - hists[cls][disc] = weights[cls] * ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]) + for disc in discriminators: + hists[disc] = {} + for cls, predictions in events.items(): + hists[disc][cls] = weights[cls] * ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]) return hists - def binary_roc_data(positiv_hist: np.ndarray, negativ_hist: np.ndarray, *args, **kwargs) -> tuple: + def binary_roc_data( + positiv_hist: np.ndarray, + negativ_hist: np.ndarray, + *args, + **kwargs, + ) -> Tuple[np.ndarray, np.ndarray]: """ Compute binary Receiver operating characteristic (ROC) values. Used as a helper function for the multi-dimensional ROC curve @@ -361,81 +374,81 @@ def roc_curve_data( histograms: dict, *args, **kwargs - ) -> dict: + ) -> Dict[str, Dict[str, np.ndarray]]: """ Compute Receiver operating characteristic (ROC) values for a multi-dimensional output. """ - def one_vs_rest(names): - result = {} - for ind, cls_name in enumerate(names): - positiv_inputs = model_output[:, ind] - fpr, tpr, th = binary_roc_data(true_labels=(true_labels == ind), - model_output_positive=positiv_inputs, - sample_weights=sample_weights, - *args, - thresholds=thresholds, - errors=errors, - output_length=output_length) - result[f"{cls_name}_vs_rest"] = {"fpr": fpr, - "tpr": tpr, - "thresholds": th} - return result - - def one_vs_one(names): - result = {} - for pos_ind, cls_name in enumerate(names): - for neg_ind, cls_name2 in enumerate(names): - if (pos_ind == neg_ind): - continue - - # Event selection masks only for the 2 classes analysed - inputs_mask = np.logical_or(true_labels == pos_ind, - true_labels == neg_ind) - select_input = model_output[inputs_mask] - select_labels = true_labels[inputs_mask] - select_weights = None if sample_weights is None else sample_weights[inputs_mask] - - positiv_inputs = select_input[:, pos_ind] - fpr, tpr, th = binary_roc_data(true_labels=(select_labels == pos_ind), - model_output_positive=positiv_inputs, - sample_weights=select_weights, - *args, - thresholds=thresholds, - errors=errors, - output_length=output_length) - result[f"{cls_name}_vs_{cls_name2}"] = {"fpr": fpr, - "tpr": tpr, - "thresholds": th} - return result - - is_input_valid(true_labels, model_output) - - # reshape in case only predictions for the positive class are givin - if model_output.ndim != 2: - model_output = model_output.reshape((model_output.size, 1)) - - # Generate class names if not givin - if class_names is None: - class_names = list(range(model_output.shape[1])) - - assert (len(class_names) == model_output.shape[1]), ( - "Number of givin class names does not match the number of output nodes in the *model_output* argument!") - - # Cast trues labels to class numers - if true_labels.dtype.name == "bool": - true_labels = np.logical_not(true_labels).astype(dtype=np.int32) - - # Map true labels to integers if needed - if "int" not in true_labels.dtype.name: - for ind, name in enumerate(class_names): - true_labels = np.where(true_labels == name, ind, true_labels) - - # Choose the evaluation type - if (evaluation_type == "OvO"): - return one_vs_one(class_names) - elif (evaluation_type == "OvR"): - return one_vs_rest(class_names) - else: - raise ValueError("Illeagal Argument! Evaluation Type can only be choosen as \"OvO\" (One vs One) \ - or \"OvR\" (One vs Rest)") + result = {} + + for disc in discriminators: + tmp = {} + + if disc in skip_discriminators: + continue + + # Choose the evaluation type + if (evaluation_type == "OvO"): + for pos_cls, pos_hist in histograms[disc].items(): + for neg_cls, neg_hist in histograms[disc].items(): + + fpr, tpr = binary_roc_data( + positiv_hist=pos_hist, + negativ_hist=neg_hist, + *args, + **kwargs, + ) + tmp[f"{pos_cls}_vs_{neg_hist}"] = {"fpr": fpr, "tpr": tpr} + + elif (evaluation_type == "OvR"): + for pos_cls, pos_hist in histograms[disc].items(): + neg_hist = np.zeros_like(pos_hist) + for neg_cls, neg_pred in histograms[disc].items(): + if (pos_cls == neg_cls): + continue + neg_hist += neg_pred + + fpr, tpr = binary_roc_data( + positiv_hist=pos_hist, + negativ_hist=neg_hist, + *args, + **kwargs, + ) + tmp[f"{pos_cls}_vs_rest"] = {"fpr": fpr, "tpr": tpr} + + result[disc] = tmp + + return result + + def plot_roc_curve( + roc_data: dict, + title: str, + cmap_label: str = "Accuracy", + digits: int = 3, + *args, + **kwargs, + ) -> plt.figure: + """ + Plots a ROC curve. + + :param roc_data: The ROC curve data to plot. + :param title: The title of the plot, displayed in the top right corner. + :param colormap: The name of the colormap to use. Can be selected from the following: + "cf_cmap", "cf_green_cmap", "cf_ygb_cmap", "viridis". + :param cmap_label: The label of the colorbar. + :param digits: The number of digits to display for each value in the matrix. + :param *args: Additional arguments to pass to the function. + :param **kwargs: Additional keyword arguments to pass to the function. + + :return: The resulting plot. + """ + title = ax.set_title("\n".join(wrap( + "Some really really long long long title I really really need - and just can't - just can't - make it any - simply any - shorter - at all.", 60))) + + histograms = create_histograms(events, weights, *args, **kwargs) + + results = roc_curve_data(evaluation_type, histograms, *args, **kwargs) + from IPython import embed + embed(header="ROC Curve Data") + results["thresholds"] = thresholds + return figs, results From 3175f4b379202610c256eeaf21e58d03905600e6 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 16 Nov 2023 17:22:05 +0100 Subject: [PATCH 27/36] fixed wrong typing --- columnflow/plotting/plot_ml_evaluation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index f47a51dfc..2a3c8e4d4 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -8,7 +8,7 @@ import re -from columnflow.types import Sequence +from columnflow.types import Sequence, List, Tuple from columnflow.util import maybe_import ak = maybe_import("awkward") @@ -83,7 +83,7 @@ def plot_cm( y_labels: list[str] | None = None, *args, **kwargs, -) -> tuple[plt.Figure, np.ndarray]: +) -> Tuple[List[plt.Figure], np.ndarray]: """ Generates the figure of the confusion matrix given the output of the nodes and an array of true labels. The Cronfusion matrix can also be weighted. From 4368c388b0b26b77a771dae79b53beca555dd253 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 17 Nov 2023 14:27:27 +0100 Subject: [PATCH 28/36] ROC curve working :) --- columnflow/plotting/plot_ml_evaluation.py | 132 ++++++++++++++-------- columnflow/tasks/ml.py | 14 ++- 2 files changed, 98 insertions(+), 48 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index de23c2e91..f9bb5403a 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -5,7 +5,6 @@ """ from __future__ import annotations -from calendar import c import re @@ -82,6 +81,8 @@ def plot_cm( skip_uncertainties: bool = False, x_labels: list[str] | None = None, y_labels: list[str] | None = None, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> Tuple[List[plt.Figure], np.ndarray]: @@ -150,12 +151,13 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: def plot_confusion_matrix( cm: np.ndarray, - title: str = "", colormap: str = "cf_cmap", cmap_label: str = "Accuracy", digits: int = 3, x_labels: list[str] | None = None, y_labels: list[str] | None = None, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> plt.figure: @@ -283,13 +285,16 @@ def fmt(v): ) # final touches - hep.cms.label(ax=ax, llabel="private work", rlabel=title if title else "") + hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) plt.tight_layout() return fig cm = get_conf_matrix(sample_weights, *args, **kwargs) + print("Confusion matrix calculated!") + fig = plot_confusion_matrix(cm, x_labels=x_labels, y_labels=y_labels, *args, **kwargs) + print("Confusion matrix plotted!") return [fig], cm @@ -299,35 +304,36 @@ def plot_roc( config_inst: od.Config, category_inst: od.Category, sample_weights: Sequence | bool = False, - skip_uncertainties: bool = False, n_thresholds: int = 200 + 1, skip_discriminators: list[str] = [], evaluation_type: str = "OvR", + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> tuple[List[plt.Figure], dict]: - """ Generates the figure of the ROC curve given the events output of the ML evaluation - - Args: - events (dict): dictionary with the true labels as keys and the model output of - the events as values. - config_inst (od.Config): used configuration for the plot - category_inst (od.Category): used category instance, for which the plot is created - sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not - givin the weights are calculated based on the number of eventsDefaults to None. - skip_uncertainties (bool, optional): calculate errors of the cm elements. Defaults to False. - n_thresholds (int): number of thresholds used for the ROC curve - *args: Additional arguments to pass to the function. - **kwargs: Additional keyword arguments to pass to the function. - - Returns: - tuple[plt.Figure, dict]: The resulting plot and the ROC curve data. - - Raises: - AssertionError: If both predictions and labels have mismatched shapes, - or if *weights* is not *None* and its shape doesn't match *predictions*. - AssertionError: If *normalization* is not one of *None*, "row", "column". + """ + Generates the figure of the ROC curve given the output of the nodes + and an array of true labels. The ROC curve can also be weighted. + + :param events: dictionary with the true labels as keys and the model output of the events as values. + :param config_inst: used configuration for the plot. + :param category_inst: used category instance, for which the plot is created. + :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are + calculated based on the number of events. + :param n_thresholds: number of thresholds to use for the ROC curve. + :param skip_discriminators: list of discriminators to skip. + :param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is "OvR". + :param cms_rlabel: right label of the CMS label. + :param cms_llabel: left label of the CMS label. + :param *args: Additional arguments to pass to the function. + :param **kwargs: Additional keyword arguments to pass to the function. + + :return: The resulting plot and the ROC curve. + :raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights* + is not *None* and its shape doesn't match *predictions*. + :raises AssertionError: If *normalization* is not one of *None*, "row", "column". """ # defining some useful properties and output shapes thresholds = np.geomspace(1e-6, 1, n_thresholds) @@ -340,6 +346,9 @@ def plot_roc( or \"OvR\" (One vs Rest)") def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> Dict[str, Dict[str, np.ndarray]]: + """ + Helper function to create histograms for the different discriminators and classes. + """ hists = {} for disc in discriminators: hists[disc] = {} @@ -357,7 +366,6 @@ def binary_roc_data( Compute binary Receiver operating characteristic (ROC) values. Used as a helper function for the multi-dimensional ROC curve """ - # Calculate the different rates fn = np.cumsum(positiv_hist) tn = np.cumsum(negativ_hist) @@ -373,12 +381,11 @@ def roc_curve_data( evaluation_type: str, histograms: dict, *args, - **kwargs + **kwargs, ) -> Dict[str, Dict[str, np.ndarray]]: """ Compute Receiver operating characteristic (ROC) values for a multi-dimensional output. """ - result = {} for disc in discriminators: @@ -398,7 +405,7 @@ def roc_curve_data( *args, **kwargs, ) - tmp[f"{pos_cls}_vs_{neg_hist}"] = {"fpr": fpr, "tpr": tpr} + tmp[f"{pos_cls}_vs_{neg_cls}"] = {"fpr": fpr, "tpr": tpr} elif (evaluation_type == "OvR"): for pos_cls, pos_hist in histograms[disc].items(): @@ -423,32 +430,67 @@ def roc_curve_data( def plot_roc_curve( roc_data: dict, title: str, - cmap_label: str = "Accuracy", - digits: int = 3, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> plt.figure: """ Plots a ROC curve. + """ + def auc_score(fpr: list, tpr: list, *args) -> np.float64: + """ + Compute the area under the curve using the trapezoidal rule. + """ + sign = 1 + if np.any(np.diff(fpr) < 0): + if np.all(np.diff(fpr) <= 0): + sign = -1 + else: + raise ValueError("x is neither increasing nor decreasing : {}.".format(fpr)) - :param roc_data: The ROC curve data to plot. - :param title: The title of the plot, displayed in the top right corner. - :param colormap: The name of the colormap to use. Can be selected from the following: - "cf_cmap", "cf_green_cmap", "cf_ygb_cmap", "viridis". - :param cmap_label: The label of the colorbar. - :param digits: The number of digits to display for each value in the matrix. - :param *args: Additional arguments to pass to the function. - :param **kwargs: Additional keyword arguments to pass to the function. + return sign * np.trapz(tpr, fpr) - :return: The resulting plot. - """ - title = ax.set_title("\n".join(wrap( - "Some really really long long long title I really really need - and just can't - just can't - make it any - simply any - shorter - at all.", 60))) + fpr = roc_data["fpr"] + tpr = roc_data["tpr"] + plt.style.use(hep.style.CMS) + fig, ax = plt.subplots(dpi=300) + ax.set_xlabel("FPR", loc="right", labelpad=10, fontsize=25) + ax.set_ylabel("TPR", loc="top", labelpad=15, fontsize=25) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + ax.plot( + fpr, + tpr, + color="#67cf02", + label=f"AUC = {auc_score(roc_data['fpr'], roc_data['tpr']):.3f}", + ) + ax.plot([0, 1], [0, 1], color="black", linestyle="--") + + # setting titles and legend + ax.legend(loc="lower right", fontsize=15) + ax.set_title(title, wrap=True, pad=50, fontsize=30) + hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) + plt.tight_layout() + + return fig + + # create historgrams and calculate FPR and TPR histograms = create_histograms(events, weights, *args, **kwargs) + print("histograms created!") results = roc_curve_data(evaluation_type, histograms, *args, **kwargs) - from IPython import embed - embed(header="ROC Curve Data") + print("ROC data calculated!") + + # plotting + for disc, roc in results.items(): + for cls, roc_data in roc.items(): + title = rf"{cls.replace('_vs_', ' VS ')} with {disc}$" + figs.append(plot_roc_curve(roc_data, title, cms_llabel=cms_llabel, cms_rlabel=cms_rlabel, *args, **kwargs)) + print("ROC curves plotted!") + results["thresholds"] = thresholds + return figs, results diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 62c67a99f..b52dee209 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -1114,9 +1114,13 @@ def output(self: PlotMLResults): b = self.branch_data return {"plots": [ self.target(name) - for name in self.get_plot_names(f"0_plot__proc_{self.processes_repr}__cat_{b.category}") + for name in self.get_plot_names( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", + ) ], - "array": self.target(f"plot__proc_{self.processes_repr}.parquet"), + "array": self.target( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/data.parquet", + ), } @law.decorator.log @@ -1138,6 +1142,7 @@ def run(self: PlotMLResults): config_inst=self.config_inst, category_inst=category_inst, skip_uncertainties=self.skip_uncertainties, + cms_llabel=self.cms_label, **self.get_plot_parameters(), ) @@ -1152,4 +1157,7 @@ def run(self: PlotMLResults): continue for index, f in enumerate(figs): - f.savefig(file_path.abs_dirname + "/" + f"{index}_{file_path.basename[2:]}", format=file_path.ext()) + f.savefig( + file_path.abs_dirname + "/" + file_path.basename.replace("0", str(index)), + format=file_path.ext(), + ) From 8b3f78a55bd099a3e0acb54b44b114aeae73cde7 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 17 Nov 2023 15:41:06 +0100 Subject: [PATCH 29/36] added ROC tests --- columnflow/plotting/plot_ml_evaluation.py | 4 +- tests/test_plotting.py | 80 ++++++++++++++++++----- 2 files changed, 66 insertions(+), 18 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index f9bb5403a..b68e6b1e8 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -331,9 +331,9 @@ def plot_roc( :return: The resulting plot and the ROC curve. - :raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights* + :raises ValueError: If both predictions and labels have mismatched shapes, or if *weights* is not *None* and its shape doesn't match *predictions*. - :raises AssertionError: If *normalization* is not one of *None*, "row", "column". + :raises ValueError: If *normalization* is not one of *None*, "row", "column". """ # defining some useful properties and output shapes thresholds = np.geomspace(1e-6, 1, n_thresholds) diff --git a/tests/test_plotting.py b/tests/test_plotting.py index aca56b16a..a33a825be 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -2,13 +2,13 @@ Test the plot_ml_evaluation module. """ -__all__ = ["TestPlotCM"] +__all__ = ["TestPlotCM", "TestPlotROC"] import unittest from unittest.mock import MagicMock from columnflow.util import maybe_import -from columnflow.plotting.plot_ml_evaluation import plot_cm +from columnflow.plotting.plot_ml_evaluation import plot_cm, plot_roc np = maybe_import("numpy") @@ -57,8 +57,8 @@ def test_plot_cm(self): self.assertIsInstance(fig, list) self.assertIsInstance(fig[0], plt.Figure) self.assertIsInstance(cm, np.ndarray) - self.assertEqual(cm.shape, (2, 3)) - self.assertEqual(cm.tolist(), self.weighted_matrix.tolist()) + self.assertTupleEqual(cm.shape, (2, 3)) + self.assertListEqual(cm.tolist(), self.weighted_matrix.tolist()) def test_plot_cm_no_weights(self): fig, cm = plot_cm( @@ -73,8 +73,8 @@ def test_plot_cm_no_weights(self): self.assertIsInstance(fig, list) self.assertIsInstance(fig[0], plt.Figure) self.assertIsInstance(cm, np.ndarray) - self.assertEqual(cm.shape, (2, 3)) - self.assertEqual(cm.tolist(), self.unweighted_matrix.tolist()) + self.assertTupleEqual(cm.shape, (2, 3)) + self.assertListEqual(cm.tolist(), self.unweighted_matrix.tolist()) def test_plot_cm_skip_uncertainties(self): fig, cm = plot_cm( @@ -90,8 +90,8 @@ def test_plot_cm_skip_uncertainties(self): self.assertIsInstance(fig, list) self.assertIsInstance(fig[0], plt.Figure) self.assertIsInstance(cm, np.ndarray) - self.assertEqual(cm.shape, (2, 3)) - self.assertEqual(cm.tolist(), self.weighted_matrix.tolist()) + self.assertTupleEqual(cm.shape, (2, 3)) + self.assertListEqual(cm.tolist(), self.weighted_matrix.tolist()) def test_plot_cm_no_labels(self): fig, cm = plot_cm( @@ -104,8 +104,8 @@ def test_plot_cm_no_labels(self): ) x_labels = ["out0", "out1", "out2"] y_labels = ["dataset_1", "dataset_2"] - self.assertEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) - self.assertEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) + self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) + self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) def test_plot_cm_labels(self): x_labels = ["vbf", "ggf", "other"] @@ -120,8 +120,8 @@ def test_plot_cm_labels(self): x_labels=x_labels, y_labels=y_labels, ) - self.assertEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) - self.assertEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) + self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) + self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) def test_plot_cm_invalid_normalization(self): with self.assertRaises(ValueError): @@ -146,8 +146,8 @@ def test_plot_cm_no_normalization(self): x_labels=self.x_labels, y_labels=self.y_labels, ) - self.assertEqual(cm.shape, (2, 3)) - self.assertEqual(cm.tolist(), self.not_normalized_matrix.tolist()) + self.assertTupleEqual(cm.shape, (2, 3)) + self.assertListEqual(cm.tolist(), self.not_normalized_matrix.tolist()) def test_plot_cm_column_normalization(self): fig, cm = plot_cm( @@ -159,8 +159,8 @@ def test_plot_cm_column_normalization(self): x_labels=self.x_labels, y_labels=self.y_labels, ) - self.assertEqual(cm.shape, (2, 3)) - self.assertEqual(cm.tolist(), self.column_normalized_matrix.tolist()) + self.assertTupleEqual(cm.shape, (2, 3)) + self.assertListEqual(cm.tolist(), self.column_normalized_matrix.tolist()) def test_plot_cm_mismatched_weights_shape(self): sample_weights = [1, 2, 3] @@ -177,5 +177,53 @@ def test_plot_cm_mismatched_weights_shape(self): ) +class TestPlotROC(unittest.TestCase): + + def setUp(self): + self.events = { + "dataset_1": ak.Array({ + "out1": [0.1, 0.1, 0.3, 0.5], + "out2": [0.2, 0.2, 0.4, 0.2], + }), + "dataset_2": ak.Array({ + "out1": [0.2, 0.2, 0.4, 0.3], + "out2": [0.3, 0.3, 0.3, 0.2], + }), + "dataset_3": ak.Array({ + "out1": [0.1, 0.7, 0.3, 0.5], + "out2": [0.2, 0.2, 0.4, 0.2], + }), + } + self.N_discriminators = 2 + self.config_inst = MagicMock() + self.category_inst = MagicMock() + self.results_out1 = { + "dataset_1": (0.75, 0.5), + "dataset_2": (0.5, 0.625), + "dataset_3": (0.5, 0.625), + } + + def test_plot_roc_returns_figures_and_results(self): + figs, results = plot_roc(self.events, self.config_inst, self.category_inst) + self.assertIsInstance(figs, list) + self.assertIsInstance(results, dict) + + def test_plot_roc_returns_correct_number_of_figures(self): + figs_ovr, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR") + figs_ovo, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvO") + + self.assertEqual(len(figs_ovr), self.N_discriminators * len(self.events)) + self.assertEqual(len(figs_ovo), self.N_discriminators * len(self.events) * (len(self.events))) + + def test_plot_roc_returns_correct_results(self): + _, results = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR") + # not implemented yet + pass + + def test_plot_roc_raises_value_error_for_invalid_evaluation_type(self): + with self.assertRaises(ValueError): + plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="InvalidType") + + if __name__ == "__main__": unittest.main() From 855c084737b5c9607a5a57dd3a788f7aa240298e Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 17 Nov 2023 14:27:27 +0100 Subject: [PATCH 30/36] ROC curve working :) --- columnflow/plotting/plot_ml_evaluation.py | 132 ++++++++++++++-------- columnflow/tasks/ml.py | 14 ++- 2 files changed, 98 insertions(+), 48 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index de23c2e91..b68e6b1e8 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -5,7 +5,6 @@ """ from __future__ import annotations -from calendar import c import re @@ -82,6 +81,8 @@ def plot_cm( skip_uncertainties: bool = False, x_labels: list[str] | None = None, y_labels: list[str] | None = None, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> Tuple[List[plt.Figure], np.ndarray]: @@ -150,12 +151,13 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: def plot_confusion_matrix( cm: np.ndarray, - title: str = "", colormap: str = "cf_cmap", cmap_label: str = "Accuracy", digits: int = 3, x_labels: list[str] | None = None, y_labels: list[str] | None = None, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> plt.figure: @@ -283,13 +285,16 @@ def fmt(v): ) # final touches - hep.cms.label(ax=ax, llabel="private work", rlabel=title if title else "") + hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) plt.tight_layout() return fig cm = get_conf_matrix(sample_weights, *args, **kwargs) + print("Confusion matrix calculated!") + fig = plot_confusion_matrix(cm, x_labels=x_labels, y_labels=y_labels, *args, **kwargs) + print("Confusion matrix plotted!") return [fig], cm @@ -299,35 +304,36 @@ def plot_roc( config_inst: od.Config, category_inst: od.Category, sample_weights: Sequence | bool = False, - skip_uncertainties: bool = False, n_thresholds: int = 200 + 1, skip_discriminators: list[str] = [], evaluation_type: str = "OvR", + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> tuple[List[plt.Figure], dict]: - """ Generates the figure of the ROC curve given the events output of the ML evaluation - - Args: - events (dict): dictionary with the true labels as keys and the model output of - the events as values. - config_inst (od.Config): used configuration for the plot - category_inst (od.Category): used category instance, for which the plot is created - sample_weights (np.ndarray or bool, optional): sample weights of the events. If an explicit array is not - givin the weights are calculated based on the number of eventsDefaults to None. - skip_uncertainties (bool, optional): calculate errors of the cm elements. Defaults to False. - n_thresholds (int): number of thresholds used for the ROC curve - *args: Additional arguments to pass to the function. - **kwargs: Additional keyword arguments to pass to the function. - - Returns: - tuple[plt.Figure, dict]: The resulting plot and the ROC curve data. - - Raises: - AssertionError: If both predictions and labels have mismatched shapes, - or if *weights* is not *None* and its shape doesn't match *predictions*. - AssertionError: If *normalization* is not one of *None*, "row", "column". + """ + Generates the figure of the ROC curve given the output of the nodes + and an array of true labels. The ROC curve can also be weighted. + + :param events: dictionary with the true labels as keys and the model output of the events as values. + :param config_inst: used configuration for the plot. + :param category_inst: used category instance, for which the plot is created. + :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are + calculated based on the number of events. + :param n_thresholds: number of thresholds to use for the ROC curve. + :param skip_discriminators: list of discriminators to skip. + :param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is "OvR". + :param cms_rlabel: right label of the CMS label. + :param cms_llabel: left label of the CMS label. + :param *args: Additional arguments to pass to the function. + :param **kwargs: Additional keyword arguments to pass to the function. + :return: The resulting plot and the ROC curve. + + :raises ValueError: If both predictions and labels have mismatched shapes, or if *weights* + is not *None* and its shape doesn't match *predictions*. + :raises ValueError: If *normalization* is not one of *None*, "row", "column". """ # defining some useful properties and output shapes thresholds = np.geomspace(1e-6, 1, n_thresholds) @@ -340,6 +346,9 @@ def plot_roc( or \"OvR\" (One vs Rest)") def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> Dict[str, Dict[str, np.ndarray]]: + """ + Helper function to create histograms for the different discriminators and classes. + """ hists = {} for disc in discriminators: hists[disc] = {} @@ -357,7 +366,6 @@ def binary_roc_data( Compute binary Receiver operating characteristic (ROC) values. Used as a helper function for the multi-dimensional ROC curve """ - # Calculate the different rates fn = np.cumsum(positiv_hist) tn = np.cumsum(negativ_hist) @@ -373,12 +381,11 @@ def roc_curve_data( evaluation_type: str, histograms: dict, *args, - **kwargs + **kwargs, ) -> Dict[str, Dict[str, np.ndarray]]: """ Compute Receiver operating characteristic (ROC) values for a multi-dimensional output. """ - result = {} for disc in discriminators: @@ -398,7 +405,7 @@ def roc_curve_data( *args, **kwargs, ) - tmp[f"{pos_cls}_vs_{neg_hist}"] = {"fpr": fpr, "tpr": tpr} + tmp[f"{pos_cls}_vs_{neg_cls}"] = {"fpr": fpr, "tpr": tpr} elif (evaluation_type == "OvR"): for pos_cls, pos_hist in histograms[disc].items(): @@ -423,32 +430,67 @@ def roc_curve_data( def plot_roc_curve( roc_data: dict, title: str, - cmap_label: str = "Accuracy", - digits: int = 3, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> plt.figure: """ Plots a ROC curve. + """ + def auc_score(fpr: list, tpr: list, *args) -> np.float64: + """ + Compute the area under the curve using the trapezoidal rule. + """ + sign = 1 + if np.any(np.diff(fpr) < 0): + if np.all(np.diff(fpr) <= 0): + sign = -1 + else: + raise ValueError("x is neither increasing nor decreasing : {}.".format(fpr)) - :param roc_data: The ROC curve data to plot. - :param title: The title of the plot, displayed in the top right corner. - :param colormap: The name of the colormap to use. Can be selected from the following: - "cf_cmap", "cf_green_cmap", "cf_ygb_cmap", "viridis". - :param cmap_label: The label of the colorbar. - :param digits: The number of digits to display for each value in the matrix. - :param *args: Additional arguments to pass to the function. - :param **kwargs: Additional keyword arguments to pass to the function. + return sign * np.trapz(tpr, fpr) - :return: The resulting plot. - """ - title = ax.set_title("\n".join(wrap( - "Some really really long long long title I really really need - and just can't - just can't - make it any - simply any - shorter - at all.", 60))) + fpr = roc_data["fpr"] + tpr = roc_data["tpr"] + plt.style.use(hep.style.CMS) + fig, ax = plt.subplots(dpi=300) + ax.set_xlabel("FPR", loc="right", labelpad=10, fontsize=25) + ax.set_ylabel("TPR", loc="top", labelpad=15, fontsize=25) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + + ax.plot( + fpr, + tpr, + color="#67cf02", + label=f"AUC = {auc_score(roc_data['fpr'], roc_data['tpr']):.3f}", + ) + ax.plot([0, 1], [0, 1], color="black", linestyle="--") + + # setting titles and legend + ax.legend(loc="lower right", fontsize=15) + ax.set_title(title, wrap=True, pad=50, fontsize=30) + hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) + plt.tight_layout() + + return fig + + # create historgrams and calculate FPR and TPR histograms = create_histograms(events, weights, *args, **kwargs) + print("histograms created!") results = roc_curve_data(evaluation_type, histograms, *args, **kwargs) - from IPython import embed - embed(header="ROC Curve Data") + print("ROC data calculated!") + + # plotting + for disc, roc in results.items(): + for cls, roc_data in roc.items(): + title = rf"{cls.replace('_vs_', ' VS ')} with {disc}$" + figs.append(plot_roc_curve(roc_data, title, cms_llabel=cms_llabel, cms_rlabel=cms_rlabel, *args, **kwargs)) + print("ROC curves plotted!") + results["thresholds"] = thresholds + return figs, results diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 62c67a99f..b52dee209 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -1114,9 +1114,13 @@ def output(self: PlotMLResults): b = self.branch_data return {"plots": [ self.target(name) - for name in self.get_plot_names(f"0_plot__proc_{self.processes_repr}__cat_{b.category}") + for name in self.get_plot_names( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", + ) ], - "array": self.target(f"plot__proc_{self.processes_repr}.parquet"), + "array": self.target( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/data.parquet", + ), } @law.decorator.log @@ -1138,6 +1142,7 @@ def run(self: PlotMLResults): config_inst=self.config_inst, category_inst=category_inst, skip_uncertainties=self.skip_uncertainties, + cms_llabel=self.cms_label, **self.get_plot_parameters(), ) @@ -1152,4 +1157,7 @@ def run(self: PlotMLResults): continue for index, f in enumerate(figs): - f.savefig(file_path.abs_dirname + "/" + f"{index}_{file_path.basename[2:]}", format=file_path.ext()) + f.savefig( + file_path.abs_dirname + "/" + file_path.basename.replace("0", str(index)), + format=file_path.ext(), + ) From 4a073315d0bf2b1523932816894344ab2f459511 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 17 Nov 2023 15:49:10 +0100 Subject: [PATCH 31/36] adapted to upcoming ROC curve --- columnflow/plotting/plot_ml_evaluation.py | 20 +++++++++++++++++--- columnflow/tasks/ml.py | 14 +++++++++++--- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 2a3c8e4d4..0cd00f734 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -81,6 +81,8 @@ def plot_cm( skip_uncertainties: bool = False, x_labels: list[str] | None = None, y_labels: list[str] | None = None, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> Tuple[List[plt.Figure], np.ndarray]: @@ -149,12 +151,13 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: def plot_confusion_matrix( cm: np.ndarray, - title: str = "", colormap: str = "cf_cmap", cmap_label: str = "Accuracy", digits: int = 3, x_labels: list[str] | None = None, y_labels: list[str] | None = None, + cms_rlabel: str = "", + cms_llabel: str = "private work", *args, **kwargs, ) -> plt.figure: @@ -282,12 +285,23 @@ def fmt(v): ) # final touches - hep.cms.label(ax=ax, llabel="private work", rlabel=title if title else "") + hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel) plt.tight_layout() return fig cm = get_conf_matrix(sample_weights, *args, **kwargs) - fig = plot_confusion_matrix(cm, x_labels=x_labels, y_labels=y_labels, *args, **kwargs) + print("Confusion matrix calculated!") + + fig = plot_confusion_matrix( + cm, + x_labels=x_labels, + y_labels=y_labels, + cms_llabel=cms_llabel, + cms_rlabel=cms_rlabel, + *args, + **kwargs + ) + print("Confusion matrix plotted!") return [fig], cm diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 62c67a99f..b52dee209 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -1114,9 +1114,13 @@ def output(self: PlotMLResults): b = self.branch_data return {"plots": [ self.target(name) - for name in self.get_plot_names(f"0_plot__proc_{self.processes_repr}__cat_{b.category}") + for name in self.get_plot_names( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", + ) ], - "array": self.target(f"plot__proc_{self.processes_repr}.parquet"), + "array": self.target( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/data.parquet", + ), } @law.decorator.log @@ -1138,6 +1142,7 @@ def run(self: PlotMLResults): config_inst=self.config_inst, category_inst=category_inst, skip_uncertainties=self.skip_uncertainties, + cms_llabel=self.cms_label, **self.get_plot_parameters(), ) @@ -1152,4 +1157,7 @@ def run(self: PlotMLResults): continue for index, f in enumerate(figs): - f.savefig(file_path.abs_dirname + "/" + f"{index}_{file_path.basename[2:]}", format=file_path.ext()) + f.savefig( + file_path.abs_dirname + "/" + file_path.basename.replace("0", str(index)), + format=file_path.ext(), + ) From be90e08ef63a38f13bb98bc1faad20ecbcf7607b Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 17 Nov 2023 15:52:23 +0100 Subject: [PATCH 32/36] trailing comma --- columnflow/plotting/plot_ml_evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 0cd00f734..e08d24a7f 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -300,7 +300,7 @@ def fmt(v): cms_llabel=cms_llabel, cms_rlabel=cms_rlabel, *args, - **kwargs + **kwargs, ) print("Confusion matrix plotted!") From eddf53031c4317fc9520125ce9ca0ed137ed0db0 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Fri, 1 Dec 2023 16:13:18 +0100 Subject: [PATCH 33/36] improved tests = minor linting fixes --- columnflow/plotting/plot_ml_evaluation.py | 16 +- tests/test_plotting.py | 203 +++++++++++++--------- 2 files changed, 129 insertions(+), 90 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 9124936fd..2923ce4d4 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -344,7 +344,7 @@ def plot_roc( :raises ValueError: If *normalization* is not one of *None*, "row", "column". """ # defining some useful properties and output shapes - thresholds = np.geomspace(1e-6, 1, n_thresholds) + thresholds = np.linspace(0, 1, n_thresholds) weights = create_sample_weights(sample_weights, events, list(events.keys())) discriminators = list(events.values())[0].fields figs = [] @@ -361,7 +361,8 @@ def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> Di for disc in discriminators: hists[disc] = {} for cls, predictions in events.items(): - hists[disc][cls] = weights[cls] * ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]) + hists[disc][cls] = (sample_weights[cls] * + ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0])) return hists def binary_roc_data( @@ -495,8 +496,15 @@ def auc_score(fpr: list, tpr: list, *args) -> np.float64: # plotting for disc, roc in results.items(): for cls, roc_data in roc.items(): - title = rf"{cls.replace('_vs_', ' VS ')} with {disc}$" - figs.append(plot_roc_curve(roc_data, title, cms_llabel=cms_llabel, cms_rlabel=cms_rlabel, *args, **kwargs)) + title = rf"{cls.replace('_vs_', ' VS ')} with {disc}" + figs.append(plot_roc_curve( + roc_data=roc_data, + title=title, + cms_rlabel=cms_rlabel, + cms_llabel=cms_llabel, + *args, + **kwargs, + )) print("ROC curves plotted!") results["thresholds"] = thresholds diff --git a/tests/test_plotting.py b/tests/test_plotting.py index a33a825be..80e7d8da5 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -4,8 +4,10 @@ __all__ = ["TestPlotCM", "TestPlotROC"] +import io import unittest from unittest.mock import MagicMock +from contextlib import redirect_stdout from columnflow.util import maybe_import from columnflow.plotting.plot_ml_evaluation import plot_cm, plot_roc @@ -38,22 +40,26 @@ def setUp(self): self.skip_uncertainties = False self.x_labels = ["out1", "out2", "out3"] self.y_labels = ["dataset_1", "dataset_2"] + + # The following results are calculated by hand self.weighted_matrix = np.array([[0.25, 0.25, 0.5], [0.25, 0, 0.75]]) self.unweighted_matrix = np.array([[0.25, 0.25, 0.5], [0.25, 0, 0.75]]) self.not_normalized_matrix = np.array([[1, 1, 2], [1, 0, 3]]) self.column_normalized_matrix = np.array([[0.5, 1, 0.4], [0.5, 0, 0.6]]) + self.text_trap = io.StringIO() def test_plot_cm(self): - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - sample_weights=self.sample_weights, - normalization=self.normalization, - skip_uncertainties=self.skip_uncertainties, - x_labels=self.x_labels, - y_labels=self.y_labels, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) self.assertIsInstance(fig, list) self.assertIsInstance(fig[0], plt.Figure) self.assertIsInstance(cm, np.ndarray) @@ -61,15 +67,16 @@ def test_plot_cm(self): self.assertListEqual(cm.tolist(), self.weighted_matrix.tolist()) def test_plot_cm_no_weights(self): - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - normalization=self.normalization, - skip_uncertainties=self.skip_uncertainties, - x_labels=self.x_labels, - y_labels=self.y_labels, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) self.assertIsInstance(fig, list) self.assertIsInstance(fig[0], plt.Figure) self.assertIsInstance(cm, np.ndarray) @@ -77,16 +84,17 @@ def test_plot_cm_no_weights(self): self.assertListEqual(cm.tolist(), self.unweighted_matrix.tolist()) def test_plot_cm_skip_uncertainties(self): - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - sample_weights=self.sample_weights, - normalization=self.normalization, - skip_uncertainties=True, - x_labels=self.x_labels, - y_labels=self.y_labels, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=True, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) self.assertIsInstance(fig, list) self.assertIsInstance(fig[0], plt.Figure) self.assertIsInstance(cm, np.ndarray) @@ -94,14 +102,15 @@ def test_plot_cm_skip_uncertainties(self): self.assertListEqual(cm.tolist(), self.weighted_matrix.tolist()) def test_plot_cm_no_labels(self): - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - sample_weights=self.sample_weights, - normalization=self.normalization, - skip_uncertainties=self.skip_uncertainties, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + ) x_labels = ["out0", "out1", "out2"] y_labels = ["dataset_1", "dataset_2"] self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) @@ -110,21 +119,22 @@ def test_plot_cm_no_labels(self): def test_plot_cm_labels(self): x_labels = ["vbf", "ggf", "other"] y_labels = ["Higgs", "Graviton"] - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - sample_weights=self.sample_weights, - normalization=self.normalization, - skip_uncertainties=self.skip_uncertainties, - x_labels=x_labels, - y_labels=y_labels, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + sample_weights=self.sample_weights, + normalization=self.normalization, + skip_uncertainties=self.skip_uncertainties, + x_labels=x_labels, + y_labels=y_labels, + ) self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_xticklabels()], x_labels) self.assertListEqual([t.get_text() for t in fig[0].axes[0].get_yticklabels()], y_labels) def test_plot_cm_invalid_normalization(self): - with self.assertRaises(ValueError): + with self.assertRaises(ValueError), redirect_stdout(self.text_trap): plot_cm( events=self.events, config_inst=self.config_inst, @@ -137,34 +147,36 @@ def test_plot_cm_invalid_normalization(self): ) def test_plot_cm_no_normalization(self): - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - normalization=None, - skip_uncertainties=self.skip_uncertainties, - x_labels=self.x_labels, - y_labels=self.y_labels, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + normalization=None, + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) self.assertTupleEqual(cm.shape, (2, 3)) self.assertListEqual(cm.tolist(), self.not_normalized_matrix.tolist()) def test_plot_cm_column_normalization(self): - fig, cm = plot_cm( - events=self.events, - config_inst=self.config_inst, - category_inst=self.category_inst, - normalization="column", - skip_uncertainties=self.skip_uncertainties, - x_labels=self.x_labels, - y_labels=self.y_labels, - ) + with redirect_stdout(self.text_trap): + fig, cm = plot_cm( + events=self.events, + config_inst=self.config_inst, + category_inst=self.category_inst, + normalization="column", + skip_uncertainties=self.skip_uncertainties, + x_labels=self.x_labels, + y_labels=self.y_labels, + ) self.assertTupleEqual(cm.shape, (2, 3)) self.assertListEqual(cm.tolist(), self.column_normalized_matrix.tolist()) def test_plot_cm_mismatched_weights_shape(self): sample_weights = [1, 2, 3] - with self.assertRaises(ValueError): + with self.assertRaises(ValueError), redirect_stdout(self.text_trap): plot_cm( events=self.events, config_inst=self.config_inst, @@ -182,46 +194,65 @@ class TestPlotROC(unittest.TestCase): def setUp(self): self.events = { "dataset_1": ak.Array({ - "out1": [0.1, 0.1, 0.3, 0.5], - "out2": [0.2, 0.2, 0.4, 0.2], + "out1": [0.9, 0.9, 0.7, 0.4], + "out2": [0.1, 0.1, 0.3, 0.6], }), "dataset_2": ak.Array({ - "out1": [0.2, 0.2, 0.4, 0.3], - "out2": [0.3, 0.3, 0.3, 0.2], - }), - "dataset_3": ak.Array({ - "out1": [0.1, 0.7, 0.3, 0.5], - "out2": [0.2, 0.2, 0.4, 0.2], + "out1": [0.2, 0.2, 0.4, 0.8], + "out2": [0.8, 0.8, 0.6, 0.2], }), } self.N_discriminators = 2 self.config_inst = MagicMock() self.category_inst = MagicMock() - self.results_out1 = { - "dataset_1": (0.75, 0.5), - "dataset_2": (0.5, 0.625), - "dataset_3": (0.5, 0.625), + # The following results are calculated by hand + self.results_dataset1_as_signal = { + "out1": { + "fpr": [0.5, 0.25, 0.25, 0], + "tpr": [1, 0.75, 0.5, 0], + }, + "out2": { + "fpr": [0.75, 0.75, 0.5, 0], + "tpr": [0.5, 0.25, 0, 0], + + }, } + self.text_trap = io.StringIO() def test_plot_roc_returns_figures_and_results(self): - figs, results = plot_roc(self.events, self.config_inst, self.category_inst) + with redirect_stdout(self.text_trap): + figs, results = plot_roc(self.events, self.config_inst, self.category_inst) self.assertIsInstance(figs, list) self.assertIsInstance(results, dict) def test_plot_roc_returns_correct_number_of_figures(self): - figs_ovr, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR") - figs_ovo, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvO") + with redirect_stdout(self.text_trap): + figs_ovr, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR") + figs_ovo, _ = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvO") self.assertEqual(len(figs_ovr), self.N_discriminators * len(self.events)) self.assertEqual(len(figs_ovo), self.N_discriminators * len(self.events) * (len(self.events))) def test_plot_roc_returns_correct_results(self): - _, results = plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="OvR") - # not implemented yet - pass + with redirect_stdout(self.text_trap): + _, results = plot_roc( + self.events, + self.config_inst, + self.category_inst, + n_thresholds=5, + evaluation_type="OvR", + ) + fpr_out1 = results["out1"]["dataset_1_vs_rest"]["fpr"].tolist() + tpr_out1 = results["out1"]["dataset_1_vs_rest"]["tpr"].tolist() + fpr_out2 = results["out2"]["dataset_1_vs_rest"]["fpr"].tolist() + tpr_out2 = results["out2"]["dataset_1_vs_rest"]["tpr"].tolist() + self.assertListEqual(fpr_out1, self.results_dataset1_as_signal["out1"]["fpr"]) + self.assertListEqual(tpr_out1, self.results_dataset1_as_signal["out1"]["tpr"]) + self.assertListEqual(fpr_out2, self.results_dataset1_as_signal["out2"]["fpr"]) + self.assertListEqual(tpr_out2, self.results_dataset1_as_signal["out2"]["tpr"]) def test_plot_roc_raises_value_error_for_invalid_evaluation_type(self): - with self.assertRaises(ValueError): + with self.assertRaises(ValueError), redirect_stdout(self.text_trap): plot_roc(self.events, self.config_inst, self.category_inst, evaluation_type="InvalidType") From e8f98591499dec9439a78efb885452ea5651d841 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 7 Dec 2023 18:37:45 +0100 Subject: [PATCH 34/36] fixed --- columnflow/plotting/plot_ml_evaluation.py | 51 ++++++++++++----------- columnflow/tasks/ml.py | 36 ++++++++-------- 2 files changed, 45 insertions(+), 42 deletions(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index e08d24a7f..dfa97696d 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -8,7 +8,7 @@ import re -from columnflow.types import Sequence, List, Tuple +from columnflow.types import Sequence from columnflow.util import maybe_import ak = maybe_import("awkward") @@ -19,7 +19,7 @@ hep = maybe_import("mplhep") colors = maybe_import("matplotlib.colors") -# Define a CF custom color maps +# define a CF custom color maps cf_colors = { "cf_green_cmap": colors.ListedColormap([ "#212121", "#242723", "#262D25", "#283426", "#2A3A26", "#2C4227", "#2E4927", @@ -85,23 +85,26 @@ def plot_cm( cms_llabel: str = "private work", *args, **kwargs, -) -> Tuple[List[plt.Figure], np.ndarray]: - """ Generates the figure of the confusion matrix given the output of the nodes +) -> tuple[list[plt.Figure], np.ndarray]: + """ + Generates the figure of the confusion matrix given the output of the nodes and an array of true labels. The Cronfusion matrix can also be weighted. - :param events: dictionary with the true labels as keys and the model output of the events as values. - :param config_inst: used configuration for the plot. - :param category_inst: used category instance, for which the plot is created. - :param sample_weights: sample weights of the events. If an explicit array is not given, the weights are - calculated based on the number of events. - :param normalization: type of normalization of the confusion matrix. If not provided, the matrix is row normalized. + :param events: Dictionary with the true labels as keys and the model output of the events as values. + :param config_inst: The used config instance, for which the plot is created. + :param category_inst: The used category instance, for which the plot is created. + :param sample_weights: Sample weights applied to the confusion matrix the events. + If an explicit array is not given, the weights are calculated based on the number of events when set to *True*. + :param normalization: The type of normalization of the confusion matrix. + This parameter takes 'row', 'col' or '' (empty string) as argument. + If not provided, the matrix is row normalized. :param skip_uncertainties: If true, no uncertainty of the cells will be shown in the plot. - :param x_labels: labels for the x-axis. - :param y_labels: labels for the y-axis. + :param x_labels: The labels for the x-axis. If not provided, the labels will be 'out' + :param y_labels: The labels for the y-axis. If not provided, the dataset names are used. :param *args: Additional arguments to pass to the function. :param **kwargs: Additional keyword arguments to pass to the function. - :return: The resulting plot and the confusion matrix. + :return: Returns the resulting plot and the confusion matrix. :raises AssertionError: If both predictions and labels have mismatched shapes, or if *weights* is not *None* and its shape doesn't match *predictions*. @@ -135,13 +138,13 @@ def get_conf_matrix(sample_weights, *args, **kwargs) -> np.ndarray: vecNumber = np.vectorize(lambda n, count: sci.Number(n, float(n / np.sqrt(count) if count else 0))) result = vecNumber(result, counts) - # Normalize Matrix if needed + # normalize Matrix if needed if normalization: valid = {"row": 1, "column": 0} if normalization not in valid.keys(): raise ValueError( - f"\"{normalization}\" is not a valid argument for normalization. If given, normalization " - "should only take \"row\" or \"column\"", + f"'{normalization}' is not a valid argument for normalization. If given, normalization " + "should only take 'row' or 'column'", ) row_sums = result.sum(axis=valid.get(normalization)) @@ -180,15 +183,15 @@ def plot_confusion_matrix( from mpl_toolkits.axes_grid1 import make_axes_locatable def calculate_font_size(): - # Get cell width + # get cell width bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) width, height = fig.dpi * bbox.width, fig.dpi * bbox.height - # Size of each cell in pixels + # size of each cell in pixels cell_width = width / n_classes cell_height = height / n_processes - # Calculate the font size based on the cell size to ensure font is not too large + # calculate the font size based on the cell size to ensure font is not too large font_size = min(cell_width, cell_height) / 10 font_size = max(min(font_size, 18), 8) @@ -218,7 +221,7 @@ def fmt(v): plt.style.use(hep.style.CMS) fig, ax = plt.subplots(dpi=300) - # Some useful variables and functions + # some useful variables and functions n_processes = cm.shape[0] n_classes = cm.shape[1] cmap = cf_colors.get(colormap, cf_colors["cf_cmap"]) @@ -228,11 +231,11 @@ def fmt(v): font_label = 20 font_text = calculate_font_size() - # Get values and (if available) their uncertenties + # get values and (if available) their uncertenties values = cm.astype(np.float32) uncs = get_errors(cm) - # Remove Major ticks and edit minor ticks + # remove Major ticks and edit minor ticks minor_tick_length = max(int(120 / n_classes), 12) / 2 minor_tick_width = max(6 / n_classes, 0.6) xtick_marks = np.arange(n_classes) @@ -241,7 +244,7 @@ def fmt(v): # plot the data im = ax.imshow(values, interpolation="nearest", cmap=cmap) - # Plot settings + # plot settings thresh = values.max() / 2. ax.tick_params(axis="both", which="major", bottom=False, top=False, left=False, right=False) ax.tick_params( @@ -271,7 +274,7 @@ def fmt(v): colorbar.ax.tick_params(labelsize=font_ax - 5) im.set_clim(0, max(1, values.max())) - # Add Matrix Elemtns + # add Matrix Elemtns for i in range(values.shape[0]): for j in range(values.shape[1]): ax.text( diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index b52dee209..78e113355 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -5,9 +5,7 @@ """ from __future__ import annotations -from collections import OrderedDict - -from collections import defaultdict +from collections import OrderedDict, defaultdict import law import luigi @@ -32,7 +30,6 @@ from columnflow.tasks.production import ProduceColumns from columnflow.util import dev_sandbox, safe_div, DotDict from columnflow.util import maybe_import -from columnflow.types import Dict, List ak = maybe_import("awkward") @@ -947,9 +944,9 @@ class PlotMLResultsBase( ) skip_processes = law.CSVParameter( - default=("",), + default=(), description="comma seperated list of process names to skip; these processes will not be included in the plots. " - "default: ('',)", + "default: ()", brace_expand=True, ) @@ -1005,14 +1002,14 @@ def workflow_requires(self: PlotMLResultsBase, only_super: bool = False): return reqs - def output(self: PlotMLResultsBase) -> Dict[str, List]: + def output(self: PlotMLResultsBase) -> dict[str, list]: b = self.branch_data return {"plots": [ self.target(name) for name in self.get_plot_names(f"plot__proc_{self.processes_repr}__cat_{b.category}") ]} - def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: + def prepare_inputs(self: PlotMLResultsBase) -> dict[str, ak.Array]: """ prepare the inputs for the plot function, based on the given configuration and category. @@ -1020,7 +1017,7 @@ def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: :raises ValueError: This error is raised if ``plot_sub_processes`` is used without providing the ``process_ids`` column in the data - :return: Dict[str, ak.Array]: A dictionary with the dataset names as keys and + :return: dict[str, ak.Array]: A dictionary with the dataset names as keys and the corresponding predictions as values. """ category_inst = self.config_inst.get_category(self.branch_data.category) @@ -1063,8 +1060,10 @@ def prepare_inputs(self: PlotMLResultsBase) -> Dict[str, ak.Array]: all_events[process_inst.name] = getattr(events, self.ml_model) else: if "process_ids" not in events.fields: - raise ValueError("No `process_ids` column stored in the events! " - f"Process selection for {dataset} cannot not be applied!") + raise ValueError( + "No `process_ids` column stored in the events! " + f"Process selection for {dataset} cannot not be applied!", + ) for sub_process in sub_process_insts[process_inst]: if sub_process.name in self.skip_processes: continue @@ -1112,15 +1111,16 @@ def prepare_plot_parameters(self: PlotMLResults): def output(self: PlotMLResults): b = self.branch_data - return {"plots": [ - self.target(name) - for name in self.get_plot_names( - f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", - ) - ], + return { + "plots": [ + self.target(name) + for name in self.get_plot_names( + f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/plot__0", + ) + ], "array": self.target( f"plot__{self.plot_function}__proc_{self.processes_repr}__cat_{b.category}/data.parquet", - ), + ), } @law.decorator.log From f755318171c6a8e33ee713eb0003c4ce63cac171 Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 7 Dec 2023 18:47:07 +0100 Subject: [PATCH 35/36] fixed linting --- columnflow/plotting/plot_ml_evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/columnflow/plotting/plot_ml_evaluation.py b/columnflow/plotting/plot_ml_evaluation.py index 87d36fd29..e55940fe4 100644 --- a/columnflow/plotting/plot_ml_evaluation.py +++ b/columnflow/plotting/plot_ml_evaluation.py @@ -355,7 +355,7 @@ def plot_roc( if evaluation_type not in ["OvO", "OvR"]: raise ValueError( "Illeagal Argument! Evaluation Type can only be choosen as 'OvO' (One vs One)" - "or 'OvR' (One vs Rest)" + "or 'OvR' (One vs Rest)", ) def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> dict[str, dict[str, np.ndarray]]: From 86410ad48a0bef9c888af4262c40fb2f3a2acb1c Mon Sep 17 00:00:00 2001 From: Anas Haddad Date: Thu, 7 Dec 2023 18:51:26 +0100 Subject: [PATCH 36/36] fixed linting --- columnflow/tasks/ml.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/columnflow/tasks/ml.py b/columnflow/tasks/ml.py index 78e113355..6c1652abb 100644 --- a/columnflow/tasks/ml.py +++ b/columnflow/tasks/ml.py @@ -28,8 +28,7 @@ from columnflow.tasks.framework.decorators import view_output_plots from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents from columnflow.tasks.production import ProduceColumns -from columnflow.util import dev_sandbox, safe_div, DotDict -from columnflow.util import maybe_import +from columnflow.util import dev_sandbox, safe_div, DotDict, maybe_import ak = maybe_import("awkward")