diff --git a/TMVA_CrossValidation.ipynb b/TMVA_CrossValidation.ipynb new file mode 100644 index 0000000..bec911e --- /dev/null +++ b/TMVA_CrossValidation.ipynb @@ -0,0 +1,692 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "# TMVA Cross Validation Example " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate Data\n", + "\n", + "We define the function to generate data" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "TTree *genTree(Int_t nPoints, Double_t offset, Double_t scale, UInt_t seed = 100)\n", + "{\n", + " TRandom3 rng(seed);\n", + " Float_t x = 0;\n", + " Float_t y = 0;\n", + " UInt_t eventID = 0;\n", + "\n", + " TTree *data = new TTree();\n", + " data->Branch(\"x\", &x, \"x/F\");\n", + " data->Branch(\"y\", &y, \"y/F\");\n", + " data->Branch(\"eventID\", &eventID, \"eventID/I\");\n", + "\n", + " for (Int_t n = 0; n < nPoints; ++n) {\n", + " x = rng.Gaus(offset, scale);\n", + " y = rng.Gaus(offset, scale);\n", + "\n", + " // For our simple example it is enough that the id's are uniformly\n", + " // distributed and independent of the data.\n", + " ++eventID;\n", + "\n", + " data->Fill();\n", + " }\n", + "\n", + " // Important: Disconnects the tree from the memory locations of x and y.\n", + " data->ResetBranchAddresses();\n", + " return data;\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Define Output File\n", + "\n", + "We declare the file for output" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "TMVA::Tools::Instance();\n", + "\n", + "auto outputFile = TFile::Open(\"CV_Output.root\", \"RECREATE\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TMVA Factory\n", + "\n", + "Start by creating the Factory class. We can use the factory to choose the methods whose performance you'd like to investigate. \n", + "\n", + "The factory is the major TMVA object you have to interact with. Here is the list of parameters you need to pass\n", + "\n", + " - The first argument is the base of the name of all the output\n", + "weightfiles in the directory weight/ that will be created with the \n", + "method parameters \n", + "\n", + " - The second argument is the output file for the training results\n", + " \n", + " - The third argument is a string option defining some general configuration for the TMVA session. For example all TMVA output can be suppressed by removing the \"!\" (not) in front of the \"Silent\" argument in the option string\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "TMVA::Factory factory(\"TMVAClassification\", outputFile,\n", + " \"!V:ROC:!Silent:Color:!DrawProgressBar:AnalysisType=Classification\" ); " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## DataLoader\n", + "\n", + "The next step is to declare the DataLoader class which provides the interface from TMVA to the input data \n", + "\n", + "### Define input variables\n", + "\n", + "Through the DataLoader we define the input variables that will be used for the MVA training." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "TMVA::DataLoader * loader = new TMVA::DataLoader(\"dataset\");\n", + "\n", + "loader->AddVariable(\"x\", 'F');\n", + "loader->AddVariable(\"y\", 'F');" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup Dataset(s)\n", + "\n", + "Define input data file and signal and background trees" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "DataSetInfo : [dataset] : Added class \"Signal\"\n", + " : Add Tree of type Signal with 1000 events\n", + "DataSetInfo : [dataset] : Added class \"Background\"\n", + " : Add Tree of type Background with 1000 events\n", + " : Dataset[dataset] : Class index : 0 name : Signal\n", + " : Dataset[dataset] : Class index : 1 name : Background\n" + ] + } + ], + "source": [ + "// Generate signal and background data\n", + "TTree *tsignal = genTree(1000, 1.0, 1.0, 100);\n", + "TTree *tbackground = genTree(1000, -1.0, 1.0, 101);\n", + "\n", + "// Register this data in the dataloader\n", + "loader->AddSignalTree(tsignal, 1.0);\n", + "loader->AddBackgroundTree(tbackground, 1.0); \n", + "\n", + "// Tell the factory how to use the training and testing events\n", + "//\n", + "// If no numbers of events are given, half of the events in the tree are used \n", + "// for training, and the other half for testing:\n", + "// loader->PrepareTrainingAndTestTree( mycut, \"SplitMode=random:!V\" );\n", + "// To also specify the number of testing events, use:\n", + "// loader->PrepareTrainingAndTestTree( mycut,\n", + "// \"NSigTrain=3000:NBkgTrain=3000:NSigTest=3000:NBkgTest=3000:SplitMode=Random:!V\" );\n", + "\n", + "loader->PrepareTrainingAndTestTree(\"\",\n", + " \"nTrain_Signal=1000:nTrain_Background=1000:SplitMode=Random:NormMode=NumEvents:!V\"); " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Run Cross Validation\n", + "\n", + "### Format\n", + "\n", + " - The first argument is the method to be used i.e classfication, regression etc\n", + "\n", + " - The second argument is the data loader object\n", + " \n", + " - The third argument is the output file object\n", + " \n", + " - The fourth argument is a string option defining the options for the cross validation.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "TString cvOptions = \"!V:!Silent:ModelPersistence:AnalysisType=Classification:SplitType=RandomStratified:NumFolds=5\";\n", + " \":SplitExpr=\"\"\";\n", + "\n", + "auto cv = new TMVA::CrossValidation(\"TMVACrossValidation\",loader,outputFile,cvOptions);" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Booking Methods\n", + "\n", + "\n", + "We Book here the different MVA method we want to use. \n", + "We specify the method using the appropriate enumeration, defined in *TMVA::Types*.\n", + "See the file *TMVA/Types.h* for all possible MVA methods available. \n", + "In addition, we specify via an option string all the method parameters. For all possible options, default parameter values, see the corresponding documentation in the TMVA Users Guide. \n", + "\n", + "Note that with the booking one can also specify individual variable tranformations to be done before using the method.\n", + "For example *VarTransform=Decorrelate* will decorrelate the inputs. " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "//cv->BookMethod(TMVA::Types::kBDT, \"BDT\",\n", + "// \"NTrees=10:MinNodeSize=2.5%:MaxDepth=2:nCuts=20\");\n", + "\n", + "cv->BookMethod(TMVA::Types::kFisher, \"Fisher\",\n", + " \"!H:!V:Fisher:VarTransform=None\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Perform the Cross Validation: Train/Test the booked methods" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " : Evaluate method: Fisher\n", + "
Factory : Booking method: Fisher_fold1\n", + " : \n", + "
Fisher_fold1 : Results for Fisher coefficients:\n", + " : -----------------------\n", + " : Variable: Coefficient:\n", + " : -----------------------\n", + " : x: +0.478\n", + " : y: +0.437\n", + " : (offset): +0.007\n", + " : -----------------------\n", + " : Elapsed time for training with 1600 events: 0.00155 sec \n", + "
Fisher_fold1 : [dataset] : Evaluation of Fisher_fold1 on training sample (1600 events)\n", + " : Elapsed time for evaluation of 1600 events: 0.000658 sec \n", + " : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml\n", + " : Creating standalone class: dataset/weights/TMVACrossValidation_Fisher_fold1.class.C\n", + "
Factory : Test all methods\n", + "
Factory : Test method: Fisher_fold1 for Classification performance\n", + " : \n", + "
Fisher_fold1 : [dataset] : Evaluation of Fisher_fold1 on testing sample (400 events)\n", + " : Elapsed time for evaluation of 400 events: 0.000234 sec \n", + "
Factory : Evaluate all methods\n", + "
Factory : Evaluate classifier: Fisher_fold1\n", + " : \n", + "
Fisher_fold1 : [dataset] : Loop over test events and fill histograms with classifier response...\n", + " : \n", + "
Factory : Thank you for using TMVA!\n", + " : For citation information, please visit: http://tmva.sf.net/citeTMVA.html\n", + "
Factory : Booking method: Fisher_fold2\n", + " : \n", + "
Fisher_fold2 : Results for Fisher coefficients:\n", + " : -----------------------\n", + " : Variable: Coefficient:\n", + " : -----------------------\n", + " : x: +0.479\n", + " : y: +0.458\n", + " : (offset): +0.003\n", + " : -----------------------\n", + " : Elapsed time for training with 1600 events: 0.00145 sec \n", + "
Fisher_fold2 : [dataset] : Evaluation of Fisher_fold2 on training sample (1600 events)\n", + " : Elapsed time for evaluation of 1600 events: 0.000693 sec \n", + " : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml\n", + " : Creating standalone class: dataset/weights/TMVACrossValidation_Fisher_fold2.class.C\n", + "
Factory : Test all methods\n", + "
Factory : Test method: Fisher_fold2 for Classification performance\n", + " : \n", + "
Fisher_fold2 : [dataset] : Evaluation of Fisher_fold2 on testing sample (400 events)\n", + " : Elapsed time for evaluation of 400 events: 0.000261 sec \n", + "
Factory : Evaluate all methods\n", + "
Factory : Evaluate classifier: Fisher_fold2\n", + " : \n", + "
Fisher_fold2 : [dataset] : Loop over test events and fill histograms with classifier response...\n", + " : \n", + "
Factory : Thank you for using TMVA!\n", + " : For citation information, please visit: http://tmva.sf.net/citeTMVA.html\n", + "
Factory : Booking method: Fisher_fold3\n", + " : \n", + "
Fisher_fold3 : Results for Fisher coefficients:\n", + " : -----------------------\n", + " : Variable: Coefficient:\n", + " : -----------------------\n", + " : x: +0.469\n", + " : y: +0.447\n", + " : (offset): +0.014\n", + " : -----------------------\n", + " : Elapsed time for training with 1600 events: 0.00133 sec \n", + "
Fisher_fold3 : [dataset] : Evaluation of Fisher_fold3 on training sample (1600 events)\n", + " : Elapsed time for evaluation of 1600 events: 0.000668 sec \n", + " : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher_fold3.weights.xml\n", + " : Creating standalone class: dataset/weights/TMVACrossValidation_Fisher_fold3.class.C\n", + "
Factory : Test all methods\n", + "
Factory : Test method: Fisher_fold3 for Classification performance\n", + " : \n", + "
Fisher_fold3 : [dataset] : Evaluation of Fisher_fold3 on testing sample (400 events)\n", + " : Elapsed time for evaluation of 400 events: 0.00021 sec \n", + "
Factory : Evaluate all methods\n", + "
Factory : Evaluate classifier: Fisher_fold3\n", + " : \n", + "
Fisher_fold3 : [dataset] : Loop over test events and fill histograms with classifier response...\n", + " : \n", + "
Factory : Thank you for using TMVA!\n", + " : For citation information, please visit: http://tmva.sf.net/citeTMVA.html\n", + "
Factory : Booking method: Fisher_fold4\n", + " : \n", + "
Fisher_fold4 : Results for Fisher coefficients:\n", + " : -----------------------\n", + " : Variable: Coefficient:\n", + " : -----------------------\n", + " : x: +0.488\n", + " : y: +0.469\n", + " : (offset): +0.005\n", + " : -----------------------\n", + " : Elapsed time for training with 1600 events: 0.00144 sec \n", + "
Fisher_fold4 : [dataset] : Evaluation of Fisher_fold4 on training sample (1600 events)\n", + " : Elapsed time for evaluation of 1600 events: 0.000677 sec \n", + " : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher_fold4.weights.xml\n", + " : Creating standalone class: dataset/weights/TMVACrossValidation_Fisher_fold4.class.C\n", + "
Factory : Test all methods\n", + "
Factory : Test method: Fisher_fold4 for Classification performance\n", + " : \n", + "
Fisher_fold4 : [dataset] : Evaluation of Fisher_fold4 on testing sample (400 events)\n", + " : Elapsed time for evaluation of 400 events: 0.000218 sec \n", + "
Factory : Evaluate all methods\n", + "
Factory : Evaluate classifier: Fisher_fold4\n", + " : \n", + "
Fisher_fold4 : [dataset] : Loop over test events and fill histograms with classifier response...\n", + " : \n", + "
Factory : Thank you for using TMVA!\n", + " : For citation information, please visit: http://tmva.sf.net/citeTMVA.html\n", + "
Factory : Booking method: Fisher_fold5\n", + " : \n", + "
Fisher_fold5 : Results for Fisher coefficients:\n", + " : -----------------------\n", + " : Variable: Coefficient:\n", + " : -----------------------\n", + " : x: +0.456\n", + " : y: +0.449\n", + " : (offset): +0.018\n", + " : -----------------------\n", + " : Elapsed time for training with 1600 events: 0.00137 sec \n", + "
Fisher_fold5 : [dataset] : Evaluation of Fisher_fold5 on training sample (1600 events)\n", + " : Elapsed time for evaluation of 1600 events: 0.00066 sec \n", + " : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher_fold5.weights.xml\n", + " : Creating standalone class: dataset/weights/TMVACrossValidation_Fisher_fold5.class.C\n", + "
Factory : Test all methods\n", + "
Factory : Test method: Fisher_fold5 for Classification performance\n", + " : \n", + "
Fisher_fold5 : [dataset] : Evaluation of Fisher_fold5 on testing sample (400 events)\n", + " : Elapsed time for evaluation of 400 events: 0.000221 sec \n", + "
Factory : Evaluate all methods\n", + "
Factory : Evaluate classifier: Fisher_fold5\n", + " : \n", + "
Fisher_fold5 : [dataset] : Loop over test events and fill histograms with classifier response...\n", + " : \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
Factory : Thank you for using TMVA!\n", + " : For citation information, please visit: http://tmva.sf.net/citeTMVA.html\n", + "
Factory : Booking method: Fisher\n", + " : \n", + " : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml\n", + " : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold1.weights.xml\n", + " : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml\n", + " : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold2.weights.xml\n", + " : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold3.weights.xml\n", + " : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold3.weights.xml\n", + " : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold4.weights.xml\n", + " : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold4.weights.xml\n", + " : Reading weightfile: dataset/weights/TMVACrossValidation_Fisher_fold5.weights.xml\n", + " : Reading weight file: dataset/weights/TMVACrossValidation_Fisher_fold5.weights.xml\n", + "
Factory : [dataset] : Create Transformation \"I\" with events from all classes.\n", + " : \n", + "
: Transformation, Variable selection : \n", + " : Input : variable 'x' <---> Output : variable 'x'\n", + " : Input : variable 'y' <---> Output : variable 'y'\n", + "
TFHandler_Factory : Variable Mean RMS [ Min Max ]\n", + " : -----------------------------------------------------------\n", + " : x: -0.014176 1.4057 [ -4.1075 4.0969 ]\n", + " : y: -0.0066866 1.4203 [ -4.8520 4.0761 ]\n", + " : -----------------------------------------------------------\n", + " : Ranking input variables (method unspecific)...\n", + "
IdTransformation : Ranking result (top variable is best ranked)\n", + " : --------------------------\n", + " : Rank : Variable : Separation\n", + " : --------------------------\n", + " : 1 : x : 5.433e-01\n", + " : 2 : y : 5.234e-01\n", + " : --------------------------\n", + " : Elapsed time for training with 2000 events: 7.87e-06 sec \n", + "
Fisher : [dataset] : Evaluation of Fisher on training sample (2000 events)\n", + " : Elapsed time for evaluation of 2000 events: 0.00207 sec \n", + " : Creating xml weight file: dataset/weights/TMVACrossValidation_Fisher.weights.xml\n", + " : Creating standalone class: dataset/weights/TMVACrossValidation_Fisher.class.C\n", + " : MakeClassSpecificHeader not implemented for CrossValidation\n", + " : MakeClassSpecific not implemented for CrossValidation\n", + "
Factory : Test all methods\n", + "
Factory : Test method: Fisher for Classification performance\n", + " : \n", + "
Fisher : [dataset] : Evaluation of Fisher on testing sample (2000 events)\n", + " : Elapsed time for evaluation of 2000 events: 0.00205 sec \n", + "
Factory : Evaluate all methods\n", + "
Factory : Evaluate classifier: Fisher\n", + " : \n", + "
Fisher : [dataset] : Loop over test events and fill histograms with classifier response...\n", + " : \n", + "
TFHandler_Fisher : Variable Mean RMS [ Min Max ]\n", + " : -----------------------------------------------------------\n", + " : x: -0.014176 1.4057 [ -4.1075 4.0969 ]\n", + " : y: -0.0066866 1.4203 [ -4.8520 4.0761 ]\n", + " : -----------------------------------------------------------\n", + " : \n", + " : Evaluation results ranked by best signal efficiency and purity (area)\n", + " : -------------------------------------------------------------------------------------------------------------------\n", + " : DataSet MVA \n", + " : Name: Method: ROC-integ\n", + " : dataset Fisher : 0.971\n", + " : -------------------------------------------------------------------------------------------------------------------\n", + " : \n", + " : Testing efficiency compared to training efficiency (overtraining check)\n", + " : -------------------------------------------------------------------------------------------------------------------\n", + " : DataSet MVA Signal efficiency: from test sample (from training sample) \n", + " : Name: Method: @B=0.01 @B=0.10 @B=0.30 \n", + " : -------------------------------------------------------------------------------------------------------------------\n", + " : dataset Fisher : 0.660 (0.660) 0.922 (0.922) 0.980 (0.980)\n", + " : -------------------------------------------------------------------------------------------------------------------\n", + " : \n", + "
Dataset:dataset : Created tree 'TestTree' with 2000 events\n", + " : \n", + "
Dataset:dataset : Created tree 'TrainTree' with 2000 events\n", + " : \n", + "
Factory : Thank you for using TMVA!\n", + " : For citation information, please visit: http://tmva.sf.net/citeTMVA.html\n", + " : Evaluation done.\n" + ] + } + ], + "source": [ + "// Run cross-validation\n", + "cv->Evaluate();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cross Validation Result" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
CrossValidation : ==== Results ====\n", + " : Fold 0 ROC-Int : 0.9749\n", + " : Fold 1 ROC-Int : 0.9711\n", + " : Fold 2 ROC-Int : 0.9766\n", + " : Fold 3 ROC-Int : 0.9632\n", + " : Fold 4 ROC-Int : 0.9706\n", + " : ------------------------\n", + " : Average ROC-Int : 0.9713\n", + " : Std-Dev ROC-Int : 0.0052\n" + ] + } + ], + "source": [ + "TMVA::CrossValidationResult & result = (TMVA::CrossValidationResult &) cv->GetResults()[0];\n", + "\n", + "result.Print();\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot ROC Curves\n", + "We enable JavaScript visualisation for the plots" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "%jsroot on" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "result.Draw();" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot Average ROC Curve" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "result.DrawAvgROCCurve(kFALSE, \"CrossValidation Avg ROC Curve\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plot ROC Curves and the Average ROC Curve" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "result.DrawAvgROCCurve(kTRUE, \"CrossValidation ROC Curves and Avg ROC Curve\");" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Close outputfile to save all output information (evaluation result of methods)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "outputFile->Close();" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ROOT C++", + "language": "c++", + "name": "root" + }, + "language_info": { + "codemirror_mode": "text/x-c++src", + "file_extension": ".C", + "mimetype": " text/x-c++src", + "name": "c++" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}