Skip to content

Commit

Permalink
chore: updates
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Oct 16, 2024
1 parent cb03ca1 commit 839c72e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 22 deletions.
21 changes: 4 additions & 17 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,28 +324,15 @@ FlattenedState TRTEngine::__obj_flatten__() {
auto trt_engine = std::string((const char*)serialized_trt_engine->data(), serialized_trt_engine->size());

return std::tuple(
std::tuple("version", ABI_VERSION),
std::tuple("name", this->name),
std::tuple("serialized_engine", base64_encode(trt_engine)),
std::tuple("device_info", this->device_info.serialize()),
std::tuple("serialized_engine", base64_encode(trt_engine)),
std::tuple("in_binding_names", this->in_binding_names),
std::tuple("out_binding_names", this->out_binding_names),
std::tuple("target_platform", this->target_platform.serialize()),
std::tuple("hardware_compatible", this->hardware_compatible),
std::tuple("serialized_metadata", this->serialized_metadata));
// std::tuple("engine_stream", this->engine_stream),
// std::tuple("caller_stream", this->caller_stream),
// std::tuple("input_buffers", this->input_buffers),
// std::tuple("output_buffers", this->output_buffers),
// std::tuple("shape_key", this->shape_key),
// std::tuple("cudagraph_mempool_id", this->cudagraph_mempool_id),
// std::tuple("profile_execution", this->profile_execution),
// std::tuple("device_profile_path", this->device_profile_path),
// std::tuple("input_profile_path", this->input_profile_path),
// std::tuple("output_profile_path", this->output_profile_path),
// std::tuple("enqueue_profile_path", this->enqueue_profile_path),
// std::tuple("trt_engine_profile_path", this->trt_engine_profile_path),
// std::tuple("cuda_graph_debug_path", this->cuda_graph_debug_path),
// std::tuple("trt_engine_profiler", this->trt_engine_profiler),);
std::tuple("serialized_metadata", this->serialized_metadata),
std::tuple("target_platform", this->target_platform.serialize()));
}

} // namespace runtime
Expand Down
7 changes: 4 additions & 3 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ namespace core {
namespace runtime {

using FlattenedState = std::tuple<
std::tuple<std::string, std::string>, // ABI_VERSION
std::tuple<std::string, std::string>, // name
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::string>, // device
std::tuple<std::string, std::string>, // engine
std::tuple<std::string, std::vector<std::string>>, // input binding names
std::tuple<std::string, std::vector<std::string>>, // output binding names
std::tuple<std::string, std::string>, // Platform
std::tuple<std::string, bool>, // HW compatibility
std::tuple<std::string, std::string>>; // serialized metadata
std::tuple<std::string, std::string>, // serialized metadata
std::tuple<std::string, std::string>>; // Platform

struct TRTEngine : torch::CustomClassHolder {
// Each engine needs it's own runtime object
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,10 @@ def setup_engine(self) -> None:
"""
if self.engine is not None:
return

self.engine = torch.classes.tensorrt.Engine(self._pack_engine_info())

@staticmethod
def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
dumped_metadata = pickle.dumps(metadata)
Expand Down Expand Up @@ -270,7 +272,6 @@ def forward(self, *inputs: Any) -> torch.Tensor | Tuple[torch.Tensor, ...]:
(i if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]

outputs: List[torch.Tensor] = torch.ops.tensorrt.execute_engine(
list(input_tensors), self.engine
)
Expand Down
9 changes: 8 additions & 1 deletion py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from typing import Any

import torch
Expand All @@ -11,4 +12,10 @@ def __init__(self) -> None:

@classmethod
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
return cls(**dict(flattened_tq))
engine_info = [info[1] for info in flattened_tq]
engine_info[3] = base64.b64decode(engine_info[3]) # decode engine
engine_info[4] = str(engine_info[4][0]) # input names
engine_info[5] = str(engine_info[5][0]) # output names
engine_info[6] = str(int(engine_info[6])) # hw compatible
trt_engine = torch.classes.tensorrt.Engine(engine_info)
return trt_engine

0 comments on commit 839c72e

Please sign in to comment.