From 9cd6920d95b691698c2ab075e3e109993ad46d65 Mon Sep 17 00:00:00 2001 From: marcellocosti Date: Tue, 10 Dec 2024 15:55:21 +0100 Subject: [PATCH] option to save cfgs + figures for application --- ML/MLApplication.py | 30 +++++++++++++++++++++++++----- ML/MLClassification.py | 7 +++++++ 2 files changed, 32 insertions(+), 5 deletions(-) diff --git a/ML/MLApplication.py b/ML/MLApplication.py index e1ed48bd..18317350 100644 --- a/ML/MLApplication.py +++ b/ML/MLApplication.py @@ -7,6 +7,7 @@ import sys import argparse import yaml +import matplotlib.pyplot as plt from hipe4ml.model_handler import ModelHandler from hipe4ml.tree_handler import TreeHandler @@ -22,6 +23,13 @@ def main(): #pylint: disable=too-many-statements, too-many-branches inputCfg = yaml.load(ymlCfgFile, yaml.FullLoader) print('Loading analysis configuration: Done!') + if inputCfg.get('savecfg'): + # Save the YAML file to the folder + if not os.path.isdir(os.path.expanduser(inputCfg['standalone_appl']['output_dir'])): + os.makedirs(os.path.expanduser(inputCfg['standalone_appl']['output_dir'])) + with open(f'{os.path.expanduser(inputCfg["standalone_appl"]["output_dir"])}/cfg.yml', 'w') as ymlOutFile: + yaml.dump(inputCfg, ymlOutFile, default_flow_style=False) + PtBins = [[a, b] for a, b in zip(inputCfg['pt_ranges']['min'], inputCfg['pt_ranges']['max'])] OutputLabels = [inputCfg['output']['out_labels']['Bkg'], inputCfg['output']['out_labels']['Prompt']] @@ -49,7 +57,7 @@ def main(): #pylint: disable=too-many-statements, too-many-branches else: DataHandler = TreeHandler(inputFile, treename) - DataHandler.slice_data_frame('pt_cand', PtBins, True) + DataHandler.slice_data_frame('fPt', PtBins, True) print(f'Loading and preparing data files {inputFile}: Done!') print('Applying ML model to dataframes: ...', end='\r') @@ -67,10 +75,10 @@ def main(): #pylint: disable=too-many-statements, too-many-branches if not isinstance(ColumnsToSaveFinal, list): print('\033[91mERROR: column_to_save_list must be defined!\033[0m') sys.exit() - if 'inv_mass' not in ColumnsToSaveFinal: - print('\033[93mWARNING: inv_mass is not going to be saved in the output dataframe!\033[0m') - if 'pt_cand' not in ColumnsToSaveFinal: - print('\033[93mWARNING: pt_cand is not going to be saved in the output dataframe!\033[0m') + if 'fM' not in ColumnsToSaveFinal: + print('\033[93mWARNING: fM is not going to be saved in the output dataframe!\033[0m') + if 'fPt' not in ColumnsToSaveFinal: + print('\033[93mWARNING: fPt is not going to be saved in the output dataframe!\033[0m') if 'pt_B' in ColumnsToSaveFinal and 'pt_B' not in DataDfPtSel.columns: ColumnsToSaveFinal.remove('pt_B') # only in MC DataDfPtSel = DataDfPtSel.loc[:, ColumnsToSaveFinal] @@ -80,6 +88,18 @@ def main(): #pylint: disable=too-many-statements, too-many-branches for Pred, Lab in enumerate(OutputLabels): DataDfPtSel[f'ML_output_{Lab}'] = yPred[:, Pred] DataDfPtSel.to_parquet(f'{OutPutDirPt}/{outName}_pT_{PtBin[0]}_{PtBin[1]}_ModelApplied.parquet.gzip') + + if inputCfg.get('savedistrs'): + plt.figure(figsize=(10, 6)) + for col in DataDfPtSel.columns: + if 'ML_output' in col: + plt.hist(DataDfPtSel[col], bins=100, alpha=0.5, label=col, log=True) + plt.title(f'Distributions of ML Outputs for {outName}') + plt.xlabel('Score') + plt.ylabel('Frequency (log scale)') + plt.legend() + plt.savefig(f"{OutPutDirPt}/{outName}Distrs.pdf", format="pdf", bbox_inches="tight") + del DataDfPtSel print('Applying ML model to dataframes: Done!') diff --git a/ML/MLClassification.py b/ML/MLClassification.py index 0ae3350d..f320668e 100644 --- a/ML/MLClassification.py +++ b/ML/MLClassification.py @@ -315,6 +315,13 @@ def main(): #pylint: disable=too-many-statements inputCfg = yaml.load(ymlCfgFile, yaml.FullLoader) print('Loading analysis configuration: Done!') + if inputCfg.get('savecfg'): + # Save the YAML file to the folder + if not os.path.isdir(os.path.expanduser(inputCfg['output']['dir'])): + os.makedirs(os.path.expanduser(inputCfg['output']['dir'])) + with open(f'{os.path.expanduser(inputCfg["output"]["dir"])}/cfg.yml', 'w') as ymlOutFile: + yaml.dump(inputCfg, ymlOutFile, default_flow_style=False) + print('Loading and preparing data files: ...', end='\r') PromptHandler = TreeHandler(inputCfg['input']['prompt'], inputCfg['input']['treename']) FDHandler = None if inputCfg['input']['FD'] is None else TreeHandler(inputCfg['input']['FD'],