Skip to content

Commit

Permalink
refactor and clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
kpedro88 committed Sep 17, 2018
1 parent bc9cb6c commit 79c91c6
Show file tree
Hide file tree
Showing 10 changed files with 373 additions and 326 deletions.
1 change: 1 addition & 0 deletions AnalysisFW/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
<use name="FWCore/PluginManager"/>
<use name="PhysicsTools/TensorFlow"/>
<use name="grpc"/>
<use name="tensorflow-cc"/>
<use name="tensorflow-serving"/>
<export>
<lib name="1"/>
Expand Down
94 changes: 0 additions & 94 deletions AnalysisFW/interface/TFClient.h

This file was deleted.

22 changes: 22 additions & 0 deletions AnalysisFW/interface/TFClientBase.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef TFCLIENTBASE_H
#define TFCLIENTBASE_H

#include "FWCore/MessageLogger/interface/MessageLogger.h"

#include "tensorflow/core/framework/tensor.h"

//base class for local and remote clients
class TFClientBase {
public:
//constructor
TFClientBase() {}
//destructor
virtual ~TFClientBase() {}

//input is "image" in tensor form
virtual bool predict(const tensorflow::Tensor& img, tensorflow::Tensor& result) const {
return true;
}
};

#endif
34 changes: 34 additions & 0 deletions AnalysisFW/interface/TFClientLocal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#ifndef TFCLIENTLOCAL_H
#define TFCLIENTLOCAL_H

#include "FWCore/MessageLogger/interface/MessageLogger.h"
#include "TFClientBase.h"

#include <string>
#include <vector>
#include <memory>

#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/default_device.h"

class TFClientLocal : public TFClientBase {
public:
//constructors (timeout in seconds)
TFClientLocal() : TFClientBase() {}
TFClientLocal(const std::string& featurizer_file, const std::string& classifier_file);

//input is "image" in tensor form
bool predict(const tensorflow::Tensor& img, tensorflow::Tensor& result) const override;

private:
void loadModel(const std::string& featurizer_file, const std::string& classifier_file);
std::vector<tensorflow::Tensor> runFeaturizer(const tensorflow::Tensor& inputImage) const;
std::vector<tensorflow::Tensor> runClassifier(const tensorflow::Tensor& inputClassifier) const;
tensorflow::Tensor createFeatureList(const tensorflow::Tensor& input) const;

//members
tensorflow::GraphDef* graphDefFeaturizer_;
tensorflow::GraphDef* graphDefClassifier_;
};

#endif
30 changes: 30 additions & 0 deletions AnalysisFW/interface/TFClientRemote.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef TFCLIENTREMOTE_H
#define TFCLIENTREMOTE_H

#include "FWCore/MessageLogger/interface/MessageLogger.h"
#include "TFClientBase.h"

#include <string>
#include <memory>

#include "grpc++/create_channel.h"
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/util/command_line_flags.h"

class TFClientRemote : public TFClientBase {
public:
//constructors (timeout in seconds)
TFClientRemote() : TFClientBase() {}
TFClientRemote(const std::string& address, int port, unsigned timeout);

//input is "image" in tensor form
bool predict(const tensorflow::Tensor& img, tensorflow::Tensor& result) const override;

private:
std::shared_ptr<grpc::Channel> channel_;
std::unique_ptr<tensorflow::serving::PredictionService::Stub> stub_;
unsigned timeout_;
};

#endif
2 changes: 1 addition & 1 deletion AnalysisFW/plugins/BuildFile.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<use name="FWCore/Framework"/>
<use name="FWCore/PluginManager"/>
<use name="SonicCMS/AnalysisFW"/>
<use name="PhysicsTools/TensorFlow"/>
<use name="tensorflow-cc"/>
<flags EDM_PLUGIN="1"/>
</library>

Loading

0 comments on commit 79c91c6

Please sign in to comment.