diff --git a/AnalysisFW/interface/TFClientRemoteTRT.h b/AnalysisFW/interface/TFClientRemoteTRT.h index 93a010c..74e4735 100644 --- a/AnalysisFW/interface/TFClientRemoteTRT.h +++ b/AnalysisFW/interface/TFClientRemoteTRT.h @@ -31,6 +31,7 @@ class ClientData { unsigned noutput_; const float *input_; float *output_; + bool async_; edm::WaitingTaskWithArenaHolder holder_; std::mutex mutex_; @@ -43,7 +44,7 @@ class TFClientRemoteTRT : public TFClientBase { public: //constructors (timeout in seconds) TFClientRemoteTRT() : TFClientBase() {} - TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name, unsigned batchSize, unsigned ninput, unsigned noutput); + TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name, unsigned batchSize, unsigned ninput, unsigned noutput,bool async); //input is "image" in tensor form void predict(unsigned dataID, const float* img, float* result, edm::WaitingTaskWithArenaHolder holder); @@ -56,6 +57,7 @@ class TFClientRemoteTRT : public TFClientBase { std::string modelName_; unsigned ninput_; unsigned noutput_; + bool async_; std::unique_ptr *context_; std::shared_ptr* nicinput_; }; diff --git a/AnalysisFW/plugins/HcalProducer.cc b/AnalysisFW/plugins/HcalProducer.cc index 70eb5f0..306613f 100644 --- a/AnalysisFW/plugins/HcalProducer.cc +++ b/AnalysisFW/plugins/HcalProducer.cc @@ -87,8 +87,9 @@ void HcalProducer::preallocate(edm::PreallocationConfiguration const& iPrealloc) extraParams_.getParameter("modelname"), batchSize_, ninput_, - noutput_ - ); + noutput_, + extraParams_.getParameter("async") + ); edm::LogInfo("HcalProducer") << "Connected to remote server"; } } diff --git a/AnalysisFW/plugins/JetImageProducer.cc b/AnalysisFW/plugins/JetImageProducer.cc index ce32928..fbbfd6d 100644 --- a/AnalysisFW/plugins/JetImageProducer.cc +++ b/AnalysisFW/plugins/JetImageProducer.cc @@ -104,7 +104,8 @@ void JetImageProducer::preallocate(edm::PreallocationConfiguration const& iPreal extraParams_.getParameter("modelname"), batchSize_, ninput_, - noutput_ + noutput_, + extraParams_.getParameter("async") ); edm::LogInfo("JetImageProducer") << "Connected to remote server"; } diff --git a/AnalysisFW/python/HcalTest_mc_cfg.py b/AnalysisFW/python/HcalTest_mc_cfg.py index 36f2400..725cec9 100644 --- a/AnalysisFW/python/HcalTest_mc_cfg.py +++ b/AnalysisFW/python/HcalTest_mc_cfg.py @@ -15,6 +15,7 @@ options.register("streams", 0, VarParsing.multiplicity.singleton, VarParsing.varType.int) options.register("batchsize", 1, VarParsing.multiplicity.singleton, VarParsing.varType.int) options.register("modelname","facile", VarParsing.multiplicity.singleton, VarParsing.varType.string) +options.register("async",False, VarParsing.multiplicity.singleton, VarParsing.varType.bool) options.parseArguments() if len(options.params)>0 and options.remote: @@ -57,7 +58,8 @@ address = cms.string(options.address), port = cms.int32(options.port), timeout = cms.uint32(options.timeout), - modelname = cms.string(options.modelname) + modelname = cms.string(options.modelname), + async = cms.bool(options.async) ) # Let it run diff --git a/AnalysisFW/python/jetImageTest_mc_cfg.py b/AnalysisFW/python/jetImageTest_mc_cfg.py index 260a4f4..e7114b8 100644 --- a/AnalysisFW/python/jetImageTest_mc_cfg.py +++ b/AnalysisFW/python/jetImageTest_mc_cfg.py @@ -14,6 +14,7 @@ options.register("batchsize", 10, VarParsing.multiplicity.singleton, VarParsing.varType.int) #options.register("modelname","resnet50_netdef", VarParsing.multiplicity.singleton, VarParsing.varType.string) options.register("modelname","resnet50_ensemble", VarParsing.multiplicity.singleton, VarParsing.varType.string) +options.register("async",False, VarParsing.multiplicity.singleton, VarParsing.varType.bool) options.parseArguments() if len(options.params)>0 and options.remote: @@ -58,7 +59,8 @@ address = cms.string(options.address), port = cms.int32(options.port), timeout = cms.uint32(options.timeout), - modelname = cms.string(options.modelname) + modelname = cms.string(options.modelname), + async = cms.bool(options.async), ) else: process.jetImageProducer.remote = cms.bool(False) diff --git a/AnalysisFW/src/TFClientRemoteTRT.cc b/AnalysisFW/src/TFClientRemoteTRT.cc index 67d2ac8..d7e7261 100644 --- a/AnalysisFW/src/TFClientRemoteTRT.cc +++ b/AnalysisFW/src/TFClientRemoteTRT.cc @@ -13,7 +13,8 @@ ClientData::ClientData() : dataID_(0), timeout_(0), batchSize_(1), - output_(nullptr) + output_(nullptr), + async_(true) { } @@ -38,8 +39,6 @@ void ClientData::predict(){ auto t2 = std::chrono::high_resolution_clock::now(); std::vector input_shape; - input_shape.push_back(15); - input->SetShape(input_shape); for(unsigned i0 = 0; i0 < batchSize_; i0++) { nic::Error err1 = input->SetRaw(reinterpret_cast(&input_[0]),ninput_ * sizeof(float)); } @@ -48,48 +47,64 @@ void ClientData::predict(){ edm::LogInfo("TFClientRemoteTRT") << "Image array time: " << time2; std::map> results; - nic::Error erro0 = ctx->AsyncRun( - [t3,this](nic::InferContext* ctx, const std::shared_ptr& request) { - //get results - std::map> results; - //this function interface will change in the next tensorrtis version - bool is_ready = false; - ctx->GetAsyncRunResults(&results, &is_ready, request, false); - if(is_ready == false) throw cms::Exception("BadCallback") << "Callback executed before request was ready"; + if(async_) { + nic::Error erro0 = ctx->AsyncRun( + [t3,this](nic::InferContext* ctx, const std::shared_ptr& request) { + //get results + std::map> results; + //this function interface will change in the next tensorrtis version + bool is_ready = false; + ctx->GetAsyncRunResults(&results, &is_ready, request, false); + if(is_ready == false) throw cms::Exception("BadCallback") << "Callback executed before request was ready"; + + //check time + auto t4 = std::chrono::high_resolution_clock::now(); + auto time3 = std::chrono::duration_cast(t4-t3).count(); + edm::LogInfo("TFClientRemoteTRT") << "Remote time: " << time3; + + //check result + std::exception_ptr exceptionPtr; + const std::unique_ptr& result = results.begin()->second; + for(unsigned i0 = 0; i0 < this->batchSize_; i0++) { + const uint8_t* r0; + size_t content_byte_size; + result->GetRaw(i0, &r0,&content_byte_size); + const float *lVal = reinterpret_cast(r0); + for(unsigned i1 = 0; i1 < this->noutput_; i1++) this->output_[i0*noutput_+i1] = lVal[i1]; //This should be replaced with a memcpy + } + auto t5 = std::chrono::high_resolution_clock::now(); + auto time4 = std::chrono::duration_cast(t5-t4).count(); + edm::LogInfo("TFClientRemoteTRT") << "Output time: " << time4; - //check time - auto t4 = std::chrono::high_resolution_clock::now(); - auto time3 = std::chrono::duration_cast(t4-t3).count(); - edm::LogInfo("TFClientRemoteTRT") << "Remote time: " << time3; - - //check result - std::exception_ptr exceptionPtr; - const std::unique_ptr& result = results.begin()->second; - for(unsigned i0 = 0; i0 < this->batchSize_; i0++) { - const uint8_t* r0; - size_t content_byte_size; - result->GetRaw(i0, &r0,&content_byte_size); - const float *lVal = reinterpret_cast(r0); - for(unsigned i1 = 0; i1 < this->noutput_; i1++) this->output_[i0*noutput_+i1] = lVal[i1]; //This should be replaced with a memcpy - } - auto t5 = std::chrono::high_resolution_clock::now(); - auto time4 = std::chrono::duration_cast(t5-t4).count(); - edm::LogInfo("TFClientRemoteTRT") << "Output time: " << time4; - - //finish - this->holder_.doneWaiting(exceptionPtr); - } - ); + //finish + this->holder_.doneWaiting(exceptionPtr); + } + ); + } else { + std::map> results; + nic::Error err0 = ctx->Run(&results); + std::exception_ptr exceptionPtr; + const std::unique_ptr& result = results.begin()->second; + for(unsigned i0 = 0; i0 < batchSize_; i0++) { + const uint8_t* r0; + size_t content_byte_size; + result->GetRaw(i0, &r0,&content_byte_size); + const float *lVal = reinterpret_cast(r0); + for(unsigned i1 = 0; i1 < noutput_; i1++) output_[i0*noutput_+i1] = lVal[i1]; //This should be replaced with a memcpy + } + this->holder_.doneWaiting(exceptionPtr); + } } //based on: tensor-rt-client simple_example -TFClientRemoteTRT::TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name,unsigned batchSize,unsigned ninput,unsigned noutput) : +TFClientRemoteTRT::TFClientRemoteTRT(unsigned numStreams, const std::string& address, int port, unsigned timeout,const std::string& model_name,unsigned batchSize,unsigned ninput,unsigned noutput,bool async) : TFClientBase(), streamData_(numStreams), timeout_(timeout), batchSize_(batchSize), ninput_(ninput), - noutput_(noutput) + noutput_(noutput), + async_(async) { url_=address+":"+std::to_string(port); modelName_ = model_name; @@ -112,6 +127,7 @@ void TFClientRemoteTRT::predict(unsigned dataID, const float* img, float* result streamData.output_= result; streamData.holder_= std::move(holder); streamData.input_ = img; + streamData.async_ = async_; } streamData.predict(); edm::LogInfo("TFClientRemoteTRT") << "Async predict request sent";