Skip to content

Commit

Permalink
Merge branch 'ROC_PlotTask' of github.com:haddadanas/columnflow into …
Browse files Browse the repository at this point in the history
…ROC_and_tests

Conflicts:
	columnflow/plotting/plot_ml_evaluation.py
	columnflow/tasks/ml.py
  • Loading branch information
haddadanas committed Dec 7, 2023
2 parents befe144 + 86410ad commit b46f3c8
Show file tree
Hide file tree
Showing 4 changed files with 474 additions and 2 deletions.
207 changes: 207 additions & 0 deletions columnflow/plotting/plot_ml_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,3 +308,210 @@ def fmt(v):
print("Confusion matrix plotted!")

return [fig], cm


def plot_roc(
events: dict,
config_inst: od.Config,
category_inst: od.Category,
sample_weights: Sequence | bool = False,
n_thresholds: int = 200 + 1,
skip_discriminators: list[str] = [],
evaluation_type: str = "OvR",
cms_rlabel: str = "",
cms_llabel: str = "private work",
*args,
**kwargs,
) -> tuple[list[plt.Figure], dict]:
"""
Generates the figure of the ROC curve given the output of the nodes
and an array of true labels. The ROC curve can also be weighted.
:param events: dictionary with the true labels as keys and the model output of the events as values.
:param config_inst: used configuration for the plot.
:param category_inst: used category instance, for which the plot is created.
:param sample_weights: sample weights of the events. If an explicit array is not given, the weights are
calculated based on the number of events.
:param n_thresholds: number of thresholds to use for the ROC curve.
:param skip_discriminators: list of discriminators to skip.
:param evaluation_type: type of evaluation to use for the ROC curve. If not provided, the type is 'OvR'.
:param cms_rlabel: right label of the CMS label.
:param cms_llabel: left label of the CMS label.
:param *args: Additional arguments to pass to the function.
:param **kwargs: Additional keyword arguments to pass to the function.
:return: The resulting plot and the ROC curve.
:raises ValueError: If both predictions and labels have mismatched shapes, or if *weights*
is not *None* and its shape doesn't match *predictions*.
:raises ValueError: If *normalization* is not one of *None*, 'row', 'column'.
"""
# defining some useful properties and output shapes
thresholds = np.linspace(0, 1, n_thresholds)
weights = create_sample_weights(sample_weights, events, list(events.keys()))
discriminators = list(events.values())[0].fields
figs = []

if evaluation_type not in ["OvO", "OvR"]:
raise ValueError(
"Illeagal Argument! Evaluation Type can only be choosen as 'OvO' (One vs One)"
"or 'OvR' (One vs Rest)",
)

def create_histograms(events: dict, sample_weights: dict, *args, **kwargs) -> dict[str, dict[str, np.ndarray]]:
"""
Helper function to create histograms for the different discriminators and classes.
"""
hists = {}
for disc in discriminators:
hists[disc] = {}
for cls, predictions in events.items():
hists[disc][cls] = (sample_weights[cls] *
ak.to_numpy(np.histogram(predictions[disc], bins=thresholds)[0]))
return hists

def binary_roc_data(
positiv_hist: np.ndarray,
negativ_hist: np.ndarray,
*args,
**kwargs,
) -> tuple[np.ndarray, np.ndarray]:
"""
Compute binary Receiver operating characteristic (ROC) values.
Used as a helper function for the multi-dimensional ROC curve
"""
# calculate the different rates
fn = np.cumsum(positiv_hist)
tn = np.cumsum(negativ_hist)
tp = fn[-1] - fn
fp = tn[-1] - tn

tpr = tp / (tp + fn)
fpr = fp / (fp + tn)

return fpr, tpr

def roc_curve_data(
evaluation_type: str,
histograms: dict,
*args,
**kwargs,
) -> dict[str, dict[str, np.ndarray]]:
"""
Compute Receiver operating characteristic (ROC) values for a multi-dimensional output.
"""
result = {}

for disc in discriminators:
tmp = {}

if disc in skip_discriminators:
continue

# choose the evaluation type
if (evaluation_type == "OvO"):
for pos_cls, pos_hist in histograms[disc].items():
for neg_cls, neg_hist in histograms[disc].items():

fpr, tpr = binary_roc_data(
positiv_hist=pos_hist,
negativ_hist=neg_hist,
*args,
**kwargs,
)
tmp[f"{pos_cls}_vs_{neg_cls}"] = {"fpr": fpr, "tpr": tpr}

elif (evaluation_type == "OvR"):
for pos_cls, pos_hist in histograms[disc].items():
neg_hist = np.zeros_like(pos_hist)
for neg_cls, neg_pred in histograms[disc].items():
if (pos_cls == neg_cls):
continue
neg_hist += neg_pred

fpr, tpr = binary_roc_data(
positiv_hist=pos_hist,
negativ_hist=neg_hist,
*args,
**kwargs,
)
tmp[f"{pos_cls}_vs_rest"] = {"fpr": fpr, "tpr": tpr}

result[disc] = tmp

return result

def plot_roc_curve(
roc_data: dict,
title: str,
cms_rlabel: str = "",
cms_llabel: str = "private work",
*args,
**kwargs,
) -> plt.figure:
"""
Plots a ROC curve.
"""
def auc_score(fpr: list, tpr: list, *args) -> np.float64:
"""
Compute the area under the curve using the trapezoidal rule.
"""
sign = 1
if np.any(np.diff(fpr) < 0):
if np.all(np.diff(fpr) <= 0):
sign = -1
else:
raise ValueError("x is neither increasing nor decreasing : {}.".format(fpr))

return sign * np.trapz(tpr, fpr)

fpr = roc_data["fpr"]
tpr = roc_data["tpr"]

plt.style.use(hep.style.CMS)
fig, ax = plt.subplots(dpi=300)
ax.set_xlabel("FPR", loc="right", labelpad=10, fontsize=25)
ax.set_ylabel("TPR", loc="top", labelpad=15, fontsize=25)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

ax.plot(
fpr,
tpr,
color="#67cf02",
label=f"AUC = {auc_score(roc_data['fpr'], roc_data['tpr']):.3f}",
)
ax.plot([0, 1], [0, 1], color="black", linestyle="--")

# setting titles and legend
ax.legend(loc="lower right", fontsize=15)
ax.set_title(title, wrap=True, pad=50, fontsize=30)
hep.cms.label(ax=ax, llabel={"pw": "private work"}.get(cms_llabel, cms_llabel), rlabel=cms_rlabel)
plt.tight_layout()

return fig

# create historgrams and calculate FPR and TPR
histograms = create_histograms(events, weights, *args, **kwargs)
print("histograms created!")

results = roc_curve_data(evaluation_type, histograms, *args, **kwargs)
print("ROC data calculated!")

# plotting
for disc, roc in results.items():
for cls, roc_data in roc.items():
title = rf"{cls.replace('_vs_', ' VS ')} with {disc}"
figs.append(plot_roc_curve(
roc_data=roc_data,
title=title,
cms_rlabel=cms_rlabel,
cms_llabel=cms_llabel,
*args,
**kwargs,
))
print("ROC curves plotted!")

results["thresholds"] = thresholds

return figs, results
3 changes: 1 addition & 2 deletions columnflow/tasks/ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from columnflow.tasks.framework.decorators import view_output_plots
from columnflow.tasks.reduction import MergeReducedEventsUser, MergeReducedEvents
from columnflow.tasks.production import ProduceColumns
from columnflow.util import dev_sandbox, safe_div, DotDict
from columnflow.util import maybe_import
from columnflow.util import dev_sandbox, safe_div, DotDict, maybe_import


ak = maybe_import("awkward")
Expand Down
6 changes: 6 additions & 0 deletions tests/run_tests
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ action() {
ret="$?"
[ "${gret}" = "0" ] && gret="${ret}"

# test_plotting
echo
bash "${this_dir}/run_test" test_plotting "${cf_dir}/sandboxes/venv_columnar${dev}.sh"
ret="$?"
[ "${gret}" = "0" ] && gret="${ret}"

return "${gret}"
}
action "$@"
Loading

0 comments on commit b46f3c8

Please sign in to comment.