Skip to content

Commit

Permalink
Fix treelite (#12938)
Browse files Browse the repository at this point in the history
  • Loading branch information
ktf authored Jan 14, 2020
1 parent 87c9d3a commit a1031a0
Show file tree
Hide file tree
Showing 25 changed files with 4,422 additions and 492 deletions.
2 changes: 1 addition & 1 deletion ML/AliExternalBDT.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ bool AliExternalBDT::LoadLightGBMModel(std::string path) {
}

bool AliExternalBDT::LoadModelLibrary(std::string path) {
const int status = TreelitePredictorLoad(path.data(), 1, 1, &fPredictor);
const int status = TreelitePredictorLoad(path.data(), 1, &fPredictor);
if (status != 0) {
std::cerr << "Library loading failed" << std::endl;
return false;
Expand Down
145 changes: 145 additions & 0 deletions ML/AliMLModelHandler.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
// Copyright CERN. This software is distributed under the terms of the GNU
// General Public License v3 (GPL Version 3).
//
// See http://www.gnu.org/licenses/ for full licensing information.
//
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file AliMLModelHandler.cxx
/// \author [email protected], [email protected], [email protected]

#include "AliMLModelHandler.h"

#include <map>
#include "yaml-cpp/yaml.h"

#include <TFile.h>
#include <TGrid.h>
#include <TSystem.h>
#include "AliLog.h"
#include "AliExternalBDT.h"

/// \cond CLASSIMP
ClassImp(AliMLModelHandler);
/// \endcond

//_______________________________________________________________________________
AliMLModelHandler::AliMLModelHandler() : TNamed(), fModel{nullptr}, fPath{}, fLibrary{}, fScoreCut{} {
//
// Default constructor
//
}

//_______________________________________________________________________________
AliMLModelHandler::AliMLModelHandler(const YAML::Node &node)
: TNamed(), fModel{nullptr}, fPath{node["path"].as<std::string>()},
fLibrary{node["library"].as<std::string>()}, fScoreCut{node["cut"].as<double>()} {
//
// Standard constructor
//
fModel = new AliExternalBDT();
}

AliMLModelHandler::~AliMLModelHandler() {
//
// Destructor
//
if(fModel)
delete fModel;
}

//_______________________________________________________________________________
AliMLModelHandler::AliMLModelHandler(const AliMLModelHandler &source)
: TNamed(source.GetName(), source.GetTitle()), fModel{nullptr}, fPath{source.fPath},
fLibrary{source.fLibrary}, fScoreCut{source.fScoreCut} {
//
// Copy constructor
//
fModel = new AliExternalBDT(*source.fModel);
}

AliMLModelHandler &AliMLModelHandler::operator=(const AliMLModelHandler &source) {
//
// Assignment operator
//
if (&source == this) return *this;

TNamed::operator=(source);

if(fModel)
delete fModel;
fModel = new AliExternalBDT(*source.fModel);

fPath = source.fPath;
fLibrary = source.fLibrary;
fScoreCut = source.fScoreCut;

return *this;
}

//_______________________________________________________________________________
bool AliMLModelHandler::CompileModel() {

std::map<std::string, int> libraryMap = {{"kXGBoost", AliMLModelHandler::kXGBoost},
{"kLightGBM", AliMLModelHandler::kLightGBM},
{"kModelLibrary", AliMLModelHandler::kModelLibrary}};

std::string localpath = ImportFile(fPath);

switch (libraryMap[GetLibrary()]) {
case kXGBoost: {
return fModel->LoadXGBoostModel(localpath.data());
break;
}
case kLightGBM: {
return fModel->LoadLightGBMModel(localpath.data());
break;
}
case kModelLibrary: {
return fModel->LoadModelLibrary(localpath.data());
break;
}
default: {
return fModel->LoadXGBoostModel(localpath.data());
break;
}
}
}

//_______________________________________________________________________________
std::string AliMLModelHandler::ImportFile(std::string path) {
std::string modelname = path.substr(path.find_last_of("/") + 1);

// check if file is in current directory
if (path.find("/") == std::string::npos) {
bool checkFile = gSystem->AccessPathName(gSystem->ExpandPathName(path.c_str()));
if (checkFile) {
AliFatalClass(Form("Error file %s not found! Exit", path.data()));
}
return path;
}

// check if file is on alien
if (path.find("alien:") != std::string::npos) {
if (gGrid == nullptr) {
TGrid::Connect("alien://");
if (gGrid == nullptr) {
AliFatalClass("Connection to GRID not established! Exit");
}
}
}

std::string newpath = gSystem->pwd() + std::string("/") + modelname.data();
std::string oldpath = gDirectory->GetPath();

bool cpStatus = TFile::Cp(path.data(), newpath.data());
if (!cpStatus) {
AliFatalClass(Form("Error in coping file %s in the working directory! Exit", path.data()));
}

gDirectory->Cd(oldpath.data());

return newpath;
}
58 changes: 58 additions & 0 deletions ML/AliMLModelHandler.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef ALIMLMODELHANDLER_H
#define ALIMLMODELHANDLER_H

// Copyright CERN. This software is distributed under the terms of the GNU
// General Public License v3 (GPL Version 3).
//
// See http://www.gnu.org/licenses/ for full licensing information.
//
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.

/// \file AliMLModelHandler.h
/// \brief Utility class to store the compiled model and it's information
/// \author [email protected], [email protected], [email protected]

#include <string>

#include "TNamed.h"

namespace YAML {
class Node;
}
class AliExternalBDT;

class AliMLModelHandler : public TNamed {
public:
enum {kXGBoost, kLightGBM, kModelLibrary};

AliMLModelHandler();
AliMLModelHandler(const YAML::Node &node);
virtual ~AliMLModelHandler();

AliMLModelHandler(const AliMLModelHandler &source);
AliMLModelHandler &operator=(const AliMLModelHandler &source);

AliExternalBDT *GetModel() { return fModel; }
std::string const &GetPath() const { return fPath; }
std::string const &GetLibrary() const { return fLibrary; }
double const &GetScoreCut() const { return fScoreCut; }

bool CompileModel();
static std::string ImportFile(std::string path);

private:
AliExternalBDT *fModel; //!<!

std::string fPath; ///
std::string fLibrary; ///

double fScoreCut; ///

/// \cond CLASSIMP
ClassDef(AliMLModelHandler, 1); ///
/// \endcond
};

#endif
Loading

0 comments on commit a1031a0

Please sign in to comment.