From a7e014b849fc2804ad93f8bdb92c775fb83b741d Mon Sep 17 00:00:00 2001 From: Hadi Hassan Date: Sun, 25 Aug 2024 20:26:59 +0300 Subject: [PATCH] PWGJE: Task for tagging the b-jet using ML model (#7286) * PWGJE: Task for tagging the b-jet using ML model * using static_cast * Removing unsued variable * Removing the TreeMerger --- PWGJE/Core/JetTaggingUtilities.h | 73 +++- PWGJE/Tasks/CMakeLists.txt | 4 + PWGJE/Tasks/bjetTaggingML.cxx | 592 +++++++++++++++++++++++++++++++ PWGJE/Tasks/bjetTreeCreator.cxx | 57 ++- Tools/ML/model.h | 33 ++ 5 files changed, 740 insertions(+), 19 deletions(-) create mode 100644 PWGJE/Tasks/bjetTaggingML.cxx diff --git a/PWGJE/Core/JetTaggingUtilities.h b/PWGJE/Core/JetTaggingUtilities.h index 590b495f570..919fb2d83b7 100644 --- a/PWGJE/Core/JetTaggingUtilities.h +++ b/PWGJE/Core/JetTaggingUtilities.h @@ -46,6 +46,28 @@ namespace jettaggingutilities { const int cmTomum = 10000; // using cm -> #mum for impact parameter (dca) +//________________________________________________________________________ +bool isBHadron(int pc) +{ + std::vector bPdG = {511, 521, 10511, 10521, 513, 523, 10513, 10523, 20513, 20523, 20513, 20523, 515, 525, 531, 10531, 533, 10533, + 20533, 535, 541, 10541, 543, 10543, 20543, 545, 551, 10551, 100551, 110551, 200551, 210551, 553, 10553, 20553, + 30553, 100553, 110553, 120553, 130553, 200553, 210553, 220553, 300553, 9000533, 9010553, 555, 10555, 20555, + 100555, 110555, 120555, 200555, 557, 100557, 5122, 5112, 5212, 5222, 5114, 5214, 5224, 5132, 5232, 5312, 5322, + 5314, 5324, 5332, 5334, 5142, 5242, 5412, 5422, 5414, 5424, 5342, 5432, 5434, 5442, 5444, 5512, 5522, 5514, 5524, + 5532, 5534, 5542, 5544, 5554}; + + return (std::find(bPdG.begin(), bPdG.end(), std::abs(pc)) != bPdG.end()); +} +//________________________________________________________________________ +bool isCHadron(int pc) +{ + std::vector bPdG = {411, 421, 10411, 10421, 413, 423, 10413, 10423, 20431, 20423, 415, 425, 431, 10431, 433, 10433, 20433, 435, 441, + 10441, 100441, 443, 10443, 20443, 100443, 30443, 9000443, 9010443, 9020443, 445, 100445, 4122, 4222, 4212, 4112, + 4224, 4214, 4114, 4232, 4132, 4322, 4312, 4324, 4314, 4332, 4334, 4412, 4422, 4414, 4424, 4432, 4434, 4444}; + + return (std::find(bPdG.begin(), bPdG.end(), std::abs(pc)) != bPdG.end()); +} + /** * returns the globalIndex of the earliest mother of a particle in the shower. returns -1 if a suitable mother is not found * @@ -292,11 +314,7 @@ int jetOrigin(T const& jet, U const& particles, float dRMax = 0.25) template int16_t getJetFlavor(AnyJet const& jet, AllMCParticles const& mcparticles) { - const int arraySize = 99; - - std::array countpartcode; - int count = 0; - + bool charmQuark = false; for (auto& mcpart : mcparticles) { int pdgcode = mcpart.pdgCode(); if (TMath::Abs(pdgcode) == 21 || (TMath::Abs(pdgcode) >= 1 && TMath::Abs(pdgcode) <= 5)) { @@ -305,19 +323,48 @@ int16_t getJetFlavor(AnyJet const& jet, AllMCParticles const& mcparticles) if (dR < jet.r() / 100.f) { if (TMath::Abs(pdgcode) == 5) { return JetTaggingSpecies::beauty; // Beauty jet - } else { - if (count > arraySize - 1) - return 0; - countpartcode[count] = pdgcode; - count++; + } else if (TMath::Abs(pdgcode) == 4) { + charmQuark = true; + } + } + } + } + + if (charmQuark) { + return JetTaggingSpecies::charm; // Charm jet + } + + return JetTaggingSpecies::lightflavour; // Light flavor jet +} + +/** + * return the jet flavor if it finds a HF hadron inside the jet: 0 for lf-jet, 1 for c-jet, 2 for b-jet + * + * @param AnyJet the jet that we need to study its flavor + * @param AllMCParticles a vector of all the mc particles stack + */ +template +int16_t getJetFlavorHadron(AnyJet const& jet, AllMCParticles const& mcparticles) +{ + bool charmHadron = false; + + for (auto& mcpart : mcparticles) { + int pdgcode = mcpart.pdgCode(); + if (isBHadron(pdgcode) || isCHadron(pdgcode)) { + double dR = jetutilities::deltaR(jet, mcpart); + + if (dR < jet.r() / 100.f) { + if (isBHadron(pdgcode)) { + return JetTaggingSpecies::beauty; // Beauty jet + } else if (isCHadron(pdgcode)) { + charmHadron = true; } } } } - for (int ij = 0; ij < count; ij++) { - if (TMath::Abs(countpartcode[ij]) == 4) - return JetTaggingSpecies::charm; // Charm jet + if (charmHadron) { + return JetTaggingSpecies::charm; // Charm jet } return JetTaggingSpecies::lightflavour; // Light flavor jet diff --git a/PWGJE/Tasks/CMakeLists.txt b/PWGJE/Tasks/CMakeLists.txt index 5ead2097f72..eaa7bb9ec2b 100644 --- a/PWGJE/Tasks/CMakeLists.txt +++ b/PWGJE/Tasks/CMakeLists.txt @@ -164,5 +164,9 @@ if(FastJet_FOUND) SOURCES fulljetspectrapp.cxx PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::PWGJECore O2Physics::AnalysisCore COMPONENT_NAME Analysis) + o2physics_add_dpl_workflow(bjet-tagging-ml + SOURCES bjetTaggingML.cxx + PUBLIC_LINK_LIBRARIES O2::Framework O2Physics::PWGJECore O2Physics::AnalysisCore O2Physics::MLCore + COMPONENT_NAME Analysis) endif() diff --git a/PWGJE/Tasks/bjetTaggingML.cxx b/PWGJE/Tasks/bjetTaggingML.cxx new file mode 100644 index 00000000000..55e56762624 --- /dev/null +++ b/PWGJE/Tasks/bjetTaggingML.cxx @@ -0,0 +1,592 @@ +// Copyright 2019-2020 CERN and copyright holders of ALICE O2. +// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders. +// All rights not expressly granted are reserved. +// +// This software is distributed under the terms of the GNU General Public +// License v3 (GPL Version 3), copied verbatim in the file "COPYING". +// +// In applying this license CERN does not waive the privileges and immunities +// granted to it by virtue of its status as an Intergovernmental Organization +// or submit itself to any jurisdiction. + +/// \file bjetTaggingML.cxx +/// \brief Task for tagging the beauty jets using ML algorithm (in onnx format) loaded from ccdb +/// +/// \author Hadi Hassan , University of Jyväskylä + +#include "Framework/AnalysisDataModel.h" +#include "Framework/AnalysisTask.h" +#include "Framework/ASoA.h" +#include "Framework/HistogramRegistry.h" +#include "Framework/runDataProcessing.h" +#include "PWGJE/Core/JetUtilities.h" +#include "PWGJE/Core/JetDerivedDataUtilities.h" +#include "PWGJE/Core/JetTaggingUtilities.h" +#include "PWGJE/DataModel/JetTagging.h" +#include "PWGJE/DataModel/Jet.h" + +#include "Common/Core/trackUtilities.h" +#include "Common/Core/TrackSelection.h" +#include "Common/Core/TrackSelectionDefaults.h" +#include "Common/DataModel/EventSelection.h" +#include "Common/DataModel/TrackSelectionTables.h" +#include "Common/Core/RecoDecay.h" +#include "Tools/ML/MlResponse.h" + +using namespace o2; +using namespace o2::framework; +using namespace o2::framework::expressions; + +struct BJetTaggingML { + + struct bjetParams { + float mJetpT = 0.0; + float mJetEta = 0.0; + float mJetPhi = 0.0; + int mNTracks = -1; + int mNSV = -1; + float mJetMass = 0.0; + }; + + struct bjetTrackParams { + double mTrackpT = 0.0; + double mTrackEta = 0.0; + double mDotProdTrackJet = 0.0; + double mDotProdTrackJetOverJet = 0.0; + double mDeltaRJetTrack = 0.0; + double mSignedIP2D = 0.0; + double mSignedIP2DSign = 0.0; + double mSignedIP3D = 0.0; + double mSignedIP3DSign = 0.0; + double mMomFraction = 0.0; + double mDeltaRTrackVertex = 0.0; + }; + + struct bjetSVParams { + double mSVpT = 0.0; + double mDeltaRSVJet = 0.0; + double mSVMass = 0.0; + double mSVfE = 0.0; + double mIPXY = 0.0; + double mCPA = 0.0; + double mChi2PCA = 0.0; + double mDecayLength2D = 0.0; + double mDecayLength2DError = 0.0; + double mDecayLength3D = 0.0; + double mDecayLength3DError = 0.0; + }; + + HistogramRegistry registry; + + static constexpr double defaultCutsMl[1][2] = {{0.5, 0.5}}; + + // event level configurables + Configurable vertexZCut{"vertexZCut", 10.0f, "Accepted z-vertex range"}; + Configurable eventSelections{"eventSelections", "sel8", "choose event selection"}; + + // track level configurables + Configurable trackPtMin{"trackPtMin", 0.5, "minimum track pT"}; + Configurable trackPtMax{"trackPtMax", 1000.0, "maximum track pT"}; + Configurable trackEtaMin{"trackEtaMin", -0.9, "minimum track eta"}; + Configurable trackEtaMax{"trackEtaMax", 0.9, "maximum track eta"}; + + // track level configurables + Configurable svPtMin{"svPtMin", 0.5, "minimum SV pT"}; + + // jet level configurables + Configurable jetPtMin{"jetPtMin", 5.0, "minimum jet pT"}; + Configurable jetPtMax{"jetPtMax", 1000.0, "maximum jet pT"}; + Configurable jetEtaMin{"jetEtaMin", -99.0, "minimum jet pseudorapidity"}; + Configurable jetEtaMax{"jetEtaMax", 99.0, "maximum jet pseudorapidity"}; + Configurable nJetConst{"nJetConst", 10, "maximum number of jet consistuents to be used for ML evaluation"}; + + Configurable useQuarkDef{"useQuarkDef", true, "Flag whether to use quarks or hadrons for determining the jet flavor"}; + + Configurable svReductionFactor{"svReductionFactor", 1.0, "factor for how many SVs to keep"}; + + Configurable> jetRadii{"jetRadii", std::vector{0.4}, "jet resolution parameters"}; + + Configurable> binsPtMl{"binsPtMl", std::vector{5., 1000.}, "pT bin limits for ML application"}; + Configurable> cutDirMl{"cutDirMl", std::vector{cuts_ml::CutSmaller, cuts_ml::CutNot}, "Whether to reject score values greater or smaller than the threshold"}; + Configurable> cutsMl{"cutsMl", {defaultCutsMl[0], 1, 2, {"pT bin 0"}, {"score for default b-jet tagging", "uncer 1"}}, "ML selections per pT bin"}; + Configurable nClassesMl{"nClassesMl", (int8_t)2, "Number of classes in ML model"}; + Configurable> namesInputFeatures{"namesInputFeatures", std::vector{"feature1", "feature2"}, "Names of ML model input features"}; + + Configurable ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"}; + Configurable> modelPathsCCDB{"modelPathsCCDB", std::vector{"Users/h/hahassan"}, "Paths of models on CCDB"}; + Configurable> onnxFileNames{"onnxFileNames", std::vector{"ML_bjets/01-MVA/Models/LHC23d4_5_20_90Percent/model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"}; + Configurable timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"}; + Configurable loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"}; + + o2::analysis::MlResponse bMlResponse; + o2::ccdb::CcdbApi ccdbApi; + + int eventSelection = -1; + + std::vector jetRadiiValues; + + void init(InitContext const&) + { + // Seed the random number generator using current time + std::srand(static_cast(std::time(nullptr))); + + jetRadiiValues = (std::vector)jetRadii; + + eventSelection = jetderiveddatautilities::initialiseEventSelection(static_cast(eventSelections)); + + registry.add("h_vertexZ", "Vertex Z;#it{Z} (cm)", {HistType::kTH1F, {{40, -20.0, 20.0}}}); + + registry.add("h2_score_jetpT", "ML scores for inclusive jets;#it{p}_{T,jet} (GeV/#it{c});Score", {HistType::kTH2F, {{200, 0., 200.}, {120, -0.1, 1.1}}}); + + registry.add("h2_nTracks_jetpT", "Number of tracks;#it{p}_{T,jet} (GeV/#it{c});nTracks", {HistType::kTH2F, {{200, 0., 200.}, {100, 0, 100.0}}}); + registry.add("h2_nSV_jetpT", "Number of secondary vertices;#it{p}_{T,jet} (GeV/#it{c});nSVs", {HistType::kTH2F, {{200, 0., 200.}, {250, 0, 250.0}}}); + + registry.add("h2_SIPs2D_jetpT", "2D IP significance;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_SIPs3D_jetpT", "3D IP significance;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_LxyS_jetpT", "Decay length in XY;#it{p}_{T,jet} (GeV/#it{c});S#it{L}_{xy}", {HistType::kTH2F, {{200, 0., 200.}, {100, 0., 100.0}}}); + registry.add("h2_Dispersion_jetpT", "SV dispersion;#it{p}_{T,jet} (GeV/#it{c});Dispersion", {HistType::kTH2F, {{200, 0., 200.}, {100, 0, 50.0}}}); + registry.add("h2_jetMass_jetpT", "Jet mass;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{jet} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 50.0}}}); + registry.add("h2_SVMass_jetpT", "Secondary vertex mass;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{SV} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 10}}}); + + if (doprocessMCJets) { + + registry.add("h2_score_jetpT_bjet", "ML scores for b-jets;#it{p}_{T,jet} (GeV/#it{c});Score", {HistType::kTH2F, {{200, 0., 200.}, {120, -0.1, 1.1}}}); + registry.add("h2_SIPs2D_jetpT_bjet", "2D IP significance b-jets;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_SIPs3D_jetpT_bjet", "3D IP significance b-jets;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_LxyS_jetpT_bjet", "Decay length in XY b-jets;#it{p}_{T,jet} (GeV/#it{c});S#it{L}_{xy}", {HistType::kTH2F, {{200, 0., 200.}, {100, 0., 100.0}}}); + registry.add("h2_Dispersion_jetpT_bjet", "SV dispersion b-jets;#it{p}_{T,jet} (GeV/#it{c});Dispersion", {HistType::kTH2F, {{200, 0., 200.}, {100, 0, 50.0}}}); + registry.add("h2_jetMass_jetpT_bjet", "Jet mass b-jets;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{jet} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 50.0}}}); + registry.add("h2_SVMass_jetpT_bjet", "Secondary vertex mass b-jets;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{SV} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 10.0}}}); + + registry.add("h2_score_jetpT_cjet", "ML scores for c-jets;#it{p}_{T,jet} (GeV/#it{c});Score", {HistType::kTH2F, {{200, 0., 200.}, {120, -0.1, 1.1}}}); + registry.add("h2_SIPs2D_jetpT_cjet", "2D IP significance c-jets;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_SIPs3D_jetpT_cjet", "3D IP significance c-jets;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_LxyS_jetpT_cjet", "Decay length in XY c-jets;#it{p}_{T,jet} (GeV/#it{c});S#it{L}_{xy}", {HistType::kTH2F, {{200, 0., 200.}, {100, 0., 100.0}}}); + registry.add("h2_Dispersion_jetpT_cjet", "SV dispersion c-jets;#it{p}_{T,jet} (GeV/#it{c});Dispersion", {HistType::kTH2F, {{200, 0., 200.}, {100, 0, 50.0}}}); + registry.add("h2_jetMass_jetpT_cjet", "Jet mass c-jets;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{jet} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 50.0}}}); + registry.add("h2_SVMass_jetpT_cjet", "Secondary vertex mass c-jets;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{SV} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 10.0}}}); + + registry.add("h2_score_jetpT_lfjet", "ML scores for lf-jets;#it{p}_{T,jet} (GeV/#it{c});Score", {HistType::kTH2F, {{200, 0., 200.}, {120, -0.1, 1.1}}}); + registry.add("h2_SIPs2D_jetpT_lfjet", "2D IP significance lf-jet;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_SIPs3D_jetpT_lfjet", "3D IP significance lf-jet;#it{p}_{T,jet} (GeV/#it{c});IPs", {HistType::kTH2F, {{200, 0., 200.}, {100, -50.0, 50.0}}}); + registry.add("h2_LxyS_jetpT_lfjet", "Decay length in XY lf-jet;#it{p}_{T,jet} (GeV/#it{c});S#it{L}_{xy}", {HistType::kTH2F, {{200, 0., 200.}, {100, 0., 100.0}}}); + registry.add("h2_Dispersion_jetpT_lfjet", "SV dispersion lf-jet;#it{p}_{T,jet} (GeV/#it{c});Dispersion", {HistType::kTH2F, {{200, 0., 200.}, {100, 0, 50.0}}}); + registry.add("h2_jetMass_jetpT_lfjet", "Jet mass lf-jet;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{jet} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 50.0}}}); + registry.add("h2_SVMass_jetpT_lfjet", "Secondary vertex mass lf-jet;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{SV} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 10.0}}}); + + registry.add("h_jetpT_detector_bjet", "Jet transverse momentum b-jets;#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_detector_cjet", "Jet transverse momentum c-jets;#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_detector_lfjet", "Jet transverse momentum lf-jet;#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + + registry.add("h_jetpT_particle_DetColl", "Jet transverse momentum particle level inclusive jets (Detector-level collisions);#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_particle_DetColl_bjet", "Jet transverse momentum particle level b-jets (Detector-level collisions);#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_particle_DetColl_cjet", "Jet transverse momentum particle level c-jets (Detector-level collisions);#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_particle_DetColl_lfjet", "Jet transverse momentum particle level lf-jet (Detector-level collisions);#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + + registry.add("h_jetpT_particle_bjet", "Jet transverse momentum particle level b-jets;#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_particle_cjet", "Jet transverse momentum particle level c-jets;#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + registry.add("h_jetpT_particle_lfjet", "Jet transverse momentum particle level lf-jet;#it{p}_{T,jet} (GeV/#it{c})", {HistType::kTH1F, {{200, 0., 200.0}}}); + + registry.add("h2_Response_DetjetpT_PartjetpT_bjet", "Response matrix b-jets;#it{p}_{T,jet}^{det} (GeV/#it{c});#it{p}_{T,jet}^{part} (GeV/#it{c})", {HistType::kTH2F, {{200, 0., 200.}, {200, 0., 200.}}}); + registry.add("h2_Response_DetjetpT_PartjetpT_cjet", "Response matrix c-jets;#it{p}_{T,jet}^{det} (GeV/#it{c});#it{p}_{T,jet}^{part} (GeV/#it{c})", {HistType::kTH2F, {{200, 0., 200.}, {200, 0., 200.}}}); + registry.add("h2_Response_DetjetpT_PartjetpT_lfjet", "Response matrix lf-jet;#it{p}_{T,jet}^{det} (GeV/#it{c});#it{p}_{T,jet}^{part} (GeV/#it{c})", {HistType::kTH2F, {{200, 0., 200.}, {200, 0., 200.}}}); + } + + bMlResponse.configure(binsPtMl, cutsMl, cutDirMl, nClassesMl); + if (loadModelsFromCCDB) { + ccdbApi.init(ccdbUrl); + bMlResponse.setModelPathsCCDB(onnxFileNames, ccdbApi, modelPathsCCDB, timestampCCDB); + } else { + bMlResponse.setModelPathsLocal(onnxFileNames); + } + // bMlResponse.cacheInputFeaturesIndices(namesInputFeatures); + bMlResponse.init(); + } + + // FIXME filtering only works when you loop directly over the list, but if you loop over it as a constituent they will not be filtered + Filter collisionFilter = nabs(aod::jcollision::posZ) < vertexZCut; + Filter trackCuts = (aod::jtrack::pt > trackPtMin && aod::jtrack::pt < trackPtMax && aod::jtrack::eta > trackEtaMin && aod::jtrack::eta < trackEtaMax); + Filter partCuts = (aod::jmcparticle::pt >= trackPtMin && aod::jmcparticle::pt < trackPtMax); + Filter jetFilter = (aod::jet::pt >= jetPtMin && aod::jet::pt <= jetPtMax && aod::jet::eta < jetEtaMax - aod::jet::r / 100.f && aod::jet::eta > jetEtaMin + aod::jet::r / 100.f); + + using FilteredCollision = soa::Filtered>; + using JetTrackswID = soa::Join; + using JetTracksMCDwID = soa::Join; + using OriginalTracks = soa::Join; + using DataJets = soa::Filtered>; + + std::vector> getInputsForML(bjetParams jetparams, std::vector& tracksParams, std::vector& svsParams) + { + std::vector jetInput = {jetparams.mJetpT, jetparams.mJetEta, jetparams.mJetPhi, static_cast(jetparams.mNTracks), static_cast(jetparams.mNSV), jetparams.mJetMass}; + std::vector tracksInputFlat; + std::vector svsInputFlat; + + for (int iconstit = 0; iconstit < nJetConst; iconstit++) { + + tracksInputFlat.push_back(tracksParams[iconstit].mTrackpT); + tracksInputFlat.push_back(tracksParams[iconstit].mTrackEta); + tracksInputFlat.push_back(tracksParams[iconstit].mDotProdTrackJet); + tracksInputFlat.push_back(tracksParams[iconstit].mDotProdTrackJetOverJet); + tracksInputFlat.push_back(tracksParams[iconstit].mDeltaRJetTrack); + tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP2D); + tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP2DSign); + tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP3D); + tracksInputFlat.push_back(tracksParams[iconstit].mSignedIP3DSign); + tracksInputFlat.push_back(tracksParams[iconstit].mMomFraction); + tracksInputFlat.push_back(tracksParams[iconstit].mDeltaRTrackVertex); + + svsInputFlat.push_back(svsParams[iconstit].mSVpT); + svsInputFlat.push_back(svsParams[iconstit].mDeltaRSVJet); + svsInputFlat.push_back(svsParams[iconstit].mSVMass); + svsInputFlat.push_back(svsParams[iconstit].mSVfE); + svsInputFlat.push_back(svsParams[iconstit].mIPXY); + svsInputFlat.push_back(svsParams[iconstit].mCPA); + svsInputFlat.push_back(svsParams[iconstit].mChi2PCA); + svsInputFlat.push_back(svsParams[iconstit].mDecayLength2D); + svsInputFlat.push_back(svsParams[iconstit].mDecayLength2DError); + svsInputFlat.push_back(svsParams[iconstit].mDecayLength3D); + svsInputFlat.push_back(svsParams[iconstit].mDecayLength3DError); + } + + std::vector> totalInput; + totalInput.push_back(jetInput); + totalInput.push_back(tracksInputFlat); + totalInput.push_back(svsInputFlat); + + return totalInput; + } + + // Looping over the SV info and writing them to a table + template + void analyzeJetSVInfo(AnalysisJet const& myJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector& svsParams, int jetFlavor = 0, double eventweight = 1.0) + { + using SVType = typename SecondaryVertices::iterator; + + // Min-heap to store the top 30 SVs by decayLengthXY/errorDecayLengthXY + auto compare = [](SVType& sv1, SVType& sv2) { + return (sv1.decayLengthXY() / sv1.errorDecayLengthXY()) > (sv2.decayLengthXY() / sv2.errorDecayLengthXY()); + }; + + auto svs = myJet.template secondaryVertices_as(); + + // Sort the SVs based on their decay length significance in descending order + // This is needed in order to select longest SVs since some jets could have thousands of SVs + std::sort(svs.begin(), svs.end(), compare); + + for (const auto& candSV : svs) { + + if (candSV.pt() < svPtMin) { + continue; + } + + double deltaRJetSV = jetutilities::deltaR(myJet, candSV); + double massSV = candSV.m(); + double energySV = candSV.e(); + + if (svsParams.size() < (svReductionFactor * myJet.template tracks_as().size())) { + svsParams.emplace_back(bjetSVParams{candSV.pt(), deltaRJetSV, massSV, energySV / myJet.energy(), candSV.impactParameterXY(), candSV.cpa(), candSV.chi2PCA(), candSV.decayLengthXY(), candSV.errorDecayLengthXY(), candSV.decayLength(), candSV.errorDecayLength()}); + } + + registry.fill(HIST("h2_LxyS_jetpT"), myJet.pt(), candSV.decayLengthXY() / candSV.errorDecayLengthXY(), eventweight); + registry.fill(HIST("h2_Dispersion_jetpT"), myJet.pt(), candSV.chi2PCA(), eventweight); + registry.fill(HIST("h2_SVMass_jetpT"), myJet.pt(), massSV, eventweight); + + if (doprocessMCJets) { + if (jetFlavor == 2) { + registry.fill(HIST("h2_LxyS_jetpT_bjet"), myJet.pt(), candSV.decayLengthXY() / candSV.errorDecayLengthXY(), eventweight); + registry.fill(HIST("h2_Dispersion_jetpT_bjet"), myJet.pt(), candSV.chi2PCA(), eventweight); + registry.fill(HIST("h2_SVMass_jetpT_bjet"), myJet.pt(), massSV, eventweight); + } else if (jetFlavor == 1) { + registry.fill(HIST("h2_LxyS_jetpT_cjet"), myJet.pt(), candSV.decayLengthXY() / candSV.errorDecayLengthXY(), eventweight); + registry.fill(HIST("h2_Dispersion_jetpT_cjet"), myJet.pt(), candSV.chi2PCA(), eventweight); + registry.fill(HIST("h2_SVMass_jetpT_cjet"), myJet.pt(), massSV, eventweight); + } else { + registry.fill(HIST("h2_LxyS_jetpT_lfjet"), myJet.pt(), candSV.decayLengthXY() / candSV.errorDecayLengthXY(), eventweight); + registry.fill(HIST("h2_Dispersion_jetpT_lfjet"), myJet.pt(), candSV.chi2PCA(), eventweight); + registry.fill(HIST("h2_SVMass_jetpT_lfjet"), myJet.pt(), massSV, eventweight); + } + } + } + } + + template + void analyzeJetTrackInfo(AnyCollision const& collision, AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector& tracksParams, int jetFlavor = 0, double eventweight = 1.0) + { + + for (auto& jconstituent : analysisJet.template tracks_as()) { + + if (jconstituent.pt() < trackPtMin) { + continue; + } + + auto constituent = jconstituent.template track_as(); + double deltaRJetTrack = jetutilities::deltaR(analysisJet, constituent); + double dotProduct = RecoDecay::dotProd(std::array{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array{constituent.px(), constituent.py(), constituent.pz()}); + int sign = jettaggingutilities::getGeoSign(collision, analysisJet, constituent); + + float RClosestSV = 10.; + for (const auto& candSV : analysisJet.template secondaryVertices_as()) { + double deltaRTrackSV = jetutilities::deltaR(constituent, candSV); + if (deltaRTrackSV < RClosestSV) { + RClosestSV = deltaRTrackSV; + } + } + + float dcaXYZ(0.), sigmaDcaXYZ2(0.); + dcaXYZ = getDcaXYZ(constituent, &sigmaDcaXYZ2); + + registry.fill(HIST("h2_SIPs2D_jetpT"), analysisJet.pt(), sign * TMath::Abs(constituent.dcaXY()) / TMath::Sqrt(constituent.sigmaDcaXY2()), eventweight); + registry.fill(HIST("h2_SIPs3D_jetpT"), analysisJet.pt(), sign * dcaXYZ / TMath::Sqrt(sigmaDcaXYZ2), eventweight); + + if (doprocessMCJets) { + if (jetFlavor == 2) { + registry.fill(HIST("h2_SIPs2D_jetpT_bjet"), analysisJet.pt(), sign * TMath::Abs(constituent.dcaXY()) / TMath::Sqrt(constituent.sigmaDcaXY2()), eventweight); + registry.fill(HIST("h2_SIPs3D_jetpT_bjet"), analysisJet.pt(), sign * dcaXYZ / TMath::Sqrt(sigmaDcaXYZ2), eventweight); + } else if (jetFlavor == 1) { + registry.fill(HIST("h2_SIPs2D_jetpT_cjet"), analysisJet.pt(), sign * TMath::Abs(constituent.dcaXY()) / TMath::Sqrt(constituent.sigmaDcaXY2()), eventweight); + registry.fill(HIST("h2_SIPs3D_jetpT_cjet"), analysisJet.pt(), sign * dcaXYZ / TMath::Sqrt(sigmaDcaXYZ2), eventweight); + } else { + registry.fill(HIST("h2_SIPs2D_jetpT_lfjet"), analysisJet.pt(), sign * TMath::Abs(constituent.dcaXY()) / TMath::Sqrt(constituent.sigmaDcaXY2()), eventweight); + registry.fill(HIST("h2_SIPs3D_jetpT_lfjet"), analysisJet.pt(), sign * dcaXYZ / TMath::Sqrt(sigmaDcaXYZ2), eventweight); + } + } + + tracksParams.emplace_back(bjetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, TMath::Abs(constituent.dcaXY()) * sign, TMath::Sqrt(constituent.sigmaDcaXY2()), dcaXYZ * sign, TMath::Sqrt(sigmaDcaXYZ2), constituent.p() / analysisJet.p(), RClosestSV}); + } + + auto compare = [](bjetTrackParams& tr1, bjetTrackParams& tr2) { + return (tr1.mSignedIP2D / tr1.mSignedIP2DSign) > (tr2.mSignedIP2D / tr2.mSignedIP2DSign); + }; + + // Sort the tracks based on their IP significance in descending order + std::sort(tracksParams.begin(), tracksParams.end(), compare); + } + + void processDummy(FilteredCollision::iterator const& /*collision*/) + { + } + PROCESS_SWITCH(BJetTaggingML, processDummy, "Dummy process function turned on by default", true); + + void processDataJets(FilteredCollision::iterator const& collision, DataJets const& alljets, JetTrackswID const& allTracks, OriginalTracks const& /*allOrigTracks*/, aod::DataSecondaryVertex3Prongs const& allSVs) + { + if (!jetderiveddatautilities::selectCollision(collision, eventSelection)) { + return; + } + + registry.fill(HIST("h_vertexZ"), collision.posZ()); + + for (const auto& analysisJet : alljets) { + + bool jetIncluded = false; + for (auto jetR : jetRadiiValues) { + if (analysisJet.r() == static_cast(jetR * 100)) { + jetIncluded = true; + break; + } + } + + if (!jetIncluded) { + continue; + } + + std::vector tracksParams; + std::vector SVsParams; + + analyzeJetSVInfo(analysisJet, allTracks, allSVs, SVsParams); + analyzeJetTrackInfo(collision, analysisJet, allTracks, allSVs, tracksParams); + + registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), tracksParams.size()); + registry.fill(HIST("h2_nSV_jetpT"), analysisJet.pt(), SVsParams.size() < 250 ? SVsParams.size() : 249); + + bjetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast(tracksParams.size()), static_cast(SVsParams.size()), analysisJet.mass()}; + tracksParams.resize(nJetConst); // resize to the number of inputs of the ML + SVsParams.resize(nJetConst); // resize to the number of inputs of the ML + + auto inputML = getInputsForML(jetparam, tracksParams, SVsParams); + + std::vector output; + // bool isSelectedMl = bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output); + bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output); + + registry.fill(HIST("h2_score_jetpT"), analysisJet.pt(), output[0]); + + registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass()); + } + } + PROCESS_SWITCH(BJetTaggingML, processDataJets, "jet information in Data", false); + + using MCDJetTable = soa::Filtered>; + using MCPJetTable = soa::Filtered>; + using FilteredCollisionMCD = soa::Filtered>; + + Preslice McParticlesPerCollision = aod::jmcparticle::mcCollisionId; + Preslice McPJetsPerCollision = aod::jet::mcCollisionId; + + void processMCJets(FilteredCollisionMCD::iterator const& collision, MCDJetTable const& MCDjets, MCPJetTable const& MCPjets, JetTracksMCDwID const& allTracks, JetParticles const& MCParticles, aod::MCDSecondaryVertex3Prongs const& allSVs, OriginalTracks const& /*origTracks*/) + { + if (!jetderiveddatautilities::selectCollision(collision, eventSelection)) { + return; + } + + registry.fill(HIST("h_vertexZ"), collision.posZ()); + + auto const mcParticlesPerColl = MCParticles.sliceBy(McParticlesPerCollision, collision.mcCollisionId()); + auto const mcPJetsPerColl = MCPjets.sliceBy(McPJetsPerCollision, collision.mcCollisionId()); + + for (const auto& analysisJet : MCDjets) { + + bool jetIncluded = false; + for (auto jetR : jetRadiiValues) { + if (analysisJet.r() == static_cast(jetR * 100)) { + jetIncluded = true; + break; + } + } + + if (!jetIncluded) { + continue; + } + + std::vector tracksParams; + std::vector SVsParams; + + float eventWeight = analysisJet.eventWeight(); + int jetFlavor = 0; + + for (auto& mcpjet : analysisJet.template matchedJetGeo_as()) { + if (useQuarkDef) { + jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl); + } else { + jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, mcParticlesPerColl); + } + } + + analyzeJetSVInfo(analysisJet, allTracks, allSVs, SVsParams, jetFlavor, eventWeight); + analyzeJetTrackInfo(collision, analysisJet, allTracks, allSVs, tracksParams, jetFlavor, eventWeight); + + registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), tracksParams.size()); + registry.fill(HIST("h2_nSV_jetpT"), analysisJet.pt(), SVsParams.size() < 250 ? SVsParams.size() : 249); + + bjetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast(tracksParams.size()), static_cast(SVsParams.size()), analysisJet.mass()}; + tracksParams.resize(nJetConst); // resize to the number of inputs of the ML + SVsParams.resize(nJetConst); // resize to the number of inputs of the ML + + auto inputML = getInputsForML(jetparam, tracksParams, SVsParams); + + std::vector output; + // bool isSelectedMl = bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output); + bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output); + + registry.fill(HIST("h2_score_jetpT"), analysisJet.pt(), output[0], eventWeight); + + registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass(), eventWeight); + + if (jetFlavor == 2) { + registry.fill(HIST("h2_score_jetpT_bjet"), analysisJet.pt(), output[0], eventWeight); + registry.fill(HIST("h2_jetMass_jetpT_bjet"), analysisJet.pt(), analysisJet.mass(), eventWeight); + registry.fill(HIST("h_jetpT_detector_bjet"), analysisJet.pt(), eventWeight); + } else if (jetFlavor == 1) { + registry.fill(HIST("h2_score_jetpT_cjet"), analysisJet.pt(), output[0], eventWeight); + registry.fill(HIST("h2_jetMass_jetpT_cjet"), analysisJet.pt(), analysisJet.mass(), eventWeight); + registry.fill(HIST("h_jetpT_detector_cjet"), analysisJet.pt(), eventWeight); + } else { + registry.fill(HIST("h2_score_jetpT_lfjet"), analysisJet.pt(), output[0], eventWeight); + registry.fill(HIST("h2_jetMass_jetpT_lfjet"), analysisJet.pt(), analysisJet.mass(), eventWeight); + registry.fill(HIST("h_jetpT_detector_lfjet"), analysisJet.pt(), eventWeight); + } + + for (auto& mcpjet : analysisJet.template matchedJetGeo_as()) { + if (jetFlavor == 2) { + registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_bjet"), analysisJet.pt(), mcpjet.pt(), eventWeight); + } else if (jetFlavor == 1) { + registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_cjet"), analysisJet.pt(), mcpjet.pt(), eventWeight); + } else { + registry.fill(HIST("h2_Response_DetjetpT_PartjetpT_lfjet"), analysisJet.pt(), mcpjet.pt(), eventWeight); + } + } + } + + // For filling histograms used for the jet matching efficiency + for (const auto& mcpjet : mcPJetsPerColl) { + + bool jetIncluded = false; + for (auto jetR : jetRadiiValues) { + if (mcpjet.r() == static_cast(jetR * 100)) { + jetIncluded = true; + break; + } + } + + if (!jetIncluded) { + continue; + } + + int8_t jetFlavor = 0; + + if (useQuarkDef) { + jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl); + } else { + jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, mcParticlesPerColl); + } + + float eventWeight = mcpjet.eventWeight(); + + registry.fill(HIST("h_jetpT_particle_DetColl"), mcpjet.pt(), eventWeight); + + if (jetFlavor == 2) { + registry.fill(HIST("h_jetpT_particle_DetColl_bjet"), mcpjet.pt(), eventWeight); + } else if (jetFlavor == 1) { + registry.fill(HIST("h_jetpT_particle_DetColl_cjet"), mcpjet.pt(), eventWeight); + } else { + registry.fill(HIST("h_jetpT_particle_DetColl_lfjet"), mcpjet.pt(), eventWeight); + } + } + } + PROCESS_SWITCH(BJetTaggingML, processMCJets, "jet information in MC", false); + + Filter mccollisionFilter = nabs(aod::jmccollision::posZ) < vertexZCut; + using FilteredCollisionMCP = soa::Filtered; + + void processMCTruthJets(FilteredCollisionMCP::iterator const& /*collision*/, MCPJetTable const& MCPjets, JetParticles const& MCParticles) + { + + for (const auto& mcpjet : MCPjets) { + + bool jetIncluded = false; + for (auto jetR : jetRadiiValues) { + if (mcpjet.r() == static_cast(jetR * 100)) { + jetIncluded = true; + break; + } + } + + if (!jetIncluded) { + continue; + } + + int8_t jetFlavor = 0; + + if (useQuarkDef) { + jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, MCParticles); + } else { + jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, MCParticles); + } + + float eventWeight = mcpjet.eventWeight(); + + if (jetFlavor == 2) { + registry.fill(HIST("h_jetpT_particle_bjet"), mcpjet.pt(), eventWeight); + } else if (jetFlavor == 1) { + registry.fill(HIST("h_jetpT_particle_cjet"), mcpjet.pt(), eventWeight); + } else { + registry.fill(HIST("h_jetpT_particle_lfjet"), mcpjet.pt(), eventWeight); + } + } + } + PROCESS_SWITCH(BJetTaggingML, processMCTruthJets, "truth jet information", false); +}; + +WorkflowSpec defineDataProcessing(ConfigContext const& cfgc) +{ + return WorkflowSpec{adaptAnalysisTask(cfgc, TaskName{"bjet-tagging-ml"})}; +} diff --git a/PWGJE/Tasks/bjetTreeCreator.cxx b/PWGJE/Tasks/bjetTreeCreator.cxx index 6f51e12d3cf..5a03fa5273f 100644 --- a/PWGJE/Tasks/bjetTreeCreator.cxx +++ b/PWGJE/Tasks/bjetTreeCreator.cxx @@ -162,12 +162,17 @@ struct BJetTreeCreator { Configurable vertexZCut{"vertexZCut", 10.0f, "Accepted z-vertex range"}; Configurable eventSelections{"eventSelections", "sel8", "choose event selection"}; + Configurable> jetPtBins{"jetPtBins", std::vector{5, 1000}, "jet pT bins for reduction"}; + Configurable> jetReductionFactors{"jetReductionFactors", std::vector{0.0}, "jet reduction factors"}; + // track level configurables Configurable trackPtMin{"trackPtMin", 0.5, "minimum track pT"}; Configurable trackPtMax{"trackPtMax", 1000.0, "maximum track pT"}; Configurable trackEtaMin{"trackEtaMin", -0.9, "minimum track eta"}; Configurable trackEtaMax{"trackEtaMax", 0.9, "maximum track eta"}; + Configurable useQuarkDef{"useQuarkDef", true, "Flag whether to use quarks or hadrons for determining the jet flavor"}; + // track level configurables Configurable svPtMin{"svPtMin", 0.5, "minimum SV pT"}; @@ -187,6 +192,8 @@ struct BJetTreeCreator { int eventSelection = -1; std::vector jetRadiiValues; + std::vector jetPtBinsReduction; + std::vector jetReductionFactorsPt; void init(InitContext const&) { @@ -194,6 +201,8 @@ struct BJetTreeCreator { std::srand(static_cast(std::time(nullptr))); jetRadiiValues = (std::vector)jetRadii; + jetPtBinsReduction = (std::vector)jetPtBins; + jetReductionFactorsPt = (std::vector)jetReductionFactors; eventSelection = jetderiveddatautilities::initialiseEventSelection(static_cast(eventSelections)); @@ -252,11 +261,30 @@ struct BJetTreeCreator { Filter jetFilter = (aod::jet::pt >= jetPtMin && aod::jet::pt <= jetPtMax && aod::jet::eta < jetEtaMax - aod::jet::r / 100.f && aod::jet::eta > jetEtaMin + aod::jet::r / 100.f); using FilteredCollision = soa::Filtered>; - using JetTrackswID = soa::Filtered>; - using JetTracksMCDwID = soa::Filtered>; + using JetTrackswID = soa::Join; + using JetTracksMCDwID = soa::Join; using OriginalTracks = soa::Join; using DataJets = soa::Filtered>; + // Function to get the reduction factor based on jet pT + double getReductionFactor(double jetPT) + { + // Loop through the jetPtBins vector + for (size_t ibin = 0; ibin < jetPtBinsReduction.size() - 1; ++ibin) { + if (jetPT >= jetPtBinsReduction[ibin] && jetPT < jetPtBinsReduction[ibin + 1]) { + return jetReductionFactorsPt[ibin]; + } + } + + // If jetPT is above the last bin, use the last reduction factor + if (jetPT >= jetPtBinsReduction.back()) { + return jetReductionFactorsPt.back(); + } + + // If jetPT is below the first bin, return the first reduction factor + return jetReductionFactorsPt.front(); + } + // Looping over the SV info and writing them to a table template void analyzeJetSVInfo(AnalysisJet const& myJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector& svIndices, int jetFlavor = 0, double eventweight = 1.0) @@ -390,6 +418,10 @@ struct BJetTreeCreator { continue; } + if (static_cast(std::rand()) / RAND_MAX < getReductionFactor(analysisJet.pt())) { + continue; + } + std::vector tracksIndices; std::vector SVsIndices; @@ -452,9 +484,18 @@ struct BJetTreeCreator { // jetFlavor = jettaggingutilities::jetTrackFromHFShower(analysisJet, nonFilteredTracks, mcParticlesPerColl, hftrack); for (auto& mcpjet : analysisJet.template matchedJetGeo_as()) { - jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl); - // jetFlavor = jettaggingutilities::mcpJetFromHFShower(mcpjet, mcParticlesPerColl, (float)(mcpjet.r() / 100.)); + if (useQuarkDef) { + jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl); + } else { + jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, mcParticlesPerColl); + // jetFlavor = jettaggingutilities::mcpJetFromHFShower(mcpjet, mcParticlesPerColl, (float)(mcpjet.r() / 100.)); + } + } + + if (jetFlavor == 0 && (static_cast(std::rand()) / RAND_MAX < getReductionFactor(analysisJet.pt()))) { + continue; } + analyzeJetSVInfo(analysisJet, allTracks, allSVs, SVsIndices, jetFlavor, eventWeight); analyzeJetTrackInfo(collision, analysisJet, allTracks, allSVs, tracksIndices, jetFlavor, eventWeight); @@ -513,8 +554,12 @@ struct BJetTreeCreator { } int16_t jetFlavor = 0; - jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, MCParticles); - // jetFlavor = jettaggingutilities::mcpJetFromHFShower(mcpjet, mcParticlesPerColl, (float)(mcpjet.r() / 100.)); + if (useQuarkDef) { + jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, MCParticles); + } else { + jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, MCParticles); + // jetFlavor = jettaggingutilities::mcpJetFromHFShower(mcpjet, MCParticles, (float)(mcpjet.r() / 100.)); + } float eventWeight = mcpjet.eventWeight(); diff --git a/Tools/ML/model.h b/Tools/ML/model.h index bb1ac84a8ae..caca77f1a25 100644 --- a/Tools/ML/model.h +++ b/Tools/ML/model.h @@ -111,6 +111,39 @@ class OnnxModel return evalModel(inputTensors); } + // For 2D inputs + template + T* evalModel(std::vector>& input) + { + std::vector inputTensors; + +#if !__has_include() + Ort::MemoryInfo mem_info = Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault); +#endif + + for (size_t iinput = 0; iinput < input.size(); iinput++) { + [[maybe_unused]] int totalSize = 1; + int64_t size = input[iinput].size(); + for (size_t idim = 1; idim < mInputShapes[iinput].size(); idim++) { + totalSize *= mInputShapes[iinput][idim]; + } + assert(size % totalSize == 0); + + std::vector inputShape{static_cast(size / totalSize)}; + for (size_t idim = 1; idim < mInputShapes[iinput].size(); idim++) { + inputShape.push_back(mInputShapes[iinput][idim]); + } + +#if __has_include() + inputTensors.emplace_back(Ort::Experimental::Value::CreateTensor(input[iinput].data(), size, inputShape)); +#else + inputTensors.emplace_back(Ort::Value::CreateTensor(mem_info, input[iinput].data(), size, inputShape.data(), inputShape.size())); +#endif + } + + return evalModel(inputTensors); + } + // Reset session #if __has_include() void resetSession() { mSession.reset(new Ort::Experimental::Session{*mEnv, modelPath, sessionOptions}); }