Skip to content

Commit

Permalink
PIDML: evaluate FSE + self-attention network (AliceO2Group#7162)
Browse files Browse the repository at this point in the history
* PIDML evaluate FSE + self-attention network (#5)

* remove detector count setting and reorder network arguments (with NaNs if detector not available)

* update README.md

* markdownlint changes

* MegaLinter fixes (#6)

* fix include missing file, the same way it was before

* readd pLimits to ONNXinterface and pass it to ONNXmodel

* Please consider the following formatting changes (#9)

* improve qaPidML according to new approach

---------

Co-authored-by: ALICE Builder <[email protected]>
  • Loading branch information
mytkom and alibuild authored Aug 7, 2024
1 parent dff47c2 commit 8247afd
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 127 deletions.
3 changes: 1 addition & 2 deletions Tools/PIDML/KaonPidTask.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ struct KaonPidTask {
Configurable<std::string> cfgCCDBURL{"ccdb-url", "http://alice-ccdb.cern.ch", "URL of the CCDB repository"};
Configurable<int> cfgPid{"pid", 321, "PID to predict"};
Configurable<double> cfgCertainty{"certainty", 0.5, "Minimum certainty above which the model accepts a particular type of particle"};
Configurable<uint32_t> cfgDetector{"detector", kTPCTOFTRD, "What detectors to use: 0: TPC only, 1: TPC + TOF, 2: TPC + TOF + TRD"};
Configurable<uint64_t> cfgTimestamp{"timestamp", 0, "Fixed timestamp"};
Configurable<bool> cfgUseCCDB{"useCCDB", false, "Whether to autofetch ML model from CCDB. If false, local file will be used."};

Expand All @@ -85,7 +84,7 @@ struct KaonPidTask {
if (cfgUseCCDB) {
ccdbApi.init(cfgCCDBURL); // Initializes ccdbApi when cfgUseCCDB is set to 'true'
}
pidModel = std::make_shared<PidONNXModel>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, cfgTimestamp.value, cfgPid.value, static_cast<PidMLDetector>(cfgDetector.value), cfgCertainty.value);
pidModel = std::make_shared<PidONNXModel>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, cfgTimestamp.value, cfgPid.value, cfgCertainty.value);

histos.add("hChargePos", ";z;", kTH1F, {{3, -1.5, 1.5}});
histos.add("hChargeNeg", ";z;", kTH1F, {{3, -1.5, 1.5}});
Expand Down
65 changes: 50 additions & 15 deletions Tools/PIDML/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# PID ML in O2

Particle identification is essential in most of the analyzes. The PID ML interface will help you to make use of the machine learning models to improve purity and efficiency of particle kinds for your analysis. A single model is tailored to a specific particle kind, e.g., pions with PID 211. For each track, the model returns a float value in [0, 1] which measures the ''certainty'' of the model that this track is of given kind.
Particle identification is essential in most of the analyzes.
The PID ML interface will help you to make use of the machine learning models to improve purity and efficiency of particle kinds for your analysis.
A single model is tailored to a specific particle kind, e.g., pions with PID 211. For each track, the model returns a float value in [0, 1] which measures the ''certainty'' of the model that this track is of given kind.

## PidONNXModel

Expand All @@ -11,12 +13,16 @@ This class represents a single ML model from an ONNX file. It requires the follo
- CCDB Api instance created in an analysis task
- timestamp of the input analysis data -- neded to choose appropriate model
- PID to be checked
- detector setup: what detectors should be used for identification. It is described by enum PidMLDetector. Currently available setups: TPC, TPC+TOF, TPC+TOF+TRD
- minimum certainty for accepting a track to be of given PID
- *p* limits array - specifiying p limits for each detector configuration (TPC, TPC+TOF, TPC+TOF+TRD)

Let's assume your `PidONNXModel` instance is named `pidModel`. Then, inside your analysis task `process()` function, you can iterate over tracks and call: `pidModel.applyModel(track);` to get the certainty of the model. You can also use `pidModel.applyModelBoolean(track);` to receive a true/false answer, whether the track can be accepted based on the minimum certainty provided to the `PidONNXModel` constructor.
Let's assume your `PidONNXModel` instance is named `pidModel`.
Then, inside your analysis task `process()` function, you can iterate over tracks and call: `pidModel.applyModel(track);` to get the certainty of the model.
You can also use `pidModel.applyModelBoolean(track);` to receive a true/false answer, whether the track can be accepted based on the minimum certainty provided to the `PidONNXModel` constructor.

You can check [a simple analysis task example](https://github.com/AliceO2Group/O2Physics/blob/master/Tools/PIDML/simpleApplyPidOnnxModel.cxx). It uses configurable parameters and shows how to calculate the data timestamp. Note that the calculation of the timestamp requires subscribing to `aod::Collisions` and `aod::BCsWithTimestamps`. For Hyperloop tests, you can set `cfgUseFixedTimestamp` to true with `cfgTimestamp` set to the default value.
You can check [a simple analysis task example](https://github.com/AliceO2Group/O2Physics/blob/master/Tools/PIDML/simpleApplyPidOnnxModel.cxx).
It uses configurable parameters and shows how to calculate the data timestamp. Note that the calculation of the timestamp requires subscribing to `aod::Collisions` and `aod::BCsWithTimestamps`.
For Hyperloop tests, you can set `cfgUseFixedTimestamp` to true with `cfgTimestamp` set to the default value.

On the other hand, it is possible to use locally stored models, and then the timestamp is not used, so it can be a dummy value. `processTracksOnly` presents how to analyze on local-only PID ML models.

Expand All @@ -31,10 +37,10 @@ This is a wrapper around PidONNXModel that contains several models. It has the p

Then, obligatory parameters for the interface:
- a vector of int output PIDs
- a 2-dimensional LabeledArray of *p*T limits for each PID, for each detector configuration. It describes the minimum *p*T values at which each next detector should be included for predicting given PID
- a 2-dimensional LabeledArray of *p* limits for each PID, for each detector configuration. It describes the minimum *p* values at which each next detector should be included for predicting given PID
- a vector of minimum certainties for each PID for accepting a track to be of this PID
- boolean flag: whether to switch on auto mode. If true, then *p*T limits and minimum certainties can be passed as an empty array and an empty vector, and the interface will fill them with default configuration:
- *p*T limits: same values for all PIDs: 0.0 (TPC), 0.5 (TPC + TOF), 0.8 (TPC + TOF + TRD)
- *p* limits: same values for all PIDs: 0.0 (TPC), 0.5 (TPC + TOF), 0.8 (TPC + TOF + TRD)
- minimum certainties: 0.5 for all PIDs

You can use the interface in the same way as the model, by calling `applyModel(track)` or `applyModelBoolean(track)`. The interface will then call the respective method of the model selected with the aforementioned interface parameters.
Expand All @@ -48,20 +54,49 @@ There is again [a simple analysis task example](https://github.com/AliceO2Group/
Currently, only models for run 285064 (timestamp interval: 1524176895000 - 1524212953000) are uploaded to CCDB, so you can use hardcoded timestamp 1524176895000 for tests.

Both model and interface analysis examples can be run with a script:

### Script for Run2 Converted to Run3 data
```bash
#!/bin/bash

config_file="my-config.json"

o2-analysis-tracks-extra-converter --configuration json://$config_file -b |
o2-analysis-timestamp --configuration json://$config_file -b |
o2-analysis-trackextension --configuration json://$config_file -b |
o2-analysis-trackselection --configuration json://$config_file -b |
o2-analysis-multiplicity-table --configuration json://$config_file -b |
o2-analysis-bc-converter --configuration json://$config_file -b |
o2-analysis-collision-converter --configuration json://$config_file -b |
o2-analysis-zdc-converter --configuration json://$config_file -b |
o2-analysis-pid-tof-base --configuration json://$config_file -b |
o2-analysis-pid-tof-beta --configuration json://$config_file -b |
o2-analysis-pid-tof-full --configuration json://$config_file -b |
o2-analysis-pid-tpc-full --configuration json://$config_file -b |
o2-analysis-pid-tpc-base --configuration json://$config_file -b |
o2-analysis-simple-apply-pid-onnx-model --configuration json://$config_file -b
```
Remember to set every setting, which states that helper task should process Run2 data to `true`.

### Script for Run3 data
```bash
#!/bin/bash

config_file="my-config.json"

o2-analysis-timestamp --configuration json://$config_file -b |
o2-analysis-trackextension --configuration json://$config_file -b |
o2-analysis-trackselection --configuration json://$config_file -b |
o2-analysis-multiplicity-table --configuration json://$config_file -b |
o2-analysis-fdd-converter --configuration json://$config_file -b |
o2-analysis-pid-tof-base --configuration json://$config_file -b |
o2-analysis-pid-tof-beta --configuration json://$config_file -b |
o2-analysis-pid-tof-full --configuration json://$config_file -b |
o2-analysis-pid-tpc-full --configuration json://$config_file -b |
o2-analysis-simple-apply-pid-onnx-model --configuration json://$config_file -b
o2-analysis-event-selection --configuration json://$config_file -b |
o2-analysis-trackselection --configuration json://$config_file -b |
o2-analysis-multiplicity-table --configuration json://$config_file -b |
o2-analysis-track-propagation --configuration json://$config_file -b |
o2-analysis-pid-tof-base --configuration json://$config_file -b |
o2-analysis-pid-tof-beta --configuration json://$config_file -b |
o2-analysis-pid-tof-full --configuration json://$config_file -b |
o2-analysis-pid-tpc-full --configuration json://$config_file -b |
o2-analysis-pid-tpc-base --configuration json://$config_file -b |
o2-analysis-simple-apply-pid-onnx-model --configuration json://$config_file -b
```
Remember to set every setting, which states that helper task should process Run3 data to `true`.


Replace "model" with "interface" in the last line if you want to run the interface workflow.
30 changes: 10 additions & 20 deletions Tools/PIDML/pidOnnxInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,25 @@ auto certainties_v = std::vector<double>{certainties, certainties + nPids};

// default values for the cuts
constexpr double cuts[nPids][nCutVars] = {{0.0, 0.5, 0.8}, {0.0, 0.5, 0.8}, {0.0, 0.5, 0.8}, {0.0, 0.5, 0.8}, {0.0, 0.5, 0.8}, {0.0, 0.5, 0.8}};

// row labels
static const std::vector<std::string> pidLabels = {
"211", "321", "2212", "0211", "0321", "02212"};
// column labels
static const std::vector<std::string> cutVarLabels = {
"TPC", "TPC + TOF", "TPC + TOF + TRD"};

} // namespace pidml_pt_cuts

struct PidONNXInterface {
PidONNXInterface(std::string& localPath, std::string& ccdbPath, bool useCCDB, o2::ccdb::CcdbApi& ccdbApi, uint64_t timestamp, std::vector<int> const& pids, o2::framework::LabeledArray<double> const& pTLimits, std::vector<double> const& minCertainties, bool autoMode) : mNPids{pids.size()}, mPTLimits{pTLimits}
PidONNXInterface(std::string& localPath, std::string& ccdbPath, bool useCCDB, o2::ccdb::CcdbApi& ccdbApi, uint64_t timestamp, std::vector<int> const& pids, o2::framework::LabeledArray<double> const& pLimits, std::vector<double> const& minCertainties, bool autoMode) : mNPids{pids.size()}, mPLimits{pLimits}
{
if (pids.size() == 0) {
LOG(fatal) << "PID ML Interface needs at least 1 output pid to predict";
}
std::set<int> tmp;
for (auto& pid : pids) {
if (!tmp.insert(pid).second) {
LOG(fatal) << "PID M Interface: output pids cannot repeat!";
LOG(fatal) << "PID ML Interface: output pids cannot repeat!";
}
}

Expand All @@ -68,9 +68,7 @@ struct PidONNXInterface {
minCertaintiesFilled = minCertainties;
}
for (std::size_t i = 0; i < mNPids; i++) {
for (uint32_t j = 0; j < kNDetectors; j++) {
mModels.emplace_back(localPath, ccdbPath, useCCDB, ccdbApi, timestamp, pids[i], (PidMLDetector)(kTPCOnly + j), minCertaintiesFilled[i]);
}
mModels.emplace_back(localPath, ccdbPath, useCCDB, ccdbApi, timestamp, pids[i], minCertaintiesFilled[i], mPLimits[i]);
}
}
PidONNXInterface() = default;
Expand All @@ -84,12 +82,8 @@ struct PidONNXInterface {
float applyModel(const T& track, int pid)
{
for (std::size_t i = 0; i < mNPids; i++) {
if (mModels[i * kNDetectors].mPid == pid) {
for (uint32_t j = 0; j < kNDetectors; j++) {
if (track.pt() >= mPTLimits[i][j] && (j == kNDetectors - 1 || track.pt() < mPTLimits[i][j + 1])) {
return mModels[i * kNDetectors + j].applyModel(track);
}
}
if (mModels[i].mPid == pid) {
return mModels[i].applyModel(track);
}
}
LOG(error) << "No suitable PID ML model found for track: " << track.globalIndex() << " from collision: " << track.collision().globalIndex() << " and expected pid: " << pid;
Expand All @@ -100,12 +94,8 @@ struct PidONNXInterface {
bool applyModelBoolean(const T& track, int pid)
{
for (std::size_t i = 0; i < mNPids; i++) {
if (mModels[i * kNDetectors].mPid == pid) {
for (uint32_t j = 0; j < kNDetectors; j++) {
if (track.pt() >= mPTLimits[i][j] && (j == kNDetectors - 1 || track.pt() < mPTLimits[i][j + 1])) {
return mModels[i * kNDetectors + j].applyModelBoolean(track);
}
}
if (mModels[i].mPid == pid) {
return mModels[i].applyModelBoolean(track);
}
}
LOG(error) << "No suitable PID ML model found for track: " << track.globalIndex() << " from collision: " << track.collision().globalIndex() << " and expected pid: " << pid;
Expand All @@ -116,12 +106,12 @@ struct PidONNXInterface {
void fillDefaultConfiguration(std::vector<double>& minCertainties)
{
// FIXME: A more sophisticated strategy should be based on pid values as well
mPTLimits = o2::framework::LabeledArray{pidml_pt_cuts::cuts[0], pidml_pt_cuts::nPids, pidml_pt_cuts::nCutVars, pidml_pt_cuts::pidLabels, pidml_pt_cuts::cutVarLabels};
mPLimits = o2::framework::LabeledArray{pidml_pt_cuts::cuts[0], pidml_pt_cuts::nPids, pidml_pt_cuts::nCutVars, pidml_pt_cuts::pidLabels, pidml_pt_cuts::cutVarLabels};
minCertainties = std::vector<double>(mNPids, 0.5);
}

std::vector<PidONNXModel> mModels;
std::size_t mNPids;
o2::framework::LabeledArray<double> mPTLimits;
o2::framework::LabeledArray<double> mPLimits;
};
#endif // TOOLS_PIDML_PIDONNXINTERFACE_H_
Loading

0 comments on commit 8247afd

Please sign in to comment.