diff --git a/HeterogeneousCore/SonicTriton/BuildFile.xml b/HeterogeneousCore/SonicTriton/BuildFile.xml
index b574f395f4d12..1961d534af2ef 100644
--- a/HeterogeneousCore/SonicTriton/BuildFile.xml
+++ b/HeterogeneousCore/SonicTriton/BuildFile.xml
@@ -1,6 +1,7 @@
+
diff --git a/HeterogeneousCore/SonicTriton/interface/TritonClient.h b/HeterogeneousCore/SonicTriton/interface/TritonClient.h
index 99ca5f8765fe7..96e30bea35534 100644
--- a/HeterogeneousCore/SonicTriton/interface/TritonClient.h
+++ b/HeterogeneousCore/SonicTriton/interface/TritonClient.h
@@ -40,6 +40,15 @@ class TritonClient : public SonicClient {
//for fillDescriptions
static void fillPSetDescription(edm::ParameterSetDescription& iDesc);
+ static std::pair splitNameConverter(const std::string& fullname) {
+ static const std::string indicator("_DataConverter:");
+ size_t dcpos = fullname.find(indicator);
+ if (dcpos != std::string::npos)
+ return {fullname.substr(0,dcpos), fullname.substr(dcpos+indicator.size())};
+ else
+ return {fullname, ""};
+ }
+
protected:
//helper
bool getResults(std::shared_ptr results);
diff --git a/HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h b/HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h
new file mode 100644
index 0000000000000..dc21bb8a8a415
--- /dev/null
+++ b/HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h
@@ -0,0 +1,43 @@
+#ifndef HeterogeneousCore_SonicTriton_TritonConverterBase
+#define HeterogeneousCore_SonicTriton_TritonConverterBase
+
+#include "FWCore/ParameterSet/interface/ParameterSet.h"
+#include "DataFormats/Common/interface/Handle.h"
+
+#include
+
+template
+class TritonConverterBase {
+//class needs to be templated since the convert functions require the data type, but need to also be virtual, and virtual member function templates are not allowed in C++
+public:
+ TritonConverterBase(const std::string convName)
+ : converterName_(convName), byteSize_(sizeof(DT)) {}
+ TritonConverterBase(const std::string convName, size_t byteSize)
+ : converterName_(convName), byteSize_(byteSize) {}
+ TritonConverterBase(const TritonConverterBase&) = delete;
+ virtual ~TritonConverterBase() = default;
+ TritonConverterBase& operator=(const TritonConverterBase&) = delete;
+
+ virtual const uint8_t* convertIn (const DT* in) const = 0;
+ virtual const DT* convertOut (const uint8_t* in) const = 0;
+
+ const int64_t byteSize() const { return byteSize_; }
+
+ const std::string& name() const { return converterName_; }
+
+ virtual void clear() const {}
+
+private:
+ const std::string converterName_;
+ const int64_t byteSize_;
+};
+
+#include "FWCore/PluginManager/interface/PluginFactory.h"
+
+template
+using TritonConverterFactory = edmplugin::PluginFactory*()>;
+
+#define DEFINE_TRITON_CONVERTER(input, type, name) DEFINE_EDM_PLUGIN(TritonConverterFactory, type, name)
+#define DEFINE_TRITON_CONVERTER_SIMPLE(input, type) DEFINE_EDM_PLUGIN(TritonConverterFactory, type, #type)
+
+#endif
diff --git a/HeterogeneousCore/SonicTriton/interface/TritonData.h b/HeterogeneousCore/SonicTriton/interface/TritonData.h
index 50808de4a1216..a3acfb42183b9 100644
--- a/HeterogeneousCore/SonicTriton/interface/TritonData.h
+++ b/HeterogeneousCore/SonicTriton/interface/TritonData.h
@@ -4,6 +4,9 @@
#include "FWCore/Utilities/interface/Exception.h"
#include "FWCore/Utilities/interface/Span.h"
+#include "FWCore/PluginManager/interface/PluginFactory.h"
+#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
+
#include
#include
#include
@@ -69,6 +72,39 @@ class TritonData {
void setResult(std::shared_ptr result) { result_ = result; }
IO* data() { return data_.get(); }
+ std::string defaultConverter(const std::string name) const {
+ if (!name.empty()) {
+ return name;
+ }
+ else {
+ std::string base = "StandardConverter";
+ if (dtype_ == inference::DataType::TYPE_INT64) {
+ return "Int64"+base;
+ }
+ else if (dtype_ == inference::DataType::TYPE_FP32) {
+ return "Float"+base;
+ } else {
+ throw cms::Exception("ConverterErrors") << "Unable to create default converter for " << name_ << " of " << dname_ << " type\n";
+ }
+ }
+ }
+
+ void setConverterParams(const std::string& convName) {
+ converterName_ = convName;
+ }
+
+ template
+ std::shared_ptr> createConverter() const {
+ using ConverterType = std::shared_ptr>;
+ //this construction catches bad any_cast without throwing std exception
+ if (auto ptr = std::any_cast(&converter_)) {
+ } else {
+ converter_ = ConverterType(TritonConverterFactory::get()->create(converterName_));
+ converter_clear_ = std::bind(&TritonConverterBase::clear, std::any_cast(converter_).get());
+ }
+ return std::any_cast(converter_);
+ }
+
//helpers
bool anyNeg(const ShapeView& vec) const {
return std::any_of(vec.begin(), vec.end(), [](int64_t i) { return i < 0; });
@@ -93,6 +129,9 @@ class TritonData {
int64_t byteSize_;
std::any holder_;
std::shared_ptr result_;
+ mutable std::any converter_;
+ std::string converterName_;
+ mutable std::function converter_clear_;
};
using TritonInputData = TritonData;
diff --git a/HeterogeneousCore/SonicTriton/plugins/BuildFile.xml b/HeterogeneousCore/SonicTriton/plugins/BuildFile.xml
new file mode 100644
index 0000000000000..0427e58fdb43a
--- /dev/null
+++ b/HeterogeneousCore/SonicTriton/plugins/BuildFile.xml
@@ -0,0 +1,5 @@
+
+
+
+
+
diff --git a/HeterogeneousCore/SonicTriton/plugins/converters/FloatApFixed16Converter.cc b/HeterogeneousCore/SonicTriton/plugins/converters/FloatApFixed16Converter.cc
new file mode 100644
index 0000000000000..62b97d405eac9
--- /dev/null
+++ b/HeterogeneousCore/SonicTriton/plugins/converters/FloatApFixed16Converter.cc
@@ -0,0 +1,44 @@
+#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
+
+#include
+#include "ap_fixed.h"
+
+template
+class FloatApFixed16Converter : public TritonConverterBase {
+public:
+ FloatApFixed16Converter() : TritonConverterBase("FloatApFixed16F"+std::to_string(I)+"Converter", 2) {}
+
+ const uint8_t* convertIn(const float* in) const override {
+ auto temp_vec = std::make_shared>>(std::move(this->makeVecIn(in)));
+ inputHolder_.push_back(temp_vec);
+ return reinterpret_cast(temp_vec->data());
+ }
+ const float* convertOut(const uint8_t* in) const override {
+ auto temp_vec = std::make_shared>(std::move(this->makeVecOut(reinterpret_cast*>(in))));
+ outputHolder_.push_back(temp_vec);
+ return temp_vec->data();
+ }
+
+ void clear() const override {
+ inputHolder_.clear();
+ outputHolder_.clear();
+ }
+
+private:
+ std::vector> makeVecIn(const float* in) const {
+ unsigned int nfeat = sizeof(in) / sizeof(float);
+ std::vector> temp_storage(in, in + nfeat);
+ return temp_storage;
+ }
+
+ std::vector makeVecOut(const ap_fixed<16, I>* in) const {
+ unsigned int nfeat = sizeof(in) / sizeof(ap_fixed<16, I>);
+ std::vector temp_storage(in, in + nfeat);
+ return temp_storage;
+ }
+
+ mutable std::vector>>> inputHolder_;
+ mutable std::vector>> outputHolder_;
+};
+
+DEFINE_TRITON_CONVERTER(float, FloatApFixed16Converter<6>, "FloatApFixed16F6Converter");
diff --git a/HeterogeneousCore/SonicTriton/plugins/converters/FloatStandardConverter.cc b/HeterogeneousCore/SonicTriton/plugins/converters/FloatStandardConverter.cc
new file mode 100644
index 0000000000000..b2056bb85129f
--- /dev/null
+++ b/HeterogeneousCore/SonicTriton/plugins/converters/FloatStandardConverter.cc
@@ -0,0 +1,11 @@
+#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
+
+class FloatStandardConverter : public TritonConverterBase {
+public:
+ FloatStandardConverter() : TritonConverterBase("FloatStandardConverter") {}
+
+ const uint8_t* convertIn(const float* in) const override { return reinterpret_cast(in); }
+ const float* convertOut(const uint8_t* in) const override { return reinterpret_cast(in); }
+};
+
+DEFINE_TRITON_CONVERTER_SIMPLE(float, FloatStandardConverter);
diff --git a/HeterogeneousCore/SonicTriton/plugins/converters/Int64StandardConverter.cc b/HeterogeneousCore/SonicTriton/plugins/converters/Int64StandardConverter.cc
new file mode 100644
index 0000000000000..8e1381834c369
--- /dev/null
+++ b/HeterogeneousCore/SonicTriton/plugins/converters/Int64StandardConverter.cc
@@ -0,0 +1,11 @@
+#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
+
+class Int64StandardConverter : public TritonConverterBase {
+public:
+ Int64StandardConverter() : TritonConverterBase("Int64StandardConverter") {}
+
+ const uint8_t* convertIn(const int64_t* in) const override { return reinterpret_cast(in); }
+ const int64_t* convertOut(const uint8_t* in) const override { return reinterpret_cast(in); }
+};
+
+DEFINE_TRITON_CONVERTER_SIMPLE(int64_t, Int64StandardConverter);
diff --git a/HeterogeneousCore/SonicTriton/src/TritonClient.cc b/HeterogeneousCore/SonicTriton/src/TritonClient.cc
index 98380e6546f4d..cc48385e41d5a 100644
--- a/HeterogeneousCore/SonicTriton/src/TritonClient.cc
+++ b/HeterogeneousCore/SonicTriton/src/TritonClient.cc
@@ -79,6 +79,12 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
if (!msg_str.empty())
throw cms::Exception("ModelErrors") << msg_str;
+ const std::vector& inputConverterDefs = params.getParameterSetVector("inputConverters");
+ std::unordered_map inConvMap;
+ for (const auto& converterDef : inputConverterDefs) {
+ inConvMap[converterDef.getParameter("inputName")] = converterDef.getParameter("converterName");
+ }
+
//setup input map
std::stringstream io_msg;
if (verbose_)
@@ -86,10 +92,16 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
<< "\n";
inputsTriton_.reserve(nicInputs.size());
for (const auto& nicInput : nicInputs) {
- const auto& iname = nicInput.name();
+ const std::string iname_full = nicInput.name();
+ const auto& [iname, iconverter] = TritonClient::splitNameConverter(iname_full);
auto [curr_itr, success] = input_.emplace(
std::piecewise_construct, std::forward_as_tuple(iname), std::forward_as_tuple(iname, nicInput, noBatch_));
auto& curr_input = curr_itr->second;
+ if ( inConvMap.find(iname) == inConvMap.end() ) {
+ curr_input.setConverterParams(curr_input.defaultConverter(iconverter));
+ } else {
+ curr_input.setConverterParams(inConvMap[iname]);
+ }
inputsTriton_.push_back(curr_input.data());
if (verbose_) {
io_msg << " " << iname << " (" << curr_input.dname() << ", " << curr_input.byteSize()
@@ -101,18 +113,30 @@ TritonClient::TritonClient(const edm::ParameterSet& params)
const auto& v_outputs = params.getUntrackedParameter>("outputs");
std::unordered_set s_outputs(v_outputs.begin(), v_outputs.end());
+ const std::vector& outputConverterDefs = params.getParameterSetVector("outputConverters");
+ std::unordered_map outConvMap;
+ for (const auto& converterDef : outputConverterDefs) {
+ outConvMap[converterDef.getParameter("outputName")] = converterDef.getParameter("converterName");
+ }
+
//setup output map
if (verbose_)
io_msg << "Model outputs: "
<< "\n";
outputsTriton_.reserve(nicOutputs.size());
for (const auto& nicOutput : nicOutputs) {
- const auto& oname = nicOutput.name();
+ const std::string oname_full = nicOutput.name();
+ const auto& [oname, oconverter] = TritonClient::splitNameConverter(oname_full);
if (!s_outputs.empty() and s_outputs.find(oname) == s_outputs.end())
continue;
auto [curr_itr, success] = output_.emplace(
std::piecewise_construct, std::forward_as_tuple(oname), std::forward_as_tuple(oname, nicOutput, noBatch_));
auto& curr_output = curr_itr->second;
+ if ( outConvMap.find(oname) == outConvMap.end() ) {
+ curr_output.setConverterParams(curr_output.defaultConverter(oconverter));
+ } else {
+ curr_output.setConverterParams(outConvMap[oname]);
+ }
outputsTriton_.push_back(curr_output.data());
if (verbose_) {
io_msg << " " << oname << " (" << curr_output.dname() << ", " << curr_output.byteSize()
@@ -336,10 +360,19 @@ inference::ModelStatistics TritonClient::getServerSideStatus() const {
//for fillDescriptions
void TritonClient::fillPSetDescription(edm::ParameterSetDescription& iDesc) {
+ edm::ParameterSetDescription descInConverter;
+ descInConverter.add("converterName");
+ descInConverter.add("inputName");
+ edm::ParameterSetDescription descOutConverter;
+ descOutConverter.add("converterName");
+ descOutConverter.add("outputName");
+ std::vector blankVPSet;
edm::ParameterSetDescription descClient;
fillBasePSetDescription(descClient);
descClient.add("modelName");
descClient.add("modelVersion", "");
+ descClient.addVPSet("inputConverters", descInConverter, blankVPSet);
+ descClient.addVPSet("outputConverters", descOutConverter, blankVPSet);
//server parameters should not affect the physics results
descClient.addUntracked("batchSize");
descClient.addUntracked("address");
diff --git a/HeterogeneousCore/SonicTriton/src/TritonData.cc b/HeterogeneousCore/SonicTriton/src/TritonData.cc
index 258671be07691..a885aa0cc45fb 100644
--- a/HeterogeneousCore/SonicTriton/src/TritonData.cc
+++ b/HeterogeneousCore/SonicTriton/src/TritonData.cc
@@ -1,5 +1,6 @@
#include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
#include "HeterogeneousCore/SonicTriton/interface/triton_utils.h"
+#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
#include "FWCore/MessageLogger/interface/MessageLogger.h"
#include "model_config.pb.h"
@@ -116,14 +117,16 @@ void TritonInputData::toServer(std::shared_ptr> ptr) {
//shape must be specified for variable dims or if batch size changes
data_->SetShape(fullShape_);
- if (byteSize_ != sizeof(DT))
- throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << sizeof(DT)
+ auto converter = createConverter();
+
+ if (byteSize_ != converter->byteSize())
+ throw cms::Exception("TritonDataError") << name_ << " input(): inconsistent byte size " << converter->byteSize()
<< " (should be " << byteSize_ << " for " << dname_ << ")";
int64_t nInput = sizeShape();
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
const DT* arr = data_in[i0].data();
- triton_utils::throwIfError(data_->AppendRaw(reinterpret_cast(arr), nInput * byteSize_),
+ triton_utils::throwIfError(data_->AppendRaw(converter->convertIn(arr), nInput * byteSize_),
name_ + " input(): unable to set data for batch entry " + std::to_string(i0));
}
@@ -138,7 +141,9 @@ TritonOutput TritonOutputData::fromServer() const {
throw cms::Exception("TritonDataError") << name_ << " output(): missing result";
}
- if (byteSize_ != sizeof(DT)) {
+ auto converter = createConverter();
+
+ if (byteSize_ != converter->byteSize()) {
throw cms::Exception("TritonDataError") << name_ << " output(): inconsistent byte size " << sizeof(DT)
<< " (should be " << byteSize_ << " for " << dname_ << ")";
}
@@ -147,14 +152,14 @@ TritonOutput TritonOutputData::fromServer() const {
TritonOutput dataOut;
const uint8_t* r0;
size_t contentByteSize;
- size_t expectedContentByteSize = nOutput * byteSize_ * batchSize_;
+ size_t expectedContentByteSize = nOutput * converter->byteSize() * batchSize_;
triton_utils::throwIfError(result_->RawData(name_, &r0, &contentByteSize), "output(): unable to get raw");
if (contentByteSize != expectedContentByteSize) {
throw cms::Exception("TritonDataError") << name_ << " output(): unexpected content byte size " << contentByteSize
<< " (expected " << expectedContentByteSize << ")";
}
- const DT* r1 = reinterpret_cast(r0);
+ const DT* r1 = converter->convertOut(r0);
dataOut.reserve(batchSize_);
for (unsigned i0 = 0; i0 < batchSize_; ++i0) {
auto offset = i0 * nOutput;
@@ -168,11 +173,13 @@ template <>
void TritonInputData::reset() {
data_->Reset();
holder_.reset();
+ converter_clear_();
}
template <>
void TritonOutputData::reset() {
result_.reset();
+ converter_clear_();
}
//explicit template instantiation declarations
diff --git a/HeterogeneousCore/SonicTriton/src/pluginFactories.cc b/HeterogeneousCore/SonicTriton/src/pluginFactories.cc
new file mode 100644
index 0000000000000..101532a5de3c1
--- /dev/null
+++ b/HeterogeneousCore/SonicTriton/src/pluginFactories.cc
@@ -0,0 +1,4 @@
+#include "HeterogeneousCore/SonicTriton/interface/TritonConverterBase.h"
+
+EDM_REGISTER_PLUGINFACTORY(TritonConverterFactory, "TritonConverterFloatFactory");
+EDM_REGISTER_PLUGINFACTORY(TritonConverterFactory, "TritonConverterInt64Factory");