diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 9492b09402..521acc17f4 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -12,6 +12,9 @@ from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) from torch_tensorrt.fx import InputTensorSpec from torch_tensorrt.fx.lower import compile as fx_compile from torch_tensorrt.fx.utils import LowerPrecision @@ -586,7 +589,7 @@ def save( Save the model to disk in the specified output format. Arguments: - module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module + module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module inputs (torch.Tensor): Torch input tensors arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs. kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function. @@ -594,6 +597,8 @@ def save( retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it. This flag is experimental for now. """ + if isinstance(module, WrapperTorchTensorRTModule): + module = module.original_module module_type = _parse_module_type(module) accepted_formats = {"exported_program", "torchscript"} if arg_inputs is not None and not all( diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 5137161840..fc4262e603 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -835,7 +835,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: dryrun_stats_display(dryrun_tracker, settings.dryrun) - if len(trt_modules) > 1: + if len(dryrun_tracker.to_run_in_torch) > 0: # Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module. partitioned_module = WrapperTorchTensorRTModule( partitioned_module, diff --git a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py index 35a5b1a306..9ceaf63ebc 100644 --- a/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py @@ -78,6 +78,7 @@ def __init__( self.cudagraph: Optional[torch.cuda.CUDAGraph] = None self._caller_stream: Optional[torch.cuda.Stream] = None self._engine_stream: Optional[torch.cuda.Stream] = None + # TODO: Make the below a Dictionary {shape: cudagraph} self.shape_key: Optional[str] = None diff --git a/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py index 0a8ec3a85c..85fe3fb60f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py @@ -46,7 +46,7 @@ def __init__( if "_run_on_acc" in name: rt_mod.set_cudagraphs_enabled_parent_module(True) - # TODO: check if only torch needs warm up. + # Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs with unset_fake_temporarily(): inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs] s = torch.cuda.Stream() @@ -256,7 +256,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, . self._caller_stream.wait_stream(self._engine_stream) if cudagraphs_enabled: - # TODO: submodule to return list only if isinstance(self._output_buffers, (list, tuple)): output_buffers = self._output_buffers else: diff --git a/py/torch_tensorrt/runtime/_weight_streaming.py b/py/torch_tensorrt/runtime/_weight_streaming.py index 3a33330fa1..4b1e192138 100755 --- a/py/torch_tensorrt/runtime/_weight_streaming.py +++ b/py/torch_tensorrt/runtime/_weight_streaming.py @@ -3,6 +3,9 @@ import torch from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule +from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import ( + WrapperTorchTensorRTModule, +) logger = logging.getLogger(__name__) @@ -12,9 +15,14 @@ class _WeightStreamingContextManager(object): Helper class used to setup weight streaming budget """ - def __init__(self, module: torch.fx.GraphModule) -> None: + def __init__( + self, module: torch.fx.GraphModule | WrapperTorchTensorRTModule + ) -> None: rt_mods = [] self.current_device_budget = 0 + + if isinstance(module, WrapperTorchTensorRTModule): + module = module.original_module for name, rt_mod in module.named_children(): if "_run_on_acc" in name and isinstance( rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) diff --git a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py index ac5909604c..4f962083a8 100644 --- a/tests/py/dynamo/runtime/test_002_cudagraphs_py.py +++ b/tests/py/dynamo/runtime/test_002_cudagraphs_py.py @@ -158,76 +158,6 @@ def forward(self, x): msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", ) - def test_cudagraphs_dynamic_py(self): - class SampleModel(torch.nn.Module): - def forward(self, x): - return torch.relu((x + 2) * 0.5) - - # TODO: more dynamic dim - # TODO: multiple output - # TODO: module that graph cannot be used - inputs = torch_tensorrt.Input( - min_shape=(1, 3, 224, 224), - opt_shape=(8, 3, 224, 224), - max_shape=(16, 3, 224, 224), - dtype=torch.float, - name="x", - ) - fx_graph = torch.fx.symbolic_trace(SampleModel()) - - # Validate that the results between Torch and Torch-TRT are similar - optimized_model = torch_tensorrt.compile( - fx_graph, - "dynamo", - inputs, - min_block_size=1, - pass_through_build_failures=True, - torch_executed_ops={"torch.ops.aten.mul.Tensor"}, - use_python_runtime=True, - ) - - result_samples = [] - torch_results_samples = [] - - inputs = [] - for i in [1, 3, 8, 11, 16]: - inputs.append(torch.randn((i, 3, 224, 224)).cuda()) - - for n in range(len(inputs) * TRIALS): - i = n // TRIALS - # disable cuda graph at all index for all trials - if n % TRIALS == n // TRIALS: - torch_tensorrt.runtime.set_cudagraphs_mode(False) - else: - torch_tensorrt.runtime.set_cudagraphs_mode(True) - - result_samples.append(optimized_model(inputs[i]).detach().cpu()) - torch_results_samples.append(fx_graph(inputs[i]).detach().cpu()) - - for n in range(len(inputs) * TRIALS): - i = n // TRIALS - # enable cuda graph at all index for all trials - if n % TRIALS == n // TRIALS: - torch_tensorrt.runtime.set_cudagraphs_mode(True) - else: - torch_tensorrt.runtime.set_cudagraphs_mode(False) - - result_samples.append(optimized_model(inputs[i]).detach().cpu()) - torch_results_samples.append(fx_graph(inputs[i]).detach().cpu()) - - for i, (optimized_model_results, torch_model_results) in enumerate( - zip(result_samples, torch_results_samples) - ): - max_diff = float( - torch.max(torch.abs(optimized_model_results - torch_model_results)) - ) - self.assertAlmostEqual( - max_diff, - 0, - DECIMALS_OF_AGREEMENT, - msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})", - ) - if __name__ == "__main__": run_tests()