Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PWGJE] Add histograms for data-driven methods #8965

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 64 additions & 39 deletions PWGJE/Tasks/bjetTaggingML.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ using namespace o2::framework::expressions;

struct BJetTaggingML {

struct bjetParams {
struct BJetParams {
float mJetpT = 0.0;
float mJetEta = 0.0;
float mJetPhi = 0.0;
Expand All @@ -52,7 +52,7 @@ struct BJetTaggingML {
float mJetMass = 0.0;
};

struct bjetTrackParams {
struct BJetTrackParams {
double mTrackpT = 0.0;
double mTrackEta = 0.0;
double mDotProdTrackJet = 0.0;
Expand All @@ -66,7 +66,7 @@ struct BJetTaggingML {
double mDeltaRTrackVertex = 0.0;
};

struct bjetSVParams {
struct BJetSVParams {
double mSVpT = 0.0;
double mDeltaRSVJet = 0.0;
double mSVMass = 0.0;
Expand All @@ -83,7 +83,7 @@ struct BJetTaggingML {

HistogramRegistry registry;

static constexpr double defaultCutsMl[1][2] = {{0.5, 0.5}};
static constexpr double DefaultCutsMl[1][2] = {{0.5, 0.5}};

// event level configurables
Configurable<float> vertexZCut{"vertexZCut", 10.0f, "Accepted z-vertex range"};
Expand Down Expand Up @@ -111,20 +111,21 @@ struct BJetTaggingML {
Configurable<int> nJetConst{"nJetConst", 10, "maximum number of jet consistuents to be used for ML evaluation"};

Configurable<bool> useQuarkDef{"useQuarkDef", true, "Flag whether to use quarks or hadrons for determining the jet flavor"};
Configurable<bool> doDataDriven{"doDataDriven", false, "Flag whether to use fill THnSpase for data driven methods"};

Configurable<float> svReductionFactor{"svReductionFactor", 1.0, "factor for how many SVs to keep"};

Configurable<std::vector<double>> jetRadii{"jetRadii", std::vector<double>{0.4}, "jet resolution parameters"};

Configurable<std::vector<double>> binsPtMl{"binsPtMl", std::vector<double>{5., 1000.}, "pT bin limits for ML application"};
Configurable<std::vector<int>> cutDirMl{"cutDirMl", std::vector<int>{cuts_ml::CutSmaller, cuts_ml::CutNot}, "Whether to reject score values greater or smaller than the threshold"};
Configurable<LabeledArray<double>> cutsMl{"cutsMl", {defaultCutsMl[0], 1, 2, {"pT bin 0"}, {"score for default b-jet tagging", "uncer 1"}}, "ML selections per pT bin"};
Configurable<LabeledArray<double>> cutsMl{"cutsMl", {DefaultCutsMl[0], 1, 2, {"pT bin 0"}, {"score for default b-jet tagging", "uncer 1"}}, "ML selections per pT bin"};
Configurable<int> nClassesMl{"nClassesMl", 2, "Number of classes in ML model"};
Configurable<std::vector<std::string>> namesInputFeatures{"namesInputFeatures", std::vector<std::string>{"feature1", "feature2"}, "Names of ML model input features"};

Configurable<std::string> ccdbUrl{"ccdbUrl", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
Configurable<std::vector<std::string>> modelPathsCCDB{"modelPathsCCDB", std::vector<std::string>{"Users/h/hahassan"}, "Paths of models on CCDB"};
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ML_bjets/01-MVA/Models/LHC23d4_5_20_90Percent/model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
Configurable<std::vector<std::string>> onnxFileNames{"onnxFileNames", std::vector<std::string>{"ML_bjets/Models/LHC24g4_70_200/model.onnx"}, "ONNX file names for each pT bin (if not from CCDB full path)"};
Configurable<int64_t> timestampCCDB{"timestampCCDB", -1, "timestamp of the ONNX file for ML model used to query in CCDB"};
Configurable<bool> loadModelsFromCCDB{"loadModelsFromCCDB", false, "Flag to enable or disable the loading of models from CCDB"};

Expand Down Expand Up @@ -158,6 +159,15 @@ struct BJetTaggingML {
registry.add("h2_jetMass_jetpT", "Jet mass;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{jet} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 50.0}}});
registry.add("h2_SVMass_jetpT", "Secondary vertex mass;#it{p}_{T,jet} (GeV/#it{c});#it{m}_{SV} (GeV/#it{c}^{2})", {HistType::kTH2F, {{200, 0., 200.}, {50, 0, 10}}});

if (doDataDriven) {
registry.add("hSparse_Incljets", "Inclusive jets Info;#it{p}_{T,jet} (GeV/#it{c});Score;#it{m}_{jet} (GeV/#it{c}^{2});#it{m}_{SV} (GeV/#it{c}^{2});SVfE;", {HistType::kTHnSparseF, {{200, 0., 200.}, {120, -0.1, 1.1}, {50, 0, 50}, {50, 0, 10}, {50, 0, 1}}});
if (doprocessMCJets) {
registry.add("hSparse_bjets", "Tagged b-jets Info;#it{p}_{T,jet} (GeV/#it{c});Score;#it{m}_{jet} (GeV/#it{c}^{2});#it{m}_{SV} (GeV/#it{c}^{2});SVfE;", {HistType::kTHnSparseF, {{200, 0., 200.}, {120, -0.1, 1.1}, {50, 0, 50}, {50, 0, 10}, {50, 0, 1}}});
registry.add("hSparse_cjets", "Tagged c-jets Info;#it{p}_{T,jet} (GeV/#it{c});Score;#it{m}_{jet} (GeV/#it{c}^{2});#it{m}_{SV} (GeV/#it{c}^{2});SVfE;", {HistType::kTHnSparseF, {{200, 0., 200.}, {120, -0.1, 1.1}, {50, 0, 50}, {50, 0, 10}, {50, 0, 1}}});
registry.add("hSparse_lfjets", "Tagged lf-jets Info;#it{p}_{T,jet} (GeV/#it{c});Score;#it{m}_{jet} (GeV/#it{c}^{2});#it{m}_{SV} (GeV/#it{c}^{2});SVfE;", {HistType::kTHnSparseF, {{200, 0., 200.}, {120, -0.1, 1.1}, {50, 0, 50}, {50, 0, 10}, {50, 0, 1}}});
}
}

if (doprocessMCJets) {

registry.add("h2_score_jetpT_bjet", "ML scores for b-jets;#it{p}_{T,jet} (GeV/#it{c});Score", {HistType::kTH2F, {{200, 0., 200.}, {120, -0.1, 1.1}}});
Expand Down Expand Up @@ -224,7 +234,7 @@ struct BJetTaggingML {
using JetTracksMCDwID = soa::Join<aod::JetTracksMCD, aod::JTrackExtras, aod::JTrackPIs>;
using DataJets = soa::Filtered<soa::Join<aod::ChargedJets, aod::ChargedJetConstituents, aod::DataSecondaryVertex3ProngIndices>>;

std::vector<std::vector<float>> getInputsForML(bjetParams jetparams, std::vector<bjetTrackParams>& tracksParams, std::vector<bjetSVParams>& svsParams)
std::vector<std::vector<float>> getInputsForML(BJetParams jetparams, std::vector<BJetTrackParams>& tracksParams, std::vector<BJetSVParams>& svsParams)
{
std::vector<float> jetInput = {jetparams.mJetpT, jetparams.mJetEta, jetparams.mJetPhi, static_cast<float>(jetparams.mNTracks), static_cast<float>(jetparams.mNSV), jetparams.mJetMass};
std::vector<float> tracksInputFlat;
Expand Down Expand Up @@ -268,7 +278,7 @@ struct BJetTaggingML {

// Looping over the SV info and writing them to a table
template <typename AnalysisJet, typename AnyTracks, typename SecondaryVertices>
void analyzeJetSVInfo(AnalysisJet const& myJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<bjetSVParams>& svsParams, int jetFlavor = 0, double eventweight = 1.0)
void analyzeJetSVInfo(AnalysisJet const& myJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<BJetSVParams>& svsParams, int jetFlavor = 0, double eventweight = 1.0)
{
using SVType = typename SecondaryVertices::iterator;

Expand All @@ -294,7 +304,7 @@ struct BJetTaggingML {
double energySV = candSV.e();

if (svsParams.size() < (svReductionFactor * myJet.template tracks_as<AnyTracks>().size())) {
svsParams.emplace_back(bjetSVParams{candSV.pt(), deltaRJetSV, massSV, energySV / myJet.energy(), candSV.impactParameterXY(), candSV.cpa(), candSV.chi2PCA(), candSV.dispersion(), candSV.decayLengthXY(), candSV.errorDecayLengthXY(), candSV.decayLength(), candSV.errorDecayLength()});
svsParams.emplace_back(BJetSVParams{candSV.pt(), deltaRJetSV, massSV, energySV / myJet.energy(), candSV.impactParameterXY(), candSV.cpa(), candSV.chi2PCA(), candSV.dispersion(), candSV.decayLengthXY(), candSV.errorDecayLengthXY(), candSV.decayLength(), candSV.errorDecayLength()});
}

registry.fill(HIST("h2_LxyS_jetpT"), myJet.pt(), candSV.decayLengthXY() / candSV.errorDecayLengthXY(), eventweight);
Expand All @@ -320,10 +330,10 @@ struct BJetTaggingML {
}

template <typename AnyCollision, typename AnalysisJet, typename AnyTracks, typename SecondaryVertices>
void analyzeJetTrackInfo(AnyCollision const& /*collision*/, AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<bjetTrackParams>& tracksParams, int jetFlavor = 0, double eventweight = 1.0)
void analyzeJetTrackInfo(AnyCollision const& /*collision*/, AnalysisJet const& analysisJet, AnyTracks const& /*allTracks*/, SecondaryVertices const& /*allSVs*/, std::vector<BJetTrackParams>& tracksParams, int jetFlavor = 0, double eventweight = 1.0)
{

for (auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {
for (const auto& constituent : analysisJet.template tracks_as<AnyTracks>()) {

if (constituent.pt() < trackPtMin) {
continue;
Expand All @@ -333,11 +343,11 @@ struct BJetTaggingML {
double dotProduct = RecoDecay::dotProd(std::array<float, 3>{analysisJet.px(), analysisJet.py(), analysisJet.pz()}, std::array<float, 3>{constituent.px(), constituent.py(), constituent.pz()});
int sign = jettaggingutilities::getGeoSign(analysisJet, constituent);

float RClosestSV = 10.;
float rClosestSV = 10.;
for (const auto& candSV : analysisJet.template secondaryVertices_as<SecondaryVertices>()) {
double deltaRTrackSV = jetutilities::deltaR(constituent, candSV);
if (deltaRTrackSV < RClosestSV) {
RClosestSV = deltaRTrackSV;
if (deltaRTrackSV < rClosestSV) {
rClosestSV = deltaRTrackSV;
}
}

Expand All @@ -357,10 +367,10 @@ struct BJetTaggingML {
}
}

tracksParams.emplace_back(bjetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), RClosestSV});
tracksParams.emplace_back(BJetTrackParams{constituent.pt(), constituent.eta(), dotProduct, dotProduct / analysisJet.p(), deltaRJetTrack, std::abs(constituent.dcaXY()) * sign, constituent.sigmadcaXY(), std::abs(constituent.dcaXYZ()) * sign, constituent.sigmadcaXYZ(), constituent.p() / analysisJet.p(), rClosestSV});
}

auto compare = [](bjetTrackParams& tr1, bjetTrackParams& tr2) {
auto compare = [](BJetTrackParams& tr1, BJetTrackParams& tr2) {
return (tr1.mSignedIP2D / tr1.mSignedIP2DSign) > (tr2.mSignedIP2D / tr2.mSignedIP2DSign);
};

Expand All @@ -384,7 +394,7 @@ struct BJetTaggingML {
for (const auto& analysisJet : alljets) {

bool jetIncluded = false;
for (auto jetR : jetRadiiValues) {
for (const auto& jetR : jetRadiiValues) {
if (analysisJet.r() == static_cast<int>(jetR * 100)) {
jetIncluded = true;
break;
Expand All @@ -395,22 +405,22 @@ struct BJetTaggingML {
continue;
}

std::vector<bjetTrackParams> tracksParams;
std::vector<bjetSVParams> SVsParams;
std::vector<BJetTrackParams> tracksParams;
std::vector<BJetSVParams> svsParams;

analyzeJetSVInfo(analysisJet, allTracks, allSVs, SVsParams);
analyzeJetSVInfo(analysisJet, allTracks, allSVs, svsParams);
analyzeJetTrackInfo(collision, analysisJet, allTracks, allSVs, tracksParams);

int nSVs = analysisJet.template secondaryVertices_as<aod::DataSecondaryVertex3Prongs>().size();

registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), tracksParams.size());
registry.fill(HIST("h2_nSV_jetpT"), analysisJet.pt(), nSVs < 250 ? nSVs : 249);

bjetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast<int>(tracksParams.size()), static_cast<int>(nSVs), analysisJet.mass()};
BJetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast<int>(tracksParams.size()), static_cast<int>(nSVs), analysisJet.mass()};
tracksParams.resize(nJetConst); // resize to the number of inputs of the ML
SVsParams.resize(nJetConst); // resize to the number of inputs of the ML
svsParams.resize(nJetConst); // resize to the number of inputs of the ML

auto inputML = getInputsForML(jetparam, tracksParams, SVsParams);
auto inputML = getInputsForML(jetparam, tracksParams, svsParams);

std::vector<float> output;
// bool isSelectedMl = bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output);
Expand All @@ -419,6 +429,10 @@ struct BJetTaggingML {
registry.fill(HIST("h2_score_jetpT"), analysisJet.pt(), output[0]);

registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass());

if (doDataDriven) {
registry.fill(HIST("hSparse_Incljets"), analysisJet.pt(), output[0], analysisJet.mass(), svsParams[0].mSVMass, svsParams[0].mSVfE);
}
}
}
PROCESS_SWITCH(BJetTaggingML, processDataJets, "jet information in Data", false);
Expand All @@ -427,8 +441,8 @@ struct BJetTaggingML {
using MCPJetTable = soa::Filtered<soa::Join<aod::ChargedMCParticleLevelJets, aod::ChargedMCParticleLevelJetConstituents, aod::ChargedMCParticleLevelJetsMatchedToChargedMCDetectorLevelJets, aod::ChargedMCParticleLevelJetEventWeights>>;
using FilteredCollisionMCD = soa::Filtered<soa::Join<aod::JCollisions, aod::JCollisionPIs, aod::JMcCollisionLbs>>;

Preslice<aod::JMcParticles> McParticlesPerCollision = aod::jmcparticle::mcCollisionId;
Preslice<MCPJetTable> McPJetsPerCollision = aod::jet::mcCollisionId;
Preslice<aod::JMcParticles> mcParticlesPerCollision = aod::jmcparticle::mcCollisionId;
Preslice<MCPJetTable> mcpJetsPerCollision = aod::jet::mcCollisionId;

void processMCJets(FilteredCollisionMCD::iterator const& collision, MCDJetTable const& MCDjets, MCPJetTable const& MCPjets, JetTracksMCDwID const& allTracks, aod::JetParticles const& MCParticles, aod::MCDSecondaryVertex3Prongs const& allSVs)
{
Expand All @@ -438,13 +452,13 @@ struct BJetTaggingML {

registry.fill(HIST("h_vertexZ"), collision.posZ());

auto const mcParticlesPerColl = MCParticles.sliceBy(McParticlesPerCollision, collision.mcCollisionId());
auto const mcPJetsPerColl = MCPjets.sliceBy(McPJetsPerCollision, collision.mcCollisionId());
auto const mcParticlesPerColl = MCParticles.sliceBy(mcParticlesPerCollision, collision.mcCollisionId());
auto const mcPJetsPerColl = MCPjets.sliceBy(mcpJetsPerCollision, collision.mcCollisionId());

for (const auto& analysisJet : MCDjets) {

bool jetIncluded = false;
for (auto jetR : jetRadiiValues) {
for (const auto& jetR : jetRadiiValues) {
if (analysisJet.r() == static_cast<int>(jetR * 100)) {
jetIncluded = true;
break;
Expand All @@ -461,32 +475,32 @@ struct BJetTaggingML {
continue;
}

std::vector<bjetTrackParams> tracksParams;
std::vector<bjetSVParams> SVsParams;
std::vector<BJetTrackParams> tracksParams;
std::vector<BJetSVParams> svsParams;

int jetFlavor = 0;

for (auto& mcpjet : analysisJet.template matchedJetGeo_as<MCPJetTable>()) {
for (const auto& mcpjet : analysisJet.template matchedJetGeo_as<MCPJetTable>()) {
if (useQuarkDef) {
jetFlavor = jettaggingutilities::getJetFlavor(mcpjet, mcParticlesPerColl);
} else {
jetFlavor = jettaggingutilities::getJetFlavorHadron(mcpjet, mcParticlesPerColl);
}
}

analyzeJetSVInfo(analysisJet, allTracks, allSVs, SVsParams, jetFlavor, eventWeight);
analyzeJetSVInfo(analysisJet, allTracks, allSVs, svsParams, jetFlavor, eventWeight);
analyzeJetTrackInfo(collision, analysisJet, allTracks, allSVs, tracksParams, jetFlavor, eventWeight);

int nSVs = analysisJet.template secondaryVertices_as<aod::MCDSecondaryVertex3Prongs>().size();

registry.fill(HIST("h2_nTracks_jetpT"), analysisJet.pt(), tracksParams.size());
registry.fill(HIST("h2_nSV_jetpT"), analysisJet.pt(), nSVs < 250 ? nSVs : 249);

bjetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast<int>(tracksParams.size()), static_cast<int>(nSVs), analysisJet.mass()};
BJetParams jetparam = {analysisJet.pt(), analysisJet.eta(), analysisJet.phi(), static_cast<int>(tracksParams.size()), static_cast<int>(nSVs), analysisJet.mass()};
tracksParams.resize(nJetConst); // resize to the number of inputs of the ML
SVsParams.resize(nJetConst); // resize to the number of inputs of the ML
svsParams.resize(nJetConst); // resize to the number of inputs of the ML

auto inputML = getInputsForML(jetparam, tracksParams, SVsParams);
auto inputML = getInputsForML(jetparam, tracksParams, svsParams);

std::vector<float> output;
// bool isSelectedMl = bMlResponse.isSelectedMl(inputML, analysisJet.pt(), output);
Expand All @@ -496,6 +510,17 @@ struct BJetTaggingML {

registry.fill(HIST("h2_jetMass_jetpT"), analysisJet.pt(), analysisJet.mass(), eventWeight);

if (doDataDriven) {
registry.fill(HIST("hSparse_Incljets"), analysisJet.pt(), output[0], analysisJet.mass(), svsParams[0].mSVMass, svsParams[0].mSVfE);
if (jetFlavor == 2) {
registry.fill(HIST("hSparse_bjets"), analysisJet.pt(), output[0], analysisJet.mass(), svsParams[0].mSVMass, svsParams[0].mSVfE);
} else if (jetFlavor == 1) {
registry.fill(HIST("hSparse_cjets"), analysisJet.pt(), output[0], analysisJet.mass(), svsParams[0].mSVMass, svsParams[0].mSVfE);
} else {
registry.fill(HIST("hSparse_lfjets"), analysisJet.pt(), output[0], analysisJet.mass(), svsParams[0].mSVMass, svsParams[0].mSVfE);
}
}

if (jetFlavor == 2) {
registry.fill(HIST("h2_score_jetpT_bjet"), analysisJet.pt(), output[0], eventWeight);
registry.fill(HIST("h2_jetMass_jetpT_bjet"), analysisJet.pt(), analysisJet.mass(), eventWeight);
Expand All @@ -510,7 +535,7 @@ struct BJetTaggingML {
registry.fill(HIST("h_jetpT_detector_lfjet"), analysisJet.pt(), eventWeight);
}

for (auto& mcpjet : analysisJet.template matchedJetGeo_as<MCPJetTable>()) {
for (const auto& mcpjet : analysisJet.template matchedJetGeo_as<MCPJetTable>()) {
if (mcpjet.pt() > pTHatMaxMCP * pTHat) {
continue;
}
Expand All @@ -529,7 +554,7 @@ struct BJetTaggingML {
for (const auto& mcpjet : mcPJetsPerColl) {

bool jetIncluded = false;
for (auto jetR : jetRadiiValues) {
for (const auto& jetR : jetRadiiValues) {
if (mcpjet.r() == static_cast<int>(jetR * 100)) {
jetIncluded = true;
break;
Expand Down Expand Up @@ -576,7 +601,7 @@ struct BJetTaggingML {
for (const auto& mcpjet : MCPjets) {

bool jetIncluded = false;
for (auto jetR : jetRadiiValues) {
for (const auto& jetR : jetRadiiValues) {
if (mcpjet.r() == static_cast<int>(jetR * 100)) {
jetIncluded = true;
break;
Expand Down Expand Up @@ -615,5 +640,5 @@ struct BJetTaggingML {

WorkflowSpec defineDataProcessing(ConfigContext const& cfgc)
{
return WorkflowSpec{adaptAnalysisTask<BJetTaggingML>(cfgc, TaskName{"bjet-tagging-ml"})};
return WorkflowSpec{adaptAnalysisTask<BJetTaggingML>(cfgc, TaskName{"bjet-tagging-ml"})}; // o2-linter: disable=name/o2-task,name/workflow-file
}
Loading