diff --git a/HeterogeneousCore/SonicTriton/BuildFile.xml b/HeterogeneousCore/SonicTriton/BuildFile.xml index b574f395f4d12..1961d534af2ef 100644 --- a/HeterogeneousCore/SonicTriton/BuildFile.xml +++ b/HeterogeneousCore/SonicTriton/BuildFile.xml @@ -1,6 +1,7 @@ + diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h index 99ca5f8765fe7..96e30bea35534 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h @@ -40,6 +40,15 @@ class TritonClient : public SonicClient { //for fillDescriptions static void fillPSetDescription(edm::ParameterSetDescription& iDesc); + static std::pair splitNameConverter(const std::string& fullname) { + static const std::string indicator("_DataConverter:"); + size_t dcpos = fullname.find(indicator); + if (dcpos != std::string::npos) + return {fullname.substr(0,dcpos), fullname.substr(dcpos+indicator.size())}; + else + return {fullname, ""}; + } + protected: //helper bool getResults(std::shared_ptr results); diff --git a/HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h b/HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h new file mode 100644 index 0000000000000..dc21bb8a8a415 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h @@ -0,0 +1,43 @@ +#ifndef HeterogeneousCore_SonicTriton_TritonConverterBase +#define HeterogeneousCore_SonicTriton_TritonConverterBase + +#include "FWCore/ParameterSet/interface/ParameterSet.h" +#include "DataFormats/Common/interface/Handle.h" + +#include + +template +class TritonConverterBase { +//class needs to be templated since the convert functions require the data type, but need to also be virtual, and virtual member function templates are not allowed in C++ +public: + TritonConverterBase(const std::string convName) + : converterName_(convName), byteSize_(sizeof(DT)) {} + TritonConverterBase(const std::string convName, size_t byteSize) + : converterName_(convName), byteSize_(byteSize) {} + TritonConverterBase(const TritonConverterBase&) = delete; + virtual ~TritonConverterBase() = default; + TritonConverterBase& operator=(const TritonConverterBase&) = delete; + + virtual const uint8_t* convertIn (const DT* in) const = 0; + virtual const DT* convertOut (const uint8_t* in) const = 0; + + const int64_t byteSize() const { return byteSize_; } + + const std::string& name() const { return converterName_; } + + virtual void clear() const {} + +private: + const std::string converterName_; + const int64_t byteSize_; +}; + +#include "FWCore/PluginManager/interface/PluginFactory.h" + +template +using TritonConverterFactory = edmplugin::PluginFactory*()>; + +#define DEFINE_TRITON_CONVERTER(input, type, name) DEFINE_EDM_PLUGIN(TritonConverterFactory, type, name) +#define DEFINE_TRITON_CONVERTER_SIMPLE(input, type) DEFINE_EDM_PLUGIN(TritonConverterFactory, type, #type) + +#endif diff --git a/HeterogeneousCore/SonicTriton/interface/TritonData.h b/HeterogeneousCore/SonicTriton/interface/TritonData.h index 50808de4a1216..a3acfb42183b9 100644 --- a/HeterogeneousCore/SonicTriton/interface/TritonData.h +++ b/HeterogeneousCore/SonicTriton/interface/TritonData.h @@ -4,6 +4,9 @@ #include "FWCore/Utilities/interface/Exception.h" #include "FWCore/Utilities/interface/Span.h" +#include "FWCore/PluginManager/interface/PluginFactory.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h" + #include #include #include @@ -69,6 +72,39 @@ class TritonData { void setResult(std::shared_ptr result) { result_ = result; } IO* data() { return data_.get(); } + std::string defaultConverter(const std::string name) const { + if (!name.empty()) { + return name; + } + else { + std::string base = "StandardConverter"; + if (dtype_ == inference::DataType::TYPE_INT64) { + return "Int64"+base; + } + else if (dtype_ == inference::DataType::TYPE_FP32) { + return "Float"+base; + } else { + throw cms::Exception("ConverterErrors") << "Unable to create default converter for " << name_ << " of " << dname_ << " type\n"; + } + } + } + + void setConverterParams(const std::string& convName) { + converterName_ = convName; + } + + template + std::shared_ptr> createConverter() const { + using ConverterType = std::shared_ptr>; + //this construction catches bad any_cast without throwing std exception + if (auto ptr = std::any_cast(&converter_)) { + } else { + converter_ = ConverterType(TritonConverterFactory
::get()->create(converterName_)); + converter_clear_ = std::bind(&TritonConverterBase
::clear, std::any_cast(converter_).get()); + } + return std::any_cast(converter_); + } + //helpers bool anyNeg(const ShapeView& vec) const { return std::any_of(vec.begin(), vec.end(), [](int64_t i) { return i < 0; }); @@ -93,6 +129,9 @@ class TritonData { int64_t byteSize_; std::any holder_; std::shared_ptr result_; + mutable std::any converter_; + std::string converterName_; + mutable std::function converter_clear_; }; using TritonInputData = TritonData; diff --git a/HeterogeneousCore/SonicTriton/plugins/BuildFile.xml b/HeterogeneousCore/SonicTriton/plugins/BuildFile.xml new file mode 100644 index 0000000000000..0427e58fdb43a --- /dev/null +++ b/HeterogeneousCore/SonicTriton/plugins/BuildFile.xml @@ -0,0 +1,5 @@ + + + + + diff --git a/HeterogeneousCore/SonicTriton/plugins/converters/FloatApFixed16Converter.cc b/HeterogeneousCore/SonicTriton/plugins/converters/FloatApFixed16Converter.cc new file mode 100644 index 0000000000000..62b97d405eac9 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/plugins/converters/FloatApFixed16Converter.cc @@ -0,0 +1,44 @@ +#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h" + +#include +#include "ap_fixed.h" + +template +class FloatApFixed16Converter : public TritonConverterBase { +public: + FloatApFixed16Converter() : TritonConverterBase("FloatApFixed16F"+std::to_string(I)+"Converter", 2) {} + + const uint8_t* convertIn(const float* in) const override { + auto temp_vec = std::make_shared>>(std::move(this->makeVecIn(in))); + inputHolder_.push_back(temp_vec); + return reinterpret_cast(temp_vec->data()); + } + const float* convertOut(const uint8_t* in) const override { + auto temp_vec = std::make_shared>(std::move(this->makeVecOut(reinterpret_cast*>(in)))); + outputHolder_.push_back(temp_vec); + return temp_vec->data(); + } + + void clear() const override { + inputHolder_.clear(); + outputHolder_.clear(); + } + +private: + std::vector> makeVecIn(const float* in) const { + unsigned int nfeat = sizeof(in) / sizeof(float); + std::vector> temp_storage(in, in + nfeat); + return temp_storage; + } + + std::vector makeVecOut(const ap_fixed<16, I>* in) const { + unsigned int nfeat = sizeof(in) / sizeof(ap_fixed<16, I>); + std::vector temp_storage(in, in + nfeat); + return temp_storage; + } + + mutable std::vector>>> inputHolder_; + mutable std::vector>> outputHolder_; +}; + +DEFINE_TRITON_CONVERTER(float, FloatApFixed16Converter<6>, "FloatApFixed16F6Converter"); diff --git a/HeterogeneousCore/SonicTriton/plugins/converters/FloatStandardConverter.cc b/HeterogeneousCore/SonicTriton/plugins/converters/FloatStandardConverter.cc new file mode 100644 index 0000000000000..b2056bb85129f --- /dev/null +++ b/HeterogeneousCore/SonicTriton/plugins/converters/FloatStandardConverter.cc @@ -0,0 +1,11 @@ +#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h" + +class FloatStandardConverter : public TritonConverterBase { +public: + FloatStandardConverter() : TritonConverterBase("FloatStandardConverter") {} + + const uint8_t* convertIn(const float* in) const override { return reinterpret_cast(in); } + const float* convertOut(const uint8_t* in) const override { return reinterpret_cast(in); } +}; + +DEFINE_TRITON_CONVERTER_SIMPLE(float, FloatStandardConverter); diff --git a/HeterogeneousCore/SonicTriton/plugins/converters/Int64StandardConverter.cc b/HeterogeneousCore/SonicTriton/plugins/converters/Int64StandardConverter.cc new file mode 100644 index 0000000000000..8e1381834c369 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/plugins/converters/Int64StandardConverter.cc @@ -0,0 +1,11 @@ +#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h" + +class Int64StandardConverter : public TritonConverterBase { +public: + Int64StandardConverter() : TritonConverterBase("Int64StandardConverter") {} + + const uint8_t* convertIn(const int64_t* in) const override { return reinterpret_cast(in); } + const int64_t* convertOut(const uint8_t* in) const override { return reinterpret_cast(in); } +}; + +DEFINE_TRITON_CONVERTER_SIMPLE(int64_t, Int64StandardConverter); diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc index 98380e6546f4d..cc48385e41d5a 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc @@ -79,6 +79,12 @@ TritonClient::TritonClient(const edm::ParameterSet& params) if (!msg_str.empty()) throw cms::Exception("ModelErrors") << msg_str; + const std::vector& inputConverterDefs = params.getParameterSetVector("inputConverters"); + std::unordered_map inConvMap; + for (const auto& converterDef : inputConverterDefs) { + inConvMap[converterDef.getParameter("inputName")] = converterDef.getParameter("converterName"); + } + //setup input map std::stringstream io_msg; if (verbose_) @@ -86,10 +92,16 @@ TritonClient::TritonClient(const edm::ParameterSet& params) << "\n"; inputsTriton_.reserve(nicInputs.size()); for (const auto& nicInput : nicInputs) { - const auto& iname = nicInput.name(); + const std::string iname_full = nicInput.name(); + const auto& [iname, iconverter] = TritonClient::splitNameConverter(iname_full); auto [curr_itr, success] = input_.emplace( std::piecewise_construct, std::forward_as_tuple(iname), std::forward_as_tuple(iname, nicInput, noBatch_)); auto& curr_input = curr_itr->second; + if ( inConvMap.find(iname) == inConvMap.end() ) { + curr_input.setConverterParams(curr_input.defaultConverter(iconverter)); + } else { + curr_input.setConverterParams(inConvMap[iname]); + } inputsTriton_.push_back(curr_input.data()); if (verbose_) { io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize() @@ -101,18 +113,30 @@ TritonClient::TritonClient(const edm::ParameterSet& params) const auto& v_outputs = params.getUntrackedParameter>("outputs"); std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end()); + const std::vector& outputConverterDefs = params.getParameterSetVector("outputConverters"); + std::unordered_map outConvMap; + for (const auto& converterDef : outputConverterDefs) { + outConvMap[converterDef.getParameter("outputName")] = converterDef.getParameter("converterName"); + } + //setup output map if (verbose_) io_msg << "Model outputs: " << "\n"; outputsTriton_.reserve(nicOutputs.size()); for (const auto& nicOutput : nicOutputs) { - const auto& oname = nicOutput.name(); + const std::string oname_full = nicOutput.name(); + const auto& [oname, oconverter] = TritonClient::splitNameConverter(oname_full); if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end()) continue; auto [curr_itr, success] = output_.emplace( std::piecewise_construct, std::forward_as_tuple(oname), std::forward_as_tuple(oname, nicOutput, noBatch_)); auto& curr_output = curr_itr->second; + if ( outConvMap.find(oname) == outConvMap.end() ) { + curr_output.setConverterParams(curr_output.defaultConverter(oconverter)); + } else { + curr_output.setConverterParams(outConvMap[oname]); + } outputsTriton_.push_back(curr_output.data()); if (verbose_) { io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize() @@ -336,10 +360,19 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const { //for fillDescriptions void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) { + edm::ParameterSetDescription descInConverter; + descInConverter.add("converterName"); + descInConverter.add("inputName"); + edm::ParameterSetDescription descOutConverter; + descOutConverter.add("converterName"); + descOutConverter.add("outputName"); + std::vector blankVPSet; edm::ParameterSetDescription descClient; fillBasePSetDescription(descClient); descClient.add("modelName"); descClient.add("modelVersion", ""); + descClient.addVPSet("inputConverters", descInConverter, blankVPSet); + descClient.addVPSet("outputConverters", descOutConverter, blankVPSet); //server parameters should not affect the physics results descClient.addUntracked("batchSize"); descClient.addUntracked("address"); diff --git a/HeterogeneousCore/SonicTriton/src/TritonData.cc b/HeterogeneousCore/SonicTriton/src/TritonData.cc index 258671be07691..a885aa0cc45fb 100644 --- a/HeterogeneousCore/SonicTriton/src/TritonData.cc +++ b/HeterogeneousCore/SonicTriton/src/TritonData.cc @@ -1,5 +1,6 @@ #include "HeterogeneousCore/SonicTriton/interface/TritonData.h" #include "HeterogeneousCore/SonicTriton/interface/triton_utils.h" +#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h" #include "FWCore/MessageLogger/interface/MessageLogger.h" #include "model_config.pb.h" @@ -116,14 +117,16 @@ void TritonInputData::toServer(std::shared_ptr> ptr) { //shape must be specified for variable dims or if batch size changes data_->SetShape(fullShape_); - if (byteSize_ != sizeof(DT)) - throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << sizeof(DT) + auto converter = createConverter
(); + + if (byteSize_ != converter->byteSize()) + throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << converter->byteSize() << " (should be " << byteSize_ << " for " << dname_ << ")"; int64_t nInput = sizeShape(); for (unsigned i0 = 0; i0 < batchSize_; ++i0) { const DT* arr = data_in[i0].data(); - triton_utils::throwIfError(data_->AppendRaw(reinterpret_cast(arr), nInput * byteSize_), + triton_utils::throwIfError(data_->AppendRaw(converter->convertIn(arr), nInput * byteSize_), name_ + " input(): unable to set data for batch entry " + std::to_string(i0)); } @@ -138,7 +141,9 @@ TritonOutput
TritonOutputData::fromServer() const { throw cms::Exception("TritonDataError") << name_ << " output(): missing result"; } - if (byteSize_ != sizeof(DT)) { + auto converter = createConverter
(); + + if (byteSize_ != converter->byteSize()) { throw cms::Exception("TritonDataError") << name_ << " output(): inconsistent byte size " << sizeof(DT) << " (should be " << byteSize_ << " for " << dname_ << ")"; } @@ -147,14 +152,14 @@ TritonOutput
TritonOutputData::fromServer() const { TritonOutput
dataOut; const uint8_t* r0; size_t contentByteSize; - size_t expectedContentByteSize = nOutput * byteSize_ * batchSize_; + size_t expectedContentByteSize = nOutput * converter->byteSize() * batchSize_; triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize), "output(): unable to get raw"); if (contentByteSize != expectedContentByteSize) { throw cms::Exception("TritonDataError") << name_ << " output(): unexpected content byte size " << contentByteSize << " (expected " << expectedContentByteSize << ")"; } - const DT* r1 = reinterpret_cast(r0); + const DT* r1 = converter->convertOut(r0); dataOut.reserve(batchSize_); for (unsigned i0 = 0; i0 < batchSize_; ++i0) { auto offset = i0 * nOutput; @@ -168,11 +173,13 @@ template <> void TritonInputData::reset() { data_->Reset(); holder_.reset(); + converter_clear_(); } template <> void TritonOutputData::reset() { result_.reset(); + converter_clear_(); } //explicit template instantiation declarations diff --git a/HeterogeneousCore/SonicTriton/src/pluginFactories.cc b/HeterogeneousCore/SonicTriton/src/pluginFactories.cc new file mode 100644 index 0000000000000..101532a5de3c1 --- /dev/null +++ b/HeterogeneousCore/SonicTriton/src/pluginFactories.cc @@ -0,0 +1,4 @@ +#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h" + +EDM_REGISTER_PLUGINFACTORY(TritonConverterFactory, "TritonConverterFloatFactory"); +EDM_REGISTER_PLUGINFACTORY(TritonConverterFactory, "TritonConverterInt64Factory");