-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
373 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#ifndef TFCLIENTBASE_H | ||
#define TFCLIENTBASE_H | ||
|
||
#include "FWCore/MessageLogger/interface/MessageLogger.h" | ||
|
||
#include "tensorflow/core/framework/tensor.h" | ||
|
||
//base class for local and remote clients | ||
class TFClientBase { | ||
public: | ||
//constructor | ||
TFClientBase() {} | ||
//destructor | ||
virtual ~TFClientBase() {} | ||
|
||
//input is "image" in tensor form | ||
virtual bool predict(const tensorflow::Tensor& img, tensorflow::Tensor& result) const { | ||
return true; | ||
} | ||
}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#ifndef TFCLIENTLOCAL_H | ||
#define TFCLIENTLOCAL_H | ||
|
||
#include "FWCore/MessageLogger/interface/MessageLogger.h" | ||
#include "TFClientBase.h" | ||
|
||
#include <string> | ||
#include <vector> | ||
#include <memory> | ||
|
||
#include "tensorflow/core/framework/tensor.h" | ||
#include "tensorflow/core/graph/default_device.h" | ||
|
||
class TFClientLocal : public TFClientBase { | ||
public: | ||
//constructors (timeout in seconds) | ||
TFClientLocal() : TFClientBase() {} | ||
TFClientLocal(const std::string& featurizer_file, const std::string& classifier_file); | ||
|
||
//input is "image" in tensor form | ||
bool predict(const tensorflow::Tensor& img, tensorflow::Tensor& result) const override; | ||
|
||
private: | ||
void loadModel(const std::string& featurizer_file, const std::string& classifier_file); | ||
std::vector<tensorflow::Tensor> runFeaturizer(const tensorflow::Tensor& inputImage) const; | ||
std::vector<tensorflow::Tensor> runClassifier(const tensorflow::Tensor& inputClassifier) const; | ||
tensorflow::Tensor createFeatureList(const tensorflow::Tensor& input) const; | ||
|
||
//members | ||
tensorflow::GraphDef* graphDefFeaturizer_; | ||
tensorflow::GraphDef* graphDefClassifier_; | ||
}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#ifndef TFCLIENTREMOTE_H | ||
#define TFCLIENTREMOTE_H | ||
|
||
#include "FWCore/MessageLogger/interface/MessageLogger.h" | ||
#include "TFClientBase.h" | ||
|
||
#include <string> | ||
#include <memory> | ||
|
||
#include "grpc++/create_channel.h" | ||
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h" | ||
#include "tensorflow/core/framework/tensor.h" | ||
#include "tensorflow/core/util/command_line_flags.h" | ||
|
||
class TFClientRemote : public TFClientBase { | ||
public: | ||
//constructors (timeout in seconds) | ||
TFClientRemote() : TFClientBase() {} | ||
TFClientRemote(const std::string& address, int port, unsigned timeout); | ||
|
||
//input is "image" in tensor form | ||
bool predict(const tensorflow::Tensor& img, tensorflow::Tensor& result) const override; | ||
|
||
private: | ||
std::shared_ptr<grpc::Channel> channel_; | ||
std::unique_ptr<tensorflow::serving::PredictionService::Stub> stub_; | ||
unsigned timeout_; | ||
}; | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.