Skip to content

Commit

Permalink
Fix Triton C API mode missing infer requested output datatype bug
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewkotila committed Aug 7, 2024
1 parent c7b1642 commit 5df151a
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 40 deletions.
6 changes: 3 additions & 3 deletions src/client_backend/triton_c_api/triton_c_api_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ TritonCApiInferRequestedOutput::Create(
const size_t class_count, const std::string& datatype)
{
TritonCApiInferRequestedOutput* local_infer_output =
new TritonCApiInferRequestedOutput(name);
new TritonCApiInferRequestedOutput(name, datatype);

tc::InferRequestedOutput* triton_infer_output;
RETURN_IF_TRITON_ERROR(tc::InferRequestedOutput::Create(
Expand All @@ -427,8 +427,8 @@ TritonCApiInferRequestedOutput::SetSharedMemory(
}

TritonCApiInferRequestedOutput::TritonCApiInferRequestedOutput(
const std::string& name)
: InferRequestedOutput(BackendKind::TRITON_C_API, name)
const std::string& name, const std::string& datatype)
: InferRequestedOutput(BackendKind::TRITON_C_API, name, datatype)
{
}

Expand Down
3 changes: 2 additions & 1 deletion src/client_backend/triton_c_api/triton_c_api_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ class TritonCApiInferRequestedOutput : public InferRequestedOutput {
const std::string& name, size_t byte_size, size_t offset = 0) override;

private:
explicit TritonCApiInferRequestedOutput(const std::string& name);
explicit TritonCApiInferRequestedOutput(
const std::string& name, const std::string& datatype);

std::unique_ptr<tc::InferRequestedOutput> output_;
};
Expand Down
29 changes: 10 additions & 19 deletions src/client_backend/triton_c_api/triton_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -338,25 +338,16 @@ TritonLoader::StartTriton()
// Create the allocator that will be used to allocate buffers for
// the result tensors.
RETURN_IF_TRITONSERVER_ERROR(
GetSingleton()->response_allocator_new_fn_(
&allocator_,
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator,
const char* tensor_name, size_t byte_size,
TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id, void* userp,
void** buffer, void** buffer_userp,
TRITONSERVER_MemoryType*
actual_memory_type,
int64_t* actual_memory_type_id)>(
ResponseAlloc),
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator* allocator,
void* buffer, void* buffer_userp,
size_t byte_size,
TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)>(ResponseRelease),
nullptr /* start_fn */),
GetSingleton()
->response_allocator_new_fn_(
&allocator_,
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator * allocator, const char* tensor_name, size_t byte_size, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id, void* userp, void** buffer, void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type, int64_t* actual_memory_type_id)>(
ResponseAlloc),
reinterpret_cast<
TRITONSERVER_Error* (*)(TRITONSERVER_ResponseAllocator * allocator, void* buffer, void* buffer_userp, size_t byte_size, TRITONSERVER_MemoryType memory_type, int64_t memory_type_id)>(
ResponseRelease),
nullptr /* start_fn */),
"creating response allocator");

return Error::Success;
Expand Down
3 changes: 2 additions & 1 deletion src/perf_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ ParseTensorFormat(const std::string& content_type_str)
}
}

size_t
std::optional<size_t>
GetDataTypeSize(const std::string& data_type)
{
if (data_type == "BOOL") {
Expand Down Expand Up @@ -445,6 +445,7 @@ GetDataTypeSize(const std::string& data_type)
} else {
std::cerr << "WARNING: unsupported data type: '" + data_type + "'"
<< std::endl;
return {};
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/perf_utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -36,6 +36,7 @@
#include <iomanip>
#include <iostream>
#include <memory>
#include <optional>
#include <random>

#include "client_backend/client_backend.h"
Expand Down Expand Up @@ -138,6 +139,6 @@ std::function<std::chrono::nanoseconds(std::mt19937&)> ScheduleDistribution(
cb::TensorFormat ParseTensorFormat(const std::string& tensor_format_str);

// Returns the size of a given data type in bytes.
size_t GetDataTypeSize(const std::string& data_type);
std::optional<size_t> GetDataTypeSize(const std::string& data_type);

}} // namespace triton::perfanalyzer
7 changes: 6 additions & 1 deletion src/profile_data_exporter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,16 @@ ProfileDataExporter::AddDataToJSON(
{
// TPA-268: support N-dimensional tensor
size_t data_size;
// TODO TPA-283: Add support for N-dimensional string tensors
if (data_type == "BYTES" || data_type == "JSON") {
// return string as is instead of array of chars
data_size = 1;
} else {
data_size = byte_size / GetDataTypeSize(data_type);
const std::optional<size_t> data_type_size{GetDataTypeSize(data_type)};
if (!data_type_size) {
return;
}
data_size = byte_size / data_type_size.value();
if (data_size > 1) {
json.SetArray();
}
Expand Down
2 changes: 1 addition & 1 deletion src/request_record.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace triton { namespace perfanalyzer {

/// A record containing the data of a single request input or response output
struct RecordData {
RecordData(const uint8_t* buf, size_t size, std::string data_type = "")
RecordData(const uint8_t* buf, size_t size, std::string data_type)
{
uint8_t* array = new uint8_t[size];
std::memcpy(array, buf, size);
Expand Down
24 changes: 12 additions & 12 deletions src/test_profile_data_collector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ TEST_CASE("profile_data_collector: AddData")
uint8_t fake_data_in[] = {0x01, 0x02, 0x03, 0x04};
uint8_t fake_data_out[] = {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08};
RequestRecord::RequestInput request1_request_input{
{"key1", RecordData(fake_data_in, 1)},
{"key2", RecordData(fake_data_in, 2)}};
{"key1", RecordData(fake_data_in, 1, "fake_datatype")},
{"key2", RecordData(fake_data_in, 2, "fake_datatype")}};
RequestRecord::ResponseOutput request1_response1_output{
{"key1", RecordData(fake_data_out, 1)},
{"key2", RecordData(fake_data_out, 2)}};
{"key1", RecordData(fake_data_out, 1, "fake_datatype")},
{"key2", RecordData(fake_data_out, 2, "fake_datatype")}};
RequestRecord::ResponseOutput request1_response2_output{
{"key3", RecordData(fake_data_out, 3)},
{"key4", RecordData(fake_data_out, 4)}};
{"key3", RecordData(fake_data_out, 3, "fake_datatype")},
{"key4", RecordData(fake_data_out, 4, "fake_datatype")}};

RequestRecord request_record1{
request1_timestamp,
Expand All @@ -95,14 +95,14 @@ TEST_CASE("profile_data_collector: AddData")
auto request2_response1_timestamp{clock_epoch + nanoseconds(5)};
auto request2_response2_timestamp{clock_epoch + nanoseconds(6)};
RequestRecord::RequestInput request2_request_input{
{"key3", RecordData(fake_data_in, 3)},
{"key4", RecordData(fake_data_in, 4)}};
{"key3", RecordData(fake_data_in, 3, "fake_datatype")},
{"key4", RecordData(fake_data_in, 4, "fake_datatype")}};
RequestRecord::ResponseOutput request2_response1_output{
{"key5", RecordData(fake_data_out, 5)},
{"key6", RecordData(fake_data_out, 6)}};
{"key5", RecordData(fake_data_out, 5, "fake_datatype")},
{"key6", RecordData(fake_data_out, 6, "fake_datatype")}};
RequestRecord::ResponseOutput request2_response2_output{
{"key7", RecordData(fake_data_out, 7)},
{"key8", RecordData(fake_data_out, 8)}};
{"key7", RecordData(fake_data_out, 7, "fake_datatype")},
{"key8", RecordData(fake_data_out, 8, "fake_datatype")}};

RequestRecord request_record2{
request2_timestamp,
Expand Down

0 comments on commit 5df151a

Please sign in to comment.