Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrapper module around TRT + pytorch subgraphs #3270

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,13 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;

bool prev_cudagraphs_enabled = false;
// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

void set_profiling_paths();
#ifndef NDEBUG
bool profile_execution = true;
bool profile_execution = false;
#else
bool profile_execution = false;
#endif
Expand Down
19 changes: 12 additions & 7 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,16 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
LOG_INFO("" << log_info);
compiled_engine->cudagraph.enable_debug_mode();
}
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);

// Whether cudagraphs needs to record the graph on this pass
bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine)));
// Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
bool need_cudagraphs_record = cudagraphs_enabled &&
((!compiled_engine->prev_cudagraphs_enabled) || (!_cudagraphs_validate_shapes(inputs, compiled_engine)));

if (!CUDAGRAPHS_MODE) {
compiled_engine->prev_cudagraphs_enabled = cudagraphs_enabled;

if (!cudagraphs_enabled) {
compiled_engine->cudagraph.reset();
}

Expand Down Expand Up @@ -211,7 +216,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()),
"Error while setting the tensor address for shape inputs");

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// @peri044 I dont know if this makes sense since they are supposed to be GPU buffers
compiled_engine->input_buffers[i] = input_cpu;
}
Expand All @@ -231,7 +236,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape");

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// If using CUDAGraphs copy formatted input to the corresponding persistent input buffer
compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true);
TORCHTRT_CHECK(
Expand Down Expand Up @@ -281,7 +286,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone());
}

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
TORCHTRT_CHECK(
compiled_engine->exec_ctx->setTensorAddress(
name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()),
Expand Down Expand Up @@ -324,7 +329,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);

if (!CUDAGRAPHS_MODE) {
if (!cudagraphs_enabled) {
// Direct execution uses the caller buffers directly
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
} else {
Expand All @@ -350,7 +355,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);

if (CUDAGRAPHS_MODE) {
if (cudagraphs_enabled) {
// If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream)
for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) {
outputs[o].copy_(compiled_engine->output_buffers[o], false);
Expand Down
6 changes: 4 additions & 2 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ TORCH_LIBRARY(tensorrt, m) {
m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
});
m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; });
m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; });
m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void {
CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode);
});
m.def("set_logging_level", [](int64_t level) -> void {
util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level));
});
Expand Down
6 changes: 3 additions & 3 deletions core/runtime/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace core {
namespace runtime {

bool MULTI_DEVICE_SAFE_MODE = false;
bool CUDAGRAPHS_MODE = false;
CudaGraphsMode CUDAGRAPHS_MODE = STANDARD;

c10::optional<RTDevice> get_most_compatible_device(
const RTDevice& target_device,
Expand Down Expand Up @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) {
MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode;
}

bool get_cudagraphs_mode() {
CudaGraphsMode get_cudagraphs_mode() {
return CUDAGRAPHS_MODE;
}

void set_cudagraphs_mode(bool cudagraphs_mode) {
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) {
CUDAGRAPHS_MODE = cudagraphs_mode;
}

Expand Down
13 changes: 10 additions & 3 deletions core/runtime/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,14 @@ namespace runtime {
using EngineID = int64_t;
const std::string ABI_VERSION = "6";
extern bool MULTI_DEVICE_SAFE_MODE;
extern bool CUDAGRAPHS_MODE;

typedef enum {
STANDARD = 0,
SUBGRAPH_CUDAGRAPHS,
WHOLE_GRAPH_CUDAGRAPHS,
} CudaGraphsMode;

extern CudaGraphsMode CUDAGRAPHS_MODE;

typedef enum {
ABI_TARGET_IDX = 0,
Expand Down Expand Up @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode();

void set_multi_device_safe_mode(bool multi_device_safe_mode);

bool get_cudagraphs_mode();
CudaGraphsMode get_cudagraphs_mode();

void set_cudagraphs_mode(bool cudagraphs_mode);
void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode);

class DeviceList {
using DeviceMap = std::unordered_map<int, RTDevice>;
Expand Down
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ Tutorials
* :ref:`custom_kernel_plugins`
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`cudagraphs_wrapper_example`

.. toctree::
:caption: Tutorials
Expand All @@ -84,6 +85,7 @@ Tutorials
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example
tutorials/_rendered_examples/dynamo/cudagraphs_wrapper_example

Dynamo Frontend
----------------
Expand Down
98 changes: 98 additions & 0 deletions examples/dynamo/cudagraphs_wrapper_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""
.. _cudagraphs_wrapper_example:

Wrapped runtime module for cuda graphs
======================================

If Torch-TensorRT encounters unsupported operations during compilation, it can fall back to using
PyTorch's native implementation for those specific operations. This fallback mechanism allows the
rest of the model to be executed using TensorRT, while only the unsupported parts are handled by PyTorch.
This fallback results in a graph break, which can reduce the overall performance benefits of using
TensorRT because it introduces additional overhead from switching between TensorRT and PyTorch execution contexts

Applying CUDA Graphs to a PyTorch module that contains graph breaks can enhance performance by leveraging
the benefits of CUDA Graphs even in the presence of these breaks. Torch-TensorRT provides
wrapper runtime module with CUDA Graphs for modules that have graph breaks allows you to mitigate the
inefficiencies introduced by these breaks
"""

# %%
# Imports and Model Definition
# ----------------------------------

import torch
import torch_tensorrt


class SampleModel(torch.nn.Module):
def forward(self, x):
return torch.relu((x + 2) * 0.5)


model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# %%
# Compiler options
# ----------------------------------
#
# The 'torch_executed_ops' compiler option is used to demonstrate graph breaks for this example.
# debug=True compiler option provides detailed insights into the compilation process and helps
# pinpoint where graph breaks occur

# Create a TensorRT-compiled model
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
debug=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
)

# %%
# Compiler log
# ----------------------------------
#
# This compiler log indicates torch.ops.aten.mul.Tensor operator is executed by PyTorch.
# Peformance of this module can be enhanced by using wrapped module.

##############################################################################
# .. code-block:: none
#
# ++++++++++++++ Dry-Run Results for Graph +++++++++++++++++
#
# The graph consists of 3 Total Operators, of which 2 operators are supported, 66.67% coverage
#
# The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph:
# torch.ops.aten.mul.Tensor: 1
#
# The following nodes are currently set to run in Torch:
# Node: torch.ops.aten.mul.Tensor, with layer location: /mul
# Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner

# %%
# trt module with cuda graphs
# ----------------------------------
#
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
# kernel launch overhead and improved execution efficiency, may be diminished.
with torch_tensorrt.runtime.enable_cudagraphs():
trt_model(input)

# %%
# Running wrapped module with cuda graphs
# ----------------------------------
#
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
# that can be executed efficiently, even in the presence of graph breaks. When a CUDA Graph context manager is
# used with the TensorRT module as a positional argument, it returns a wrapped_module. This module captures the
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.
with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as wrapped_module:
wrapped_module(input)
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
Expand Down Expand Up @@ -586,14 +589,16 @@ def save(
Save the model to disk in the specified output format.

Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
if isinstance(module, WrapperTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if arg_inputs is not None and not all(
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
post_lowering,
pre_export_lowering,
)
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)
from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_metadata,
Expand Down Expand Up @@ -373,6 +376,7 @@ def compile(
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
enable_wrapper_module: bool = _defaults.ENABLE_WRAPPER_MODULE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -589,6 +593,7 @@ def compile(
"use_fp32_acc": use_fp32_acc,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
"enable_wrapper_module": enable_wrapper_module,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -832,6 +837,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

dryrun_stats_display(dryrun_tracker, settings.dryrun)

if settings.enable_wrapper_module:
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
partitioned_module = WrapperTorchTensorRTModule(partitioned_module)

return partitioned_module


Expand Down
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
USE_FP32_ACC = False
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
ENABLE_WRAPPER_MODULE = False


def default_device() -> Device:
Expand Down
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
ENABLE_WRAPPER_MODULE,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
Expand Down Expand Up @@ -125,6 +126,7 @@ class CompilationSettings:
use_fp32_acc: bool = USE_FP32_ACC
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
enable_wrapper_module: bool = ENABLE_WRAPPER_MODULE


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
Loading
Loading