diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index 88fb7ab275..7560660d81 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -102,13 +102,13 @@ struct TRTEngine : torch::CustomClassHolder { std::vector input_buffers = {}; std::vector output_buffers = {}; std::string shape_key; - + bool prev_cudagraphs_enabled = false; // TODO: Implement a call method // c10::List Run(c10::List inputs); void set_profiling_paths(); #ifndef NDEBUG - bool profile_execution = true; + bool profile_execution = false; #else bool profile_execution = false; #endif diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index a7908468f4..3a2bf0fd38 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -113,11 +113,16 @@ std::vector execute_engine(std::vector inputs, c10::intr LOG_INFO("" << log_info); compiled_engine->cudagraph.enable_debug_mode(); } + bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); // Whether cudagraphs needs to record the graph on this pass - bool need_cudagraphs_record = (CUDAGRAPHS_MODE && (!_cudagraphs_validate_shapes(inputs, compiled_engine))); + // Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change + bool need_cudagraphs_record = cudagraphs_enabled && + ((!compiled_engine->prev_cudagraphs_enabled) || (!_cudagraphs_validate_shapes(inputs, compiled_engine))); - if (!CUDAGRAPHS_MODE) { + compiled_engine->prev_cudagraphs_enabled = cudagraphs_enabled; + + if (!cudagraphs_enabled) { compiled_engine->cudagraph.reset(); } @@ -211,7 +216,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->exec_ctx->setTensorAddress(name.c_str(), inputShapeTensorValues.back().data()), "Error while setting the tensor address for shape inputs"); - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { // @peri044 I dont know if this makes sense since they are supposed to be GPU buffers compiled_engine->input_buffers[i] = input_cpu; } @@ -231,7 +236,7 @@ std::vector execute_engine(std::vector inputs, c10::intr TORCHTRT_CHECK( compiled_engine->exec_ctx->setInputShape(name.c_str(), dims), "Error while setting the input shape"); - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { // If using CUDAGraphs copy formatted input to the corresponding persistent input buffer compiled_engine->input_buffers[i].copy_(formatted_inputs.back(), true); TORCHTRT_CHECK( @@ -281,7 +286,7 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->output_buffers[pyt_idx] = std::move(outputs[pyt_idx].clone()); } - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { TORCHTRT_CHECK( compiled_engine->exec_ctx->setTensorAddress( name.c_str(), compiled_engine->output_buffers[pyt_idx].data_ptr()), @@ -324,7 +329,7 @@ std::vector execute_engine(std::vector inputs, c10::intr caller_exec_complete.record(compiled_engine->caller_stream); caller_exec_complete.block(compiled_engine->engine_stream); - if (!CUDAGRAPHS_MODE) { + if (!cudagraphs_enabled) { // Direct execution uses the caller buffers directly compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); } else { @@ -350,7 +355,7 @@ std::vector execute_engine(std::vector inputs, c10::intr trt_exec_complete.record(compiled_engine->engine_stream); trt_exec_complete.block(compiled_engine->caller_stream); - if (CUDAGRAPHS_MODE) { + if (cudagraphs_enabled) { // If in CUDAGraph mode, results need to be copied to the result buffers (on caller stream) for (size_t o = 0; o < compiled_engine->output_buffers.size(); o++) { outputs[o].copy_(compiled_engine->output_buffers[o], false); diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 042bf085c8..145972afaa 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -111,8 +111,10 @@ TORCH_LIBRARY(tensorrt, m) { m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; }); - m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; }); - m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; }); + m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; }); + m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void { + CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode); + }); m.def("set_logging_level", [](int64_t level) -> void { util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level)); }); diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index b933e081c7..82b2fb1517 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -8,7 +8,7 @@ namespace core { namespace runtime { bool MULTI_DEVICE_SAFE_MODE = false; -bool CUDAGRAPHS_MODE = false; +CudaGraphsMode CUDAGRAPHS_MODE = STANDARD; c10::optional get_most_compatible_device( const RTDevice& target_device, @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; } -bool get_cudagraphs_mode() { +CudaGraphsMode get_cudagraphs_mode() { return CUDAGRAPHS_MODE; } -void set_cudagraphs_mode(bool cudagraphs_mode) { +void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) { CUDAGRAPHS_MODE = cudagraphs_mode; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 86ba331796..6f1436c745 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -18,7 +18,14 @@ namespace runtime { using EngineID = int64_t; const std::string ABI_VERSION = "6"; extern bool MULTI_DEVICE_SAFE_MODE; -extern bool CUDAGRAPHS_MODE; + +typedef enum { + STANDARD = 0, + SUBGRAPH_CUDAGRAPHS, + WHOLE_GRAPH_CUDAGRAPHS, +} CudaGraphsMode; + +extern CudaGraphsMode CUDAGRAPHS_MODE; typedef enum { ABI_TARGET_IDX = 0, @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode(); void set_multi_device_safe_mode(bool multi_device_safe_mode); -bool get_cudagraphs_mode(); +CudaGraphsMode get_cudagraphs_mode(); -void set_cudagraphs_mode(bool cudagraphs_mode); +void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode); class DeviceList { using DeviceMap = std::unordered_map; diff --git a/docsrc/index.rst b/docsrc/index.rst index 5d88c8ecae..c34e7ff2f2 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -67,6 +67,7 @@ Tutorials * :ref:`custom_kernel_plugins` * :ref:`mutable_torchtrt_module_example` * :ref:`weight_streaming_example` +* :ref:`cudagraphs_wrapper_example` .. toctree:: :caption: Tutorials @@ -84,6 +85,7 @@ Tutorials tutorials/_rendered_examples/dynamo/custom_kernel_plugins tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example tutorials/_rendered_examples/dynamo/weight_streaming_example + tutorials/_rendered_examples/dynamo/cudagraphs_wrapper_example Dynamo Frontend ---------------- diff --git a/examples/dynamo/cudagraphs_wrapper_example.py b/examples/dynamo/cudagraphs_wrapper_example.py new file mode 100644 index 0000000000..386eb62650 --- /dev/null +++ b/examples/dynamo/cudagraphs_wrapper_example.py @@ -0,0 +1,98 @@ +""" +.. _cudagraphs_wrapper_example: + +Wrapped runtime module for cuda graphs +====================================== + +If Torch-TensorRT encounters unsupported operations during compilation, it can fall back to using +PyTorch's native implementation for those specific operations. This fallback mechanism allows the +rest of the model to be executed using TensorRT, while only the unsupported parts are handled by PyTorch. +This fallback results in a graph break, which can reduce the overall performance benefits of using +TensorRT because it introduces additional overhead from switching between TensorRT and PyTorch execution contexts + +Applying CUDA Graphs to a PyTorch module that contains graph breaks can enhance performance by leveraging +the benefits of CUDA Graphs even in the presence of these breaks. Torch-TensorRT provides +wrapper runtime module with CUDA Graphs for modules that have graph breaks allows you to mitigate the +inefficiencies introduced by these breaks +""" + +# %% +# Imports and Model Definition +# ---------------------------------- + +import torch +import torch_tensorrt + + +class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + +model = SampleModel().eval().cuda() +input = torch.randn((1, 3, 224, 224)).to("cuda") + +# %% +# Compiler options +# ---------------------------------- +# +# The 'torch_executed_ops' compiler option is used to demonstrate graph breaks for this example. +# debug=True compiler option provides detailed insights into the compilation process and helps +# pinpoint where graph breaks occur + +# Create a TensorRT-compiled model +trt_model = torch_tensorrt.compile( + model, + ir="dynamo", + inputs=[input], + min_block_size=1, + pass_through_build_failures=True, + debug=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, +) + +# %% +# Compiler log +# ---------------------------------- +# +# This compiler log indicates torch.ops.aten.mul.Tensor operator is executed by PyTorch. +# Peformance of this module can be enhanced by using wrapped module. + +############################################################################## +# .. code-block:: none +# +# ++++++++++++++ Dry-Run Results for Graph +++++++++++++++++ +# +# The graph consists of 3 Total Operators, of which 2 operators are supported, 66.67% coverage +# +# The following ops are currently unsupported or excluded from conversion, and are listed with their op-count in the graph: +# torch.ops.aten.mul.Tensor: 1 +# +# The following nodes are currently set to run in Torch: +# Node: torch.ops.aten.mul.Tensor, with layer location: /mul +# Note: Some of the above nodes may be supported, but were not included in a TRT graph by the partitioner + +# %% +# trt module with cuda graphs +# ---------------------------------- +# +# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional +# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous +# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced +# kernel launch overhead and improved execution efficiency, may be diminished. +with torch_tensorrt.runtime.enable_cudagraphs(): + trt_model(input) + +# %% +# Running wrapped module with cuda graphs +# ---------------------------------- +# +# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs +# that can be executed efficiently, even in the presence of graph breaks. When a CUDA Graph context manager is +# used with the TensorRT module as a positional argument, it returns a wrapped_module. This module captures the +# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads +# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the +# module is executed several times. This warm-up ensures that memory allocations and initializations are not +# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance. +with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as wrapped_module: + wrapped_module(input) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9492b09402..855c75a057 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -12,6 +12,9 @@ from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) from torch_tensorrt.fx import InputTensorSpec from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision @@ -586,7 +589,7 @@ def save( Save the model to disk in the specified output format. Arguments: - module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module + module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module inputs (torch.Tensor): Torch input tensors arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. @@ -594,6 +597,8 @@ def save( retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. This flag is experimental for now. """ + if isinstance(module, WrapperTorchTensorRTModule): + module = module.compiled_module module_type = _parse_module_type(module) accepted_formats = {"exported_program", "torchscript"} if arg_inputs is not None and not all( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 9859668cd9..95d7bbf713 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -36,6 +36,9 @@ post_lowering, pre_export_lowering, ) +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) from torch_tensorrt.dynamo.utils import ( get_flat_args_with_check, get_output_metadata, @@ -373,6 +376,7 @@ def compile( use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, + enable_wrapper_module: bool = _defaults.ENABLE_WRAPPER_MODULE, **kwargs: Any, ) -> torch.fx.GraphModule: """Compile an ExportedProgram module for NVIDIA GPUs using TensorRT @@ -589,6 +593,7 @@ def compile( "use_fp32_acc": use_fp32_acc, "enable_cross_compile_for_windows": False, "enable_weight_streaming": enable_weight_streaming, + "enable_wrapper_module": enable_wrapper_module, } settings = CompilationSettings(**compilation_options) @@ -832,6 +837,10 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: dryrun_stats_display(dryrun_tracker, settings.dryrun) + if settings.enable_wrapper_module: + # Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module. + partitioned_module = WrapperTorchTensorRTModule(partitioned_module) + return partitioned_module diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index f6b97b1fbb..953af35f09 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -44,6 +44,7 @@ USE_FP32_ACC = False ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False +ENABLE_WRAPPER_MODULE = False def default_device() -> Device: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 9062e2e539..c562e65049 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -16,6 +16,7 @@ ENABLE_CROSS_COMPILE_FOR_WINDOWS, ENABLE_EXPERIMENTAL_DECOMPOSITIONS, ENABLE_WEIGHT_STREAMING, + ENABLE_WRAPPER_MODULE, ENABLED_PRECISIONS, ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, @@ -125,6 +126,7 @@ class CompilationSettings: use_fp32_acc: bool = USE_FP32_ACC enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS + enable_wrapper_module: bool = ENABLE_WRAPPER_MODULE _SETTINGS_TO_BE_ENGINE_INVARIANT = ( diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index e31d73f337..bf772d3ee8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -107,6 +107,8 @@ def __init__( self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() + # Previous cuda graphs state + self.prev_cudagraphs_enabled = False if self.serialized_engine is not None and not self.settings.lazy_engine_init: self.setup_engine() @@ -171,7 +173,7 @@ def setup_engine(self) -> None: self.engine.get_tensor_shape(input_name) for input_name in self.input_names ] self.output_dtypes = [ - dtype._from(self.engine.get_tensor_dtype(output_name)) + dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype) for output_name in self.output_names ] self.output_shapes = [ @@ -238,7 +240,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) for i in inputs ] - with ( torch.autograd.profiler.record_function("PythonTorchTensorRTModule:Forward") if self.profiling_enabled @@ -247,9 +248,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._check_initialized() cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() - need_cudagraphs_record = ( - cudagraphs_enabled and not self.cudagraphs_validate_shapes(inputs) + + # Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change + need_cudagraphs_record = cudagraphs_enabled and ( + (not self.prev_cudagraphs_enabled) + or (not self.cudagraphs_validate_shapes(inputs)) ) + self.prev_cudagraphs_enabled = cudagraphs_enabled if need_cudagraphs_record: self._input_buffers = [None] * len(self.input_names) @@ -259,7 +264,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self.cudagraph.reset() self.cudagraph = None - # If in safe mode, check at each iteration for for whether a switch is required + # If in safe mode, check at each iteration for whether a switch is required if ( torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE ): @@ -379,7 +384,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . output = torch.empty( size=shape, - dtype=self.output_dtypes[o].to(torch.dtype), + dtype=self.output_dtypes[o], device=torch.cuda.current_device(), ) diff --git a/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py new file mode 100644 index 0000000000..452004052a --- /dev/null +++ b/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +import logging +from typing import List, Optional, Sequence, Tuple + +import torch +import torch_tensorrt +from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch_tensorrt.dynamo import partitioning + +logger = logging.getLogger(__name__) + + +class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc] + """This Wrapper runtime module is to record/replay whole cuda graph in sub modules + + Args: + compiled_module: Complied fx graphModule that will be wrapped + Returns: + Output tensor or tensor list + """ + + def __init__( + self, + compiled_module: torch.nn.Module, + ): + super(WrapperTorchTensorRTModule, self).__init__() + self.compiled_module = compiled_module + self.inputs = partitioning.construct_submodule_inputs(compiled_module) + + self._input_buffers: List[torch.Tensor] = [] + self._output_buffers: List[torch.Tensor] = [] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.shape_key: Optional[str] = None + self.prev_cudagraphs_enabled = False + self._caller_stream: Optional[torch.cuda.Stream] = None + self._engine_stream: Optional[torch.cuda.Stream] = None + + num_torch_mod = 0 + for name, _ in self.compiled_module.named_children(): + if "_run_on_acc" not in name: + num_torch_mod += 1 + if num_torch_mod > 0: + self.warm_up() + else: + logger.warning( + "Wrapper runtime module provides no benefit for a graph module that doesn't have graph breaks" + ) + + def warm_up(self) -> None: + """ + Warm up is necessary to ensure that memory allocations and initializations + are not recorded in cuda graphs + """ + with torch_tensorrt.logging.errors(): + with unset_fake_temporarily(): + inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs] + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + for _ in range(3): + self.compiled_module(*inputs_tensor) + torch.cuda.current_stream().wait_stream(s) + + def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool: + """ + Validates the input shapes of the forward function has changed + And infer output shapes if dynamic input shape has changed. + """ + # Representation of input shapes to a given model + # Shapes are concatenated as so: + # x: (3, 4), y: (4, 5) --> Key: (3,4)(4,5) + new_shape_key = "".join(str(tuple(t.shape)).replace(" ", "") for t in inputs) + + if new_shape_key != self.shape_key: + logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}") + self.shape_key = new_shape_key + return True + + return False + + def __del__(self) -> None: + if self.cudagraph: + self.cudagraph.reset() + + def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: + cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() + if cudagraphs_enabled: + shape_changed = self.validate_input_shapes(inputs) + # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change + need_cudagraphs_record = not self.prev_cudagraphs_enabled or shape_changed + self.prev_cudagraphs_enabled = cudagraphs_enabled + + if need_cudagraphs_record: + if self.cudagraph: + self.cudagraph.reset() + self._input_buffers = [None] * len(self.inputs) + + # Ensure inputs are available in all scopes and cast symbolic integers to Tensors + contiguous_inputs: List[torch.Tensor] = [ + ( + i.contiguous() + if isinstance(i, torch.Tensor) + else torch.tensor(i).cuda() + ) + for i in inputs + ] + assert len(contiguous_inputs) == len( + self.inputs + ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." + + for i, _ in enumerate(self.inputs): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input[{i}] is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] + ) + + assert ( + contiguous_inputs[i].dtype == self.inputs[i].dtype + ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + self._input_buffers[i] = contiguous_inputs[i].clone() + else: + self._input_buffers[i].copy_(contiguous_inputs[i]) + + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None + ): + self._engine_stream = torch.cuda.Stream() + + self._engine_stream.wait_stream(self._caller_stream) + + with torch.cuda.stream(self._engine_stream): + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): + self._output_buffers = self.compiled_module( + *self._input_buffers + ) + + self.cudagraph.replay() # type: ignore + self._caller_stream.wait_stream(self._engine_stream) + + if isinstance(self._output_buffers, (list, tuple)): + output_buffers = self._output_buffers + else: + output_buffers = [self._output_buffers] + outputs = [output.clone() for output in output_buffers] + if len(outputs) == 1: + return outputs[0] + return outputs + else: + if self.cudagraph: + self.cudagraph.reset() + self.cudagraph = None + return self.compiled_module(*inputs) diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 77b4401222..09a478e807 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -5,6 +5,7 @@ from torch_tensorrt.runtime._cudagraphs import ( enable_cudagraphs, get_cudagraphs_mode, + get_whole_cudagraphs_mode, set_cudagraphs_mode, ) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 9d1523ef2e..29d7ef895e 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -1,13 +1,24 @@ import logging -from typing import Any +from typing import Any, Optional import torch import torch_tensorrt +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) + + +class CudaGraphsMode: + STANDARD = 0 + SUBGRAPH_CUDAGRAPHS = 1 + # Internal mode to apply cuda graphs for wrapped runtime module + WHOLE_GRAPH_CUDAGRAPHS = 2 + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: _PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode() else: - _PY_RT_CUDAGRAPHS = False + _PY_RT_CUDAGRAPHS = CudaGraphsMode.STANDARD logger = logging.getLogger(__name__) @@ -16,19 +27,33 @@ def set_cudagraphs_mode(mode: bool) -> None: # Set new cudagraphs mode for Python global _PY_RT_CUDAGRAPHS - _PY_RT_CUDAGRAPHS = mode + _PY_RT_CUDAGRAPHS = ( + CudaGraphsMode.SUBGRAPH_CUDAGRAPHS if mode else CudaGraphsMode.STANDARD + ) # Set new mode for C++ if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: - torch.ops.tensorrt.set_cudagraphs_mode(mode) + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) logger.info(f"Set Cudagraphs usage to {mode}") +def get_whole_cudagraphs_mode() -> bool: + # check if whole cudagraphs mode is enabled or not + global _PY_RT_CUDAGRAPHS + if _PY_RT_CUDAGRAPHS == CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS: + return True + else: + return False + + def get_cudagraphs_mode() -> bool: # Get cudagraphs mode for Python global _PY_RT_CUDAGRAPHS - return _PY_RT_CUDAGRAPHS # type: ignore + if _PY_RT_CUDAGRAPHS == CudaGraphsMode.SUBGRAPH_CUDAGRAPHS: + return True + else: + return False class _CudagraphsContextManager(object): @@ -37,19 +62,30 @@ class _CudagraphsContextManager(object): Used to enable cudagraphs as a context manager """ - def __init__(self) -> None: + def __init__(self, compiled_module: Optional[torch.nn.Module]) -> None: global _PY_RT_CUDAGRAPHS self.old_mode = _PY_RT_CUDAGRAPHS + self.compiled_module = compiled_module def __enter__(self) -> "_CudagraphsContextManager": - # Enable cudagraphs - set_cudagraphs_mode(True) - return self + global _PY_RT_CUDAGRAPHS + if self.compiled_module: + _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + # Set new mode for C++ + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + return WrapperTorchTensorRTModule(self.compiled_module) + else: + # Enable cudagraphs + set_cudagraphs_mode(True) + return self def __exit__(self, *args: Any) -> None: # Set cudagraphs back to old mode set_cudagraphs_mode(self.old_mode) -def enable_cudagraphs() -> _CudagraphsContextManager: - return _CudagraphsContextManager() +def enable_cudagraphs( + compiled_module: Optional[torch.nn.Module] = None, +) -> _CudagraphsContextManager: + return _CudagraphsContextManager(compiled_module) diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 3a33330fa1..4ec7fb02c5 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -3,6 +3,9 @@ import torch from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) logger = logging.getLogger(__name__) @@ -12,9 +15,14 @@ class _WeightStreamingContextManager(object): Helper class used to setup weight streaming budget """ - def __init__(self, module: torch.fx.GraphModule) -> None: + def __init__( + self, module: torch.fx.GraphModule | WrapperTorchTensorRTModule + ) -> None: rt_mods = [] self.current_device_budget = 0 + + if isinstance(module, WrapperTorchTensorRTModule): + module = module.compiled_module for name, rt_mod in module.named_children(): if "_run_on_acc" in name and isinstance( rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) diff --git a/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py b/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py new file mode 100644 index 0000000000..23ecd07d83 --- /dev/null +++ b/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py @@ -0,0 +1,270 @@ +import torch +import torch_tensorrt as torchtrt +from parameterized import parameterized +from torch.testing._internal.common_utils import TestCase, run_tests +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) + +INPUT_SIZE = (3, 16, 16) +TRIALS = 5 + + +class TestWrapperCudagraphs(TestCase): + @parameterized.expand( + [ + ("python_runtime", True, False), + ("python_runtime_multi_out", True, True), + ("cpp_runtime", False, False), + ("cpp_runtime_multi_out", False, True), + ] + ) + def test_wrapper_cudagraphs(self, _, use_python_runtime, multi_output): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + class SampleModelMultiOutput(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5), torch.relu((x - 2) * 2.1) + + input_list = [] + for _ in range(TRIALS): + input = [torch.randn(*INPUT_SIZE, dtype=torch.float32).cuda()] + input_list.append(input) + + model = SampleModel() if multi_output else SampleModelMultiOutput() + fx_graph = torch.fx.symbolic_trace(model) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + input_list[0], + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=use_python_runtime, + ) + + ref_out_list = [] + trt_out_list = [] + wrapped_module = WrapperTorchTensorRTModule(optimized_model) + for enable_cuda_graphs in [False, True]: + for i in range(len(input_list)): + # Toggles cuda graph at all index in TRIALS + if i % TRIALS == i // TRIALS: + cuda_graphs = enable_cuda_graphs + else: + cuda_graphs = not enable_cuda_graphs + + if cuda_graphs: + torchtrt.runtime._cudagraphs._PY_RT_CUDAGRAPHS = ( + torchtrt.runtime._cudagraphs.CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + ) + else: + torchtrt.runtime.set_cudagraphs_mode(False) + + trt_out_list.append(wrapped_module(*input_list[i])) + ref_out_list.append(fx_graph(*input_list[i])) + + for optimized_model_results, torch_model_results in zip( + trt_out_list, ref_out_list + ): + torch.testing.assert_close( + torch_model_results, + optimized_model_results, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset() + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_wrapper_cudagraphs_dynamic(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + inputs = torchtrt.Input( + min_shape=(1, 3, 128, 224), + opt_shape=(8, 3, 192, 224), + max_shape=(16, 3, 224, 224), + dtype=torch.float, + name="x", + ) + fx_graph = torch.fx.symbolic_trace(SampleModel()) + + optimized_model = torchtrt.compile( + fx_graph, + "dynamo", + inputs, + min_block_size=1, + pass_through_build_failures=True, + torch_executed_ops={"torch.ops.aten.mul.Tensor"}, + use_python_runtime=use_python_runtime, + ) + + input_list = [] + ref_out_list = [] + trt_out_list = [] + # Alternating cuda_graphs enable and input shapes at every five iterations. + for i in [1, 3, 8, 11, 16]: + for j in [128, 128, 222, 222, 224]: + input_list.append(torch.randn((i, 3, j, 224)).cuda()) + + wrapped_module = WrapperTorchTensorRTModule(optimized_model) + for enable_cuda_graphs in [False, True]: + for i in range(len(input_list)): + # Toggles cuda graph at all index in TRIALS + if i % TRIALS == i // TRIALS: + cuda_graphs = enable_cuda_graphs + else: + cuda_graphs = not enable_cuda_graphs + + if cuda_graphs: + torchtrt.runtime._cudagraphs._PY_RT_CUDAGRAPHS = ( + torchtrt.runtime._cudagraphs.CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + ) + else: + torchtrt.runtime.set_cudagraphs_mode(False) + trt_out_list.append(wrapped_module(input_list[i])) + trt_out_list.append(fx_graph(input_list[i])) + + for optimized_model_results, torch_model_results in zip( + trt_out_list, ref_out_list + ): + torch.testing.assert_close( + torch_model_results, + optimized_model_results, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset() + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_wrapper_cudagraphs_conv(self, _, use_python_runtime): + """ + Graph break at torch convolution that may have memory allocation + and it's not expected to be recorded. + """ + + class SampleModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv1d(64, 6, 3) + self.relu = torch.nn.ReLU() + + def forward(self, x): + out = 1 + self.conv(x) + out = self.relu(out) + return out + + model = SampleModel().eval().cuda() + input_list = [] + trt_out_list = [] + ref_out_list = [] + + for _ in range(TRIALS): + input = [torch.randn((64, 32), dtype=torch.float32).cuda()] + input_list.append(input) + fx_graph = torch.fx.symbolic_trace(model) + + optimized_model = torchtrt.compile( + fx_graph, + inputs=input_list[0], + ir="dynamo", + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + torch_executed_ops={"torch.ops.aten.convolution.default"}, + use_python_runtime=use_python_runtime, + ) + + wrapped_module = WrapperTorchTensorRTModule(optimized_model) + torchtrt.runtime._cudagraphs._PY_RT_CUDAGRAPHS = ( + torchtrt.runtime._cudagraphs.CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + ) + for i in range(TRIALS): + trt_out_list.append(wrapped_module(*input_list[i])) + ref_out_list.append(fx_graph(*input_list[i])) + + for optimized_model_results, torch_model_results in zip( + trt_out_list, ref_out_list + ): + torch.testing.assert_close( + torch_model_results, + optimized_model_results, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset() + + @parameterized.expand( + [ + ("python_runtime", True), + ("cpp_runtime", False), + ] + ) + def test_wrapper_cudagraphs_api(self, _, use_python_runtime): + class SampleModel(torch.nn.Module): + def forward(self, x): + return torch.relu((x + 2) * 0.5) + + model = SampleModel().eval().cuda() + input_list = [] + trt_out_list = [] + ref_out_list = [] + + for _ in range(TRIALS): + input = [torch.randn((64, 32), dtype=torch.float32).cuda()] + input_list.append(input) + fx_graph = torch.fx.symbolic_trace(model) + + optimized_model = torchtrt.compile( + fx_graph, + inputs=input_list[0], + ir="dynamo", + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + torch_executed_ops={"torch.ops.aten.convolution.default"}, + use_python_runtime=use_python_runtime, + ) + + with torchtrt.runtime.enable_cudagraphs(optimized_model) as wrapped_module: + for i in range(TRIALS): + trt_out_list.append(wrapped_module(*input_list[i])) + ref_out_list.append(fx_graph(*input_list[i])) + + for optimized_model_results, torch_model_results in zip( + trt_out_list, ref_out_list + ): + torch.testing.assert_close( + torch_model_results, + optimized_model_results, + rtol=5e-03, + atol=5e-03, + equal_nan=True, + check_dtype=True, + ) + torch._dynamo.reset() + + +if __name__ == "__main__": + run_tests()