Skip to content

Commit

Permalink
option to save cfgs + figures for application
Browse files Browse the repository at this point in the history
  • Loading branch information
Marcellocosti committed Dec 10, 2024
1 parent 4618d2e commit 9cd6920
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 5 deletions.
30 changes: 25 additions & 5 deletions ML/MLApplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import argparse
import yaml
import matplotlib.pyplot as plt

from hipe4ml.model_handler import ModelHandler
from hipe4ml.tree_handler import TreeHandler
Expand All @@ -22,6 +23,13 @@ def main(): #pylint: disable=too-many-statements, too-many-branches
inputCfg = yaml.load(ymlCfgFile, yaml.FullLoader)
print('Loading analysis configuration: Done!')

if inputCfg.get('savecfg'):
# Save the YAML file to the folder
if not os.path.isdir(os.path.expanduser(inputCfg['standalone_appl']['output_dir'])):
os.makedirs(os.path.expanduser(inputCfg['standalone_appl']['output_dir']))
with open(f'{os.path.expanduser(inputCfg["standalone_appl"]["output_dir"])}/cfg.yml', 'w') as ymlOutFile:
yaml.dump(inputCfg, ymlOutFile, default_flow_style=False)

PtBins = [[a, b] for a, b in zip(inputCfg['pt_ranges']['min'], inputCfg['pt_ranges']['max'])]
OutputLabels = [inputCfg['output']['out_labels']['Bkg'],
inputCfg['output']['out_labels']['Prompt']]
Expand Down Expand Up @@ -49,7 +57,7 @@ def main(): #pylint: disable=too-many-statements, too-many-branches
else:
DataHandler = TreeHandler(inputFile, treename)

DataHandler.slice_data_frame('pt_cand', PtBins, True)
DataHandler.slice_data_frame('fPt', PtBins, True)
print(f'Loading and preparing data files {inputFile}: Done!')

print('Applying ML model to dataframes: ...', end='\r')
Expand All @@ -67,10 +75,10 @@ def main(): #pylint: disable=too-many-statements, too-many-branches
if not isinstance(ColumnsToSaveFinal, list):
print('\033[91mERROR: column_to_save_list must be defined!\033[0m')
sys.exit()
if 'inv_mass' not in ColumnsToSaveFinal:
print('\033[93mWARNING: inv_mass is not going to be saved in the output dataframe!\033[0m')
if 'pt_cand' not in ColumnsToSaveFinal:
print('\033[93mWARNING: pt_cand is not going to be saved in the output dataframe!\033[0m')
if 'fM' not in ColumnsToSaveFinal:
print('\033[93mWARNING: fM is not going to be saved in the output dataframe!\033[0m')
if 'fPt' not in ColumnsToSaveFinal:
print('\033[93mWARNING: fPt is not going to be saved in the output dataframe!\033[0m')
if 'pt_B' in ColumnsToSaveFinal and 'pt_B' not in DataDfPtSel.columns:
ColumnsToSaveFinal.remove('pt_B') # only in MC
DataDfPtSel = DataDfPtSel.loc[:, ColumnsToSaveFinal]
Expand All @@ -80,6 +88,18 @@ def main(): #pylint: disable=too-many-statements, too-many-branches
for Pred, Lab in enumerate(OutputLabels):
DataDfPtSel[f'ML_output_{Lab}'] = yPred[:, Pred]
DataDfPtSel.to_parquet(f'{OutPutDirPt}/{outName}_pT_{PtBin[0]}_{PtBin[1]}_ModelApplied.parquet.gzip')

if inputCfg.get('savedistrs'):
plt.figure(figsize=(10, 6))
for col in DataDfPtSel.columns:
if 'ML_output' in col:
plt.hist(DataDfPtSel[col], bins=100, alpha=0.5, label=col, log=True)
plt.title(f'Distributions of ML Outputs for {outName}')
plt.xlabel('Score')
plt.ylabel('Frequency (log scale)')
plt.legend()
plt.savefig(f"{OutPutDirPt}/{outName}Distrs.pdf", format="pdf", bbox_inches="tight")

del DataDfPtSel
print('Applying ML model to dataframes: Done!')

Expand Down
7 changes: 7 additions & 0 deletions ML/MLClassification.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,13 @@ def main(): #pylint: disable=too-many-statements
inputCfg = yaml.load(ymlCfgFile, yaml.FullLoader)
print('Loading analysis configuration: Done!')

if inputCfg.get('savecfg'):
# Save the YAML file to the folder
if not os.path.isdir(os.path.expanduser(inputCfg['output']['dir'])):
os.makedirs(os.path.expanduser(inputCfg['output']['dir']))
with open(f'{os.path.expanduser(inputCfg["output"]["dir"])}/cfg.yml', 'w') as ymlOutFile:
yaml.dump(inputCfg, ymlOutFile, default_flow_style=False)

print('Loading and preparing data files: ...', end='\r')
PromptHandler = TreeHandler(inputCfg['input']['prompt'], inputCfg['input']['treename'])
FDHandler = None if inputCfg['input']['FD'] is None else TreeHandler(inputCfg['input']['FD'],
Expand Down

0 comments on commit 9cd6920

Please sign in to comment.