diff --git a/core/runtime/TRTEngine.cpp b/core/runtime/TRTEngine.cpp index 36ab06a4f6..5a5c1ad83d 100644 --- a/core/runtime/TRTEngine.cpp +++ b/core/runtime/TRTEngine.cpp @@ -213,10 +213,6 @@ TRTEngine::TRTEngine( LOG_DEBUG(*this); } -void TRTEngine::set_whole_cudagraphs(bool enable) { - whole_cudagraphs = enable; -} - TRTEngine::~TRTEngine() { trt_engine_profiler.reset(); exec_ctx.reset(); diff --git a/core/runtime/TRTEngine.h b/core/runtime/TRTEngine.h index fb922da18a..7560660d81 100644 --- a/core/runtime/TRTEngine.h +++ b/core/runtime/TRTEngine.h @@ -87,7 +87,6 @@ struct TRTEngine : torch::CustomClassHolder { bool set_device_memory_budget(int64_t budget); int64_t get_streamable_device_memory_budget(); int64_t get_automatic_device_memory_budget(); - void set_whole_cudagraphs(bool enable); std::vector infer_outputs(std::vector> input_shapes); friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine); static const char BINDING_DELIM = '%'; @@ -104,13 +103,12 @@ struct TRTEngine : torch::CustomClassHolder { std::vector output_buffers = {}; std::string shape_key; bool prev_cudagraphs_enabled = false; - bool whole_cudagraphs = false; // TODO: Implement a call method // c10::List Run(c10::List inputs); void set_profiling_paths(); #ifndef NDEBUG - bool profile_execution = true; + bool profile_execution = false; #else bool profile_execution = false; #endif diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 460a9cb221..3a2bf0fd38 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -113,7 +113,7 @@ std::vector execute_engine(std::vector inputs, c10::intr LOG_INFO("" << log_info); compiled_engine->cudagraph.enable_debug_mode(); } - bool cudagraphs_enabled = (!compiled_engine->whole_cudagraphs && CUDAGRAPHS_MODE); + bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); // Whether cudagraphs needs to record the graph on this pass // Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change diff --git a/core/runtime/register_jit_hooks.cpp b/core/runtime/register_jit_hooks.cpp index 3a6f7605b5..145972afaa 100644 --- a/core/runtime/register_jit_hooks.cpp +++ b/core/runtime/register_jit_hooks.cpp @@ -87,7 +87,6 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion = .def("dump_engine_layer_info_to_file", &TRTEngine::dump_engine_layer_info_to_file) .def("dump_engine_layer_info", &TRTEngine::dump_engine_layer_info) .def("get_engine_layer_info", &TRTEngine::get_engine_layer_info) - .def("set_whole_cudagraphs", &TRTEngine::set_whole_cudagraphs) .def("infer_outputs", &TRTEngine::infer_outputs) .def_property( "device_memory_budget", @@ -112,8 +111,10 @@ TORCH_LIBRARY(tensorrt, m) { m.def("set_multi_device_safe_mode", [](bool multi_device_safe_mode) -> void { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; }); - m.def("get_cudagraphs_mode", []() -> bool { return CUDAGRAPHS_MODE; }); - m.def("set_cudagraphs_mode", [](bool cudagraphs_mode) -> void { CUDAGRAPHS_MODE = cudagraphs_mode; }); + m.def("get_cudagraphs_mode", []() -> int64_t { return CUDAGRAPHS_MODE; }); + m.def("set_cudagraphs_mode", [](int64_t cudagraphs_mode) -> void { + CUDAGRAPHS_MODE = CudaGraphsMode(cudagraphs_mode); + }); m.def("set_logging_level", [](int64_t level) -> void { util::logging::get_logger().set_reportable_log_level(util::logging::LogLevel(level)); }); diff --git a/core/runtime/runtime.cpp b/core/runtime/runtime.cpp index b933e081c7..82b2fb1517 100644 --- a/core/runtime/runtime.cpp +++ b/core/runtime/runtime.cpp @@ -8,7 +8,7 @@ namespace core { namespace runtime { bool MULTI_DEVICE_SAFE_MODE = false; -bool CUDAGRAPHS_MODE = false; +CudaGraphsMode CUDAGRAPHS_MODE = STANDARD; c10::optional get_most_compatible_device( const RTDevice& target_device, @@ -130,11 +130,11 @@ void set_multi_device_safe_mode(bool multi_device_safe_mode) { MULTI_DEVICE_SAFE_MODE = multi_device_safe_mode; } -bool get_cudagraphs_mode() { +CudaGraphsMode get_cudagraphs_mode() { return CUDAGRAPHS_MODE; } -void set_cudagraphs_mode(bool cudagraphs_mode) { +void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode) { CUDAGRAPHS_MODE = cudagraphs_mode; } diff --git a/core/runtime/runtime.h b/core/runtime/runtime.h index 86ba331796..6f1436c745 100644 --- a/core/runtime/runtime.h +++ b/core/runtime/runtime.h @@ -18,7 +18,14 @@ namespace runtime { using EngineID = int64_t; const std::string ABI_VERSION = "6"; extern bool MULTI_DEVICE_SAFE_MODE; -extern bool CUDAGRAPHS_MODE; + +typedef enum { + STANDARD = 0, + SUBGRAPH_CUDAGRAPHS, + WHOLE_GRAPH_CUDAGRAPHS, +} CudaGraphsMode; + +extern CudaGraphsMode CUDAGRAPHS_MODE; typedef enum { ABI_TARGET_IDX = 0, @@ -51,9 +58,9 @@ bool get_multi_device_safe_mode(); void set_multi_device_safe_mode(bool multi_device_safe_mode); -bool get_cudagraphs_mode(); +CudaGraphsMode get_cudagraphs_mode(); -void set_cudagraphs_mode(bool cudagraphs_mode); +void set_cudagraphs_mode(CudaGraphsMode cudagraphs_mode); class DeviceList { using DeviceMap = std::unordered_map; diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 304263006f..bf772d3ee8 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -107,8 +107,6 @@ def __init__( self.engine = None self.weight_name_map = weight_name_map self.target_platform = Platform.current_platform() - # Check if CUDA graph capture is enabled in the parent node - self.whole_cudagraphs = False # Previous cuda graphs state self.prev_cudagraphs_enabled = False @@ -151,14 +149,6 @@ def set_default_device_memory_budget(self) -> int: logger.debug(f"Weight streaming budget set to {budget_bytes}B") return self._set_device_memory_budget(budget_bytes) - def set_whole_cudagraphs(self, enable: bool) -> None: - """ - When the global CUDA graphs mode is enabled, the parent wrapper module handles all - CUDA graph recording and replay. Therefore, any child modules must disable their - own CUDA graph functionality to avoid conflicts. - """ - self.whole_cudagraphs = enable - def setup_engine(self) -> None: assert ( self.target_platform == Platform.current_platform() @@ -257,10 +247,8 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . ): self._check_initialized() - cudagraphs_enabled = ( - torch_tensorrt.runtime.get_cudagraphs_mode() - and not self.whole_cudagraphs - ) + cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() + # Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change need_cudagraphs_record = cudagraphs_enabled and ( (not self.prev_cudagraphs_enabled) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 9cad1c8994..1bebe20fda 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -195,14 +195,6 @@ def set_device_memory_budget(self, budget_bytes: int) -> int: return budget_bytes - def set_whole_cudagraphs(self, enable: bool) -> None: - """ - When the global CUDA graphs mode is enabled, the parent wrapper module handles all - CUDA graph recording and replay. Therefore, any child modules must disable their - own CUDA graph functionality to avoid conflicts. - """ - self.engine.set_whole_cudagraphs(enable) - def setup_engine(self) -> None: """ Setup engine for a module which has deferred engine setup. diff --git a/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py index 1ac7e14eb6..452004052a 100644 --- a/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py @@ -1,15 +1,12 @@ from __future__ import annotations import logging -from contextlib import nullcontext -from tempfile import tempdir from typing import List, Optional, Sequence, Tuple import torch import torch_tensorrt from torch.fx.experimental.proxy_tensor import unset_fake_temporarily from torch_tensorrt.dynamo import partitioning -from torch_tensorrt.runtime._utils import _is_switch_required, _select_rt_device logger = logging.getLogger(__name__) @@ -35,16 +32,20 @@ def __init__( self._output_buffers: List[torch.Tensor] = [] self.cudagraph: Optional[torch.cuda.CUDAGraph] = None self.shape_key: Optional[str] = None - self.profiling_enabled = False self.prev_cudagraphs_enabled = False self._caller_stream: Optional[torch.cuda.Stream] = None self._engine_stream: Optional[torch.cuda.Stream] = None - # Disable cudagrphs in submodules as it will be enabled in wrapper - for name, rt_mod in self.compiled_module.named_children(): - if "_run_on_acc" in name: - rt_mod.set_whole_cudagraphs(True) - self.warm_up() + num_torch_mod = 0 + for name, _ in self.compiled_module.named_children(): + if "_run_on_acc" not in name: + num_torch_mod += 1 + if num_torch_mod > 0: + self.warm_up() + else: + logger.warning( + "Wrapper runtime module provides no benefit for a graph module that doesn't have graph breaks" + ) def warm_up(self) -> None: """ @@ -83,157 +84,86 @@ def __del__(self) -> None: self.cudagraph.reset() def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, ...]: - # Ensure inputs are available in all scopes and cast symbolic integers to Tensors - contiguous_inputs: List[torch.Tensor] = [ - (i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda()) - for i in inputs - ] - with ( - torch.autograd.profiler.record_function( - "WrapperTorchTensorRTModule:Forward" - ) - if self.profiling_enabled - else nullcontext() - ): + cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() + if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) - cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode() # Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change - need_cudagraphs_record = cudagraphs_enabled and ( - (not self.prev_cudagraphs_enabled) or shape_changed - ) + need_cudagraphs_record = not self.prev_cudagraphs_enabled or shape_changed self.prev_cudagraphs_enabled = cudagraphs_enabled if need_cudagraphs_record: if self.cudagraph: self.cudagraph.reset() - self._input_buffers = [None] * len(self.inputs) - if not cudagraphs_enabled and self.cudagraph: - self.cudagraph.reset() - self.cudagraph = None - - # If in safe mode, check at each iteration for for whether a switch is required - if ( - torch_tensorrt.runtime._multi_device_safe_mode._PY_RT_MULTI_DEVICE_SAFE_MODE - ): - curr_device_id = torch.cuda.current_device() - curr_device_properties = torch.cuda.get_device_properties( - curr_device_id + # Ensure inputs are available in all scopes and cast symbolic integers to Tensors + contiguous_inputs: List[torch.Tensor] = [ + ( + i.contiguous() + if isinstance(i, torch.Tensor) + else torch.tensor(i).cuda() ) - logger.debug(f"Current Device: cuda:{curr_device_id}") - - # If a switch is required, move all inputs to new device and set as active device - if _is_switch_required( - curr_device_id, - self.target_device_id, - curr_device_properties, - self.target_device_properties, - ): - device_id, _ = _select_rt_device( - curr_device_id, - self.target_device_id, - self.target_device_properties, + for i in inputs + ] + assert len(contiguous_inputs) == len( + self.inputs + ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." + + for i, _ in enumerate(self.inputs): + if not contiguous_inputs[i].is_cuda: + logger.warning( + f"Detected input[{i}] is not on a cuda device. " + "This tensor is being moved by the runtime but for performance considerations, " + "ensure your inputs are all on GPU and open an issue here " + "(https://github.com/pytorch/TensorRT/issues) if this warning persists." + ) + contiguous_inputs = ( + contiguous_inputs[:i] + + [contiguous_inputs[i].cuda()] + + contiguous_inputs[i + 1 :] ) - # Update current device - device = torch.device(device_id) - torch.cuda.set_device(device_id) + assert ( + contiguous_inputs[i].dtype == self.inputs[i].dtype + ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." - contiguous_inputs = [ - tensor.to(device) for tensor in contiguous_inputs - ] - logger.warning(f"Moved all input Tensors to cuda:{device_id}") + if need_cudagraphs_record: + # If cudagraphs is enabled, this memory is reserved for future cudagraph runs + # Clone is required to avoid re-using user-provided GPU memory + self._input_buffers[i] = contiguous_inputs[i].clone() + else: + self._input_buffers[i].copy_(contiguous_inputs[i]) - with ( - torch.autograd.profiler.record_function( - "WrapperTorchTensorRTModule:ProcessInputs" - ) - if self.profiling_enabled - else nullcontext() + self._caller_stream = torch.cuda.current_stream() + if ( + self._engine_stream == torch.cuda.default_stream() + or self._engine_stream is None ): - assert len(contiguous_inputs) == len( - self.inputs - ), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}." - - for i, _ in enumerate(self.inputs): - if not contiguous_inputs[i].is_cuda: - logger.warning( - f"Detected input[{i}] of engine {self.engine.name} is not on a cuda device. " - "This tensor is being moved by the runtime but for performance considerations, " - "ensure your inputs are all on GPU and open an issue here " - "(https://github.com/pytorch/TensorRT/issues) if this warning persists." - ) - contiguous_inputs = ( - contiguous_inputs[:i] - + [contiguous_inputs[i].cuda()] - + contiguous_inputs[i + 1 :] - ) + self._engine_stream = torch.cuda.Stream() - assert ( - contiguous_inputs[i].dtype == self.inputs[i].dtype - ), f"Dtype mismatch for {i}th input. Expect {self.inputs[i].dtype}, got {contiguous_inputs[i].dtype}." + self._engine_stream.wait_stream(self._caller_stream) - if need_cudagraphs_record: - # If cudagraphs is enabled, this memory is reserved for future cudagraph runs - # Clone is required to avoid re-using user-provided GPU memory - self._input_buffers[i] = contiguous_inputs[i].clone() - elif cudagraphs_enabled: - self._input_buffers[i].copy_(contiguous_inputs[i]) + with torch.cuda.stream(self._engine_stream): + if need_cudagraphs_record: + self.cudagraph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.cudagraph, stream=self._engine_stream): + self._output_buffers = self.compiled_module( + *self._input_buffers + ) - with ( - torch.autograd.profiler.record_function( - "WrapperTorchTensorRTModule:TensorRTRuntime" - ) - if self.profiling_enabled - else nullcontext() - ): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) - - with torch.cuda.stream(self._engine_stream): - if cudagraphs_enabled: - if need_cudagraphs_record: - self.cudagraph = torch.cuda.CUDAGraph() - - if self.profiling_enabled: - self.cudagraph.enable_debug_mode() - with torch.cuda.graph( - self.cudagraph, stream=self._engine_stream - ): - self._output_buffers = self.compiled_module( - *self._input_buffers - ) - - if self.profiling_enabled: - import tempfile - - with tempfile.TemporaryDirectory() as tmpdir: - self.cudagraph.debug_dump( - f"{tempdir}/{self.name}_cudagraph.dot" - ) - self.cudagraph.replay() # type: ignore - - else: - outputs = self.compiled_module(*inputs) - - self._caller_stream.wait_stream(self._engine_stream) - - if cudagraphs_enabled: - if isinstance(self._output_buffers, (list, tuple)): - output_buffers = self._output_buffers - else: - output_buffers = [self._output_buffers] - outputs = [output.clone() for output in output_buffers] - if len(outputs) == 1: - return outputs[0] - - return outputs + self.cudagraph.replay() # type: ignore + self._caller_stream.wait_stream(self._engine_stream) + + if isinstance(self._output_buffers, (list, tuple)): + output_buffers = self._output_buffers else: - return outputs + output_buffers = [self._output_buffers] + outputs = [output.clone() for output in output_buffers] + if len(outputs) == 1: + return outputs[0] + return outputs + else: + if self.cudagraph: + self.cudagraph.reset() + self.cudagraph = None + return self.compiled_module(*inputs) diff --git a/py/torch_tensorrt/runtime/__init__.py b/py/torch_tensorrt/runtime/__init__.py index 77b4401222..09a478e807 100644 --- a/py/torch_tensorrt/runtime/__init__.py +++ b/py/torch_tensorrt/runtime/__init__.py @@ -5,6 +5,7 @@ from torch_tensorrt.runtime._cudagraphs import ( enable_cudagraphs, get_cudagraphs_mode, + get_whole_cudagraphs_mode, set_cudagraphs_mode, ) from torch_tensorrt.runtime._multi_device_safe_mode import set_multi_device_safe_mode diff --git a/py/torch_tensorrt/runtime/_cudagraphs.py b/py/torch_tensorrt/runtime/_cudagraphs.py index 7949a5d4a3..29d7ef895e 100644 --- a/py/torch_tensorrt/runtime/_cudagraphs.py +++ b/py/torch_tensorrt/runtime/_cudagraphs.py @@ -7,10 +7,18 @@ WrapperTorchTensorRTModule, ) + +class CudaGraphsMode: + STANDARD = 0 + SUBGRAPH_CUDAGRAPHS = 1 + # Internal mode to apply cuda graphs for wrapped runtime module + WHOLE_GRAPH_CUDAGRAPHS = 2 + + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: _PY_RT_CUDAGRAPHS = torch.ops.tensorrt.get_cudagraphs_mode() else: - _PY_RT_CUDAGRAPHS = False + _PY_RT_CUDAGRAPHS = CudaGraphsMode.STANDARD logger = logging.getLogger(__name__) @@ -19,19 +27,33 @@ def set_cudagraphs_mode(mode: bool) -> None: # Set new cudagraphs mode for Python global _PY_RT_CUDAGRAPHS - _PY_RT_CUDAGRAPHS = mode + _PY_RT_CUDAGRAPHS = ( + CudaGraphsMode.SUBGRAPH_CUDAGRAPHS if mode else CudaGraphsMode.STANDARD + ) # Set new mode for C++ if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: - torch.ops.tensorrt.set_cudagraphs_mode(mode) + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) logger.info(f"Set Cudagraphs usage to {mode}") +def get_whole_cudagraphs_mode() -> bool: + # check if whole cudagraphs mode is enabled or not + global _PY_RT_CUDAGRAPHS + if _PY_RT_CUDAGRAPHS == CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS: + return True + else: + return False + + def get_cudagraphs_mode() -> bool: # Get cudagraphs mode for Python global _PY_RT_CUDAGRAPHS - return _PY_RT_CUDAGRAPHS # type: ignore + if _PY_RT_CUDAGRAPHS == CudaGraphsMode.SUBGRAPH_CUDAGRAPHS: + return True + else: + return False class _CudagraphsContextManager(object): @@ -40,17 +62,22 @@ class _CudagraphsContextManager(object): Used to enable cudagraphs as a context manager """ - def __init__(self, module_to_wrap: Optional[torch.nn.Module]) -> None: + def __init__(self, compiled_module: Optional[torch.nn.Module]) -> None: global _PY_RT_CUDAGRAPHS self.old_mode = _PY_RT_CUDAGRAPHS - self.module_to_wrap = module_to_wrap + self.compiled_module = compiled_module def __enter__(self) -> "_CudagraphsContextManager": - # Enable cudagraphs - set_cudagraphs_mode(True) - if self.module_to_wrap: - return WrapperTorchTensorRTModule(self.module_to_wrap) + global _PY_RT_CUDAGRAPHS + if self.compiled_module: + _PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + # Set new mode for C++ + if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime: + torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS) + return WrapperTorchTensorRTModule(self.compiled_module) else: + # Enable cudagraphs + set_cudagraphs_mode(True) return self def __exit__(self, *args: Any) -> None: @@ -59,6 +86,6 @@ def __exit__(self, *args: Any) -> None: def enable_cudagraphs( - module_to_wrap: Optional[torch.nn.Module] = None, + compiled_module: Optional[torch.nn.Module] = None, ) -> _CudagraphsContextManager: - return _CudagraphsContextManager(module_to_wrap) + return _CudagraphsContextManager(compiled_module) diff --git a/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py b/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py index ac404aa93d..23ecd07d83 100644 --- a/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py +++ b/tests/py/dynamo/runtime/test_005_wrapper_cudagraphs.py @@ -49,7 +49,7 @@ def forward(self, x): ref_out_list = [] trt_out_list = [] - + wrapped_module = WrapperTorchTensorRTModule(optimized_model) for enable_cuda_graphs in [False, True]: for i in range(len(input_list)): # Toggles cuda graph at all index in TRIALS @@ -57,9 +57,15 @@ def forward(self, x): cuda_graphs = enable_cuda_graphs else: cuda_graphs = not enable_cuda_graphs - torchtrt.runtime.set_cudagraphs_mode(cuda_graphs) - trt_out_list.append(optimized_model(*input_list[i])) + if cuda_graphs: + torchtrt.runtime._cudagraphs._PY_RT_CUDAGRAPHS = ( + torchtrt.runtime._cudagraphs.CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + ) + else: + torchtrt.runtime.set_cudagraphs_mode(False) + + trt_out_list.append(wrapped_module(*input_list[i])) ref_out_list.append(fx_graph(*input_list[i])) for optimized_model_results, torch_model_results in zip( @@ -113,6 +119,7 @@ def forward(self, x): for j in [128, 128, 222, 222, 224]: input_list.append(torch.randn((i, 3, j, 224)).cuda()) + wrapped_module = WrapperTorchTensorRTModule(optimized_model) for enable_cuda_graphs in [False, True]: for i in range(len(input_list)): # Toggles cuda graph at all index in TRIALS @@ -120,10 +127,15 @@ def forward(self, x): cuda_graphs = enable_cuda_graphs else: cuda_graphs = not enable_cuda_graphs - torchtrt.runtime.set_cudagraphs_mode(cuda_graphs) + if cuda_graphs: + torchtrt.runtime._cudagraphs._PY_RT_CUDAGRAPHS = ( + torchtrt.runtime._cudagraphs.CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + ) + else: + torchtrt.runtime.set_cudagraphs_mode(False) + trt_out_list.append(wrapped_module(input_list[i])) trt_out_list.append(fx_graph(input_list[i])) - ref_out_list.append(optimized_model(input_list[i])) for optimized_model_results, torch_model_results in zip( trt_out_list, ref_out_list @@ -182,9 +194,12 @@ def forward(self, x): use_python_runtime=use_python_runtime, ) - torchtrt.runtime.set_cudagraphs_mode(True) + wrapped_module = WrapperTorchTensorRTModule(optimized_model) + torchtrt.runtime._cudagraphs._PY_RT_CUDAGRAPHS = ( + torchtrt.runtime._cudagraphs.CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS + ) for i in range(TRIALS): - trt_out_list.append(optimized_model(*input_list[i])) + trt_out_list.append(wrapped_module(*input_list[i])) ref_out_list.append(fx_graph(*input_list[i])) for optimized_model_results, torch_model_results in zip( @@ -207,20 +222,9 @@ def forward(self, x): ] ) def test_wrapper_cudagraphs_api(self, _, use_python_runtime): - """ - 3 api draft - """ - class SampleModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.conv = torch.nn.Conv1d(64, 6, 3) - self.relu = torch.nn.ReLU() - def forward(self, x): - out = 1 + self.conv(x) - out = self.relu(out) - return out + return torch.relu((x + 2) * 0.5) model = SampleModel().eval().cuda() input_list = [] @@ -232,7 +236,6 @@ def forward(self, x): input_list.append(input) fx_graph = torch.fx.symbolic_trace(model) - # 1. Compiler option: enable_wrapper_module=True optimized_model = torchtrt.compile( fx_graph, inputs=input_list[0], @@ -242,45 +245,13 @@ def forward(self, x): reuse_cached_engines=False, torch_executed_ops={"torch.ops.aten.convolution.default"}, use_python_runtime=use_python_runtime, - enable_wrapper_module=True, ) - with torchtrt.runtime.enable_cudagraphs(): - for i in range(TRIALS): - trt_out_list.append(optimized_model(*input_list[i])) - ref_out_list.append(fx_graph(*input_list[i])) - - # Compiler again to generate normal module - optimized_model = torchtrt.compile( - fx_graph, - inputs=input_list[0], - ir="dynamo", - min_block_size=1, - cache_built_engines=False, - reuse_cached_engines=False, - torch_executed_ops={"torch.ops.aten.convolution.default"}, - use_python_runtime=use_python_runtime, - ) - # This is current cuda runtime api - with torchtrt.runtime.enable_cudagraphs(): - for i in range(TRIALS): - trt_out_list.append(optimized_model(*input_list[i])) - ref_out_list.append(fx_graph(*input_list[i])) - - # 2. Optional parameter in existing cuda runtime api - # WrapperTorchTensorRTModule can be simplified to have only cuda graph path with torchtrt.runtime.enable_cudagraphs(optimized_model) as wrapped_module: for i in range(TRIALS): trt_out_list.append(wrapped_module(*input_list[i])) ref_out_list.append(fx_graph(*input_list[i])) - # 3. Use Wrapper module directly - wrapped_module = WrapperTorchTensorRTModule(optimized_model) - with torchtrt.runtime.enable_cudagraphs(): - for i in range(TRIALS): - trt_out_list.append(wrapped_module(*input_list[i])) - ref_out_list.append(fx_graph(*input_list[i])) - for optimized_model_results, torch_model_results in zip( trt_out_list, ref_out_list ):