Skip to content

Commit

Permalink
added asynch option
Browse files Browse the repository at this point in the history
  • Loading branch information
violatingcp committed Nov 23, 2019
1 parent af278d0 commit f28e1b5
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 42 deletions.
4 changes: 3 additions & 1 deletion AnalysisFW/interface/TFClientRemoteTRT.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ClientData {
unsigned noutput_;
const float *input_;
float *output_;
bool async_;
edm::WaitingTaskWithArenaHolder holder_;

std::mutex mutex_;
Expand All @@ -43,7 +44,7 @@ class TFClientRemoteTRT : public TFClientBase {
public:
//constructors (timeout in seconds)
TFClientRemoteTRT() : TFClientBase() {}
TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name, unsigned batchSize, unsigned ninput, unsigned noutput);
TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name, unsigned batchSize, unsigned ninput, unsigned noutput,bool async);

//input is "image" in tensor form
void predict(unsigned dataID, const float* img, float* result, edm::WaitingTaskWithArenaHolder holder);
Expand All @@ -56,6 +57,7 @@ class TFClientRemoteTRT : public TFClientBase {
std::string modelName_;
unsigned ninput_;
unsigned noutput_;
bool async_;
std::unique_ptr<nic::InferContext> *context_;
std::shared_ptr<nic::InferContext::Input>* nicinput_;
};
Expand Down
5 changes: 3 additions & 2 deletions AnalysisFW/plugins/HcalProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ void HcalProducer::preallocate(edm::PreallocationConfiguration const& iPrealloc)
extraParams_.getParameter<std::string>("modelname"),
batchSize_,
ninput_,
noutput_
);
noutput_,
extraParams_.getParameter<bool>("async")
);
edm::LogInfo("HcalProducer") << "Connected to remote server";
}
}
Expand Down
3 changes: 2 additions & 1 deletion AnalysisFW/plugins/JetImageProducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ void JetImageProducer::preallocate(edm::PreallocationConfiguration const& iPreal
extraParams_.getParameter<std::string>("modelname"),
batchSize_,
ninput_,
noutput_
noutput_,
extraParams_.getParameter<bool>("async")
);
edm::LogInfo("JetImageProducer") << "Connected to remote server";
}
Expand Down
4 changes: 3 additions & 1 deletion AnalysisFW/python/HcalTest_mc_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
options.register("streams", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int)
options.register("batchsize", 1, VarParsing.multiplicity.singleton, VarParsing.varType.int)
options.register("modelname","facile", VarParsing.multiplicity.singleton, VarParsing.varType.string)
options.register("async",False, VarParsing.multiplicity.singleton, VarParsing.varType.bool)
options.parseArguments()

if len(options.params)>0 and options.remote:
Expand Down Expand Up @@ -57,7 +58,8 @@
address = cms.string(options.address),
port = cms.int32(options.port),
timeout = cms.uint32(options.timeout),
modelname = cms.string(options.modelname)
modelname = cms.string(options.modelname),
async = cms.bool(options.async)
)

# Let it run
Expand Down
4 changes: 3 additions & 1 deletion AnalysisFW/python/jetImageTest_mc_cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
options.register("batchsize", 10, VarParsing.multiplicity.singleton, VarParsing.varType.int)
#options.register("modelname","resnet50_netdef", VarParsing.multiplicity.singleton, VarParsing.varType.string)
options.register("modelname","resnet50_ensemble", VarParsing.multiplicity.singleton, VarParsing.varType.string)
options.register("async",False, VarParsing.multiplicity.singleton, VarParsing.varType.bool)
options.parseArguments()

if len(options.params)>0 and options.remote:
Expand Down Expand Up @@ -58,7 +59,8 @@
address = cms.string(options.address),
port = cms.int32(options.port),
timeout = cms.uint32(options.timeout),
modelname = cms.string(options.modelname)
modelname = cms.string(options.modelname),
async = cms.bool(options.async),
)
else:
process.jetImageProducer.remote = cms.bool(False)
Expand Down
88 changes: 52 additions & 36 deletions AnalysisFW/src/TFClientRemoteTRT.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ ClientData::ClientData() :
dataID_(0),
timeout_(0),
batchSize_(1),
output_(nullptr)
output_(nullptr),
async_(true)
{
}

Expand All @@ -38,8 +39,6 @@ void ClientData::predict(){

auto t2 = std::chrono::high_resolution_clock::now();
std::vector<int64_t> input_shape;
input_shape.push_back(15);
input->SetShape(input_shape);
for(unsigned i0 = 0; i0 < batchSize_; i0++) {
nic::Error err1 = input->SetRaw(reinterpret_cast<const uint8_t*>(&input_[0]),ninput_ * sizeof(float));
}
Expand All @@ -48,48 +47,64 @@ void ClientData::predict(){
edm::LogInfo("TFClientRemoteTRT") << "Image array time: " << time2;

std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
nic::Error erro0 = ctx->AsyncRun(
[t3,this](nic::InferContext* ctx, const std::shared_ptr<nic::InferContext::Request>& request) {
//get results
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
//this function interface will change in the next tensorrtis version
bool is_ready = false;
ctx->GetAsyncRunResults(&results, &is_ready, request, false);
if(is_ready == false) throw cms::Exception("BadCallback") << "Callback executed before request was ready";
if(async_) {
nic::Error erro0 = ctx->AsyncRun(
[t3,this](nic::InferContext* ctx, const std::shared_ptr<nic::InferContext::Request>& request) {
//get results
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
//this function interface will change in the next tensorrtis version
bool is_ready = false;
ctx->GetAsyncRunResults(&results, &is_ready, request, false);
if(is_ready == false) throw cms::Exception("BadCallback") << "Callback executed before request was ready";

//check time
auto t4 = std::chrono::high_resolution_clock::now();
auto time3 = std::chrono::duration_cast<std::chrono::microseconds>(t4-t3).count();
edm::LogInfo("TFClientRemoteTRT") << "Remote time: " << time3;

//check result
std::exception_ptr exceptionPtr;
const std::unique_ptr<nic::InferContext::Result>& result = results.begin()->second;
for(unsigned i0 = 0; i0 < this->batchSize_; i0++) {
const uint8_t* r0;
size_t content_byte_size;
result->GetRaw(i0, &r0,&content_byte_size);
const float *lVal = reinterpret_cast<const float*>(r0);
for(unsigned i1 = 0; i1 < this->noutput_; i1++) this->output_[i0*noutput_+i1] = lVal[i1]; //This should be replaced with a memcpy
}
auto t5 = std::chrono::high_resolution_clock::now();
auto time4 = std::chrono::duration_cast<std::chrono::microseconds>(t5-t4).count();
edm::LogInfo("TFClientRemoteTRT") << "Output time: " << time4;

//check time
auto t4 = std::chrono::high_resolution_clock::now();
auto time3 = std::chrono::duration_cast<std::chrono::microseconds>(t4-t3).count();
edm::LogInfo("TFClientRemoteTRT") << "Remote time: " << time3;

//check result
std::exception_ptr exceptionPtr;
const std::unique_ptr<nic::InferContext::Result>& result = results.begin()->second;
for(unsigned i0 = 0; i0 < this->batchSize_; i0++) {
const uint8_t* r0;
size_t content_byte_size;
result->GetRaw(i0, &r0,&content_byte_size);
const float *lVal = reinterpret_cast<const float*>(r0);
for(unsigned i1 = 0; i1 < this->noutput_; i1++) this->output_[i0*noutput_+i1] = lVal[i1]; //This should be replaced with a memcpy
}
auto t5 = std::chrono::high_resolution_clock::now();
auto time4 = std::chrono::duration_cast<std::chrono::microseconds>(t5-t4).count();
edm::LogInfo("TFClientRemoteTRT") << "Output time: " << time4;

//finish
this->holder_.doneWaiting(exceptionPtr);
}
);
//finish
this->holder_.doneWaiting(exceptionPtr);
}
);
} else {
std::map<std::string, std::unique_ptr<nic::InferContext::Result>> results;
nic::Error err0 = ctx->Run(&results);
std::exception_ptr exceptionPtr;
const std::unique_ptr<nic::InferContext::Result>& result = results.begin()->second;
for(unsigned i0 = 0; i0 < batchSize_; i0++) {
const uint8_t* r0;
size_t content_byte_size;
result->GetRaw(i0, &r0,&content_byte_size);
const float *lVal = reinterpret_cast<const float*>(r0);
for(unsigned i1 = 0; i1 < noutput_; i1++) output_[i0*noutput_+i1] = lVal[i1]; //This should be replaced with a memcpy
}
this->holder_.doneWaiting(exceptionPtr);
}
}

//based on: tensor-rt-client simple_example
TFClientRemoteTRT::TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name,unsigned batchSize,unsigned ninput,unsigned noutput) :
TFClientRemoteTRT::TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name,unsigned batchSize,unsigned ninput,unsigned noutput,bool async) :
TFClientBase(),
streamData_(numStreams),
timeout_(timeout),
batchSize_(batchSize),
ninput_(ninput),
noutput_(noutput)
noutput_(noutput),
async_(async)
{
url_=address+":"+std::to_string(port);
modelName_ = model_name;
Expand All @@ -112,6 +127,7 @@ void TFClientRemoteTRT::predict(unsigned dataID, const float* img, float* result
streamData.output_= result;
streamData.holder_= std::move(holder);
streamData.input_ = img;
streamData.async_ = async_;
}
streamData.predict();
edm::LogInfo("TFClientRemoteTRT") << "Async predict request sent";
Expand Down

0 comments on commit f28e1b5

Please sign in to comment.