Skip to content

Commit

Permalink
chore: cleanup in WrapperTorchTensorRTModule
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Nov 6, 2024
1 parent ecee5a6 commit 0c6b6cd
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 30 deletions.
4 changes: 3 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
if len(trt_modules) > 1:
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
partitioned_module = WrapperTorchTensorRTModule(
partitioned_module, dryrun_tracker.output_dtypes
partitioned_module,
dryrun_tracker.output_shapes,
dryrun_tracker.output_dtypes,
)

return partitioned_module
Expand Down
80 changes: 51 additions & 29 deletions py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from __future__ import annotations

import logging
from contextlib import nullcontext
from tempfile import tempdir
from typing import List, Optional, Sequence, Tuple

import nvtx
import torch
import torch_tensorrt
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt.dynamo import partitioning
from torch_tensorrt.dynamo.conversion import DYNAMIC_DIM
from torch_tensorrt.dynamo.utils import input_is_dynamic
from torch_tensorrt.runtime._utils import _is_switch_required, _select_rt_device

logger = logging.getLogger(__name__)
Expand All @@ -21,12 +22,13 @@ class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
def __init__(
self,
original_module: torch.nn.Module,
output_shapes: List[torch.Size],
output_dtypes: List[torch.dtype],
):
super(WrapperTorchTensorRTModule, self).__init__()
self.original_module = original_module
self.inputs = partitioning.construct_submodule_inputs(original_module)
self.output_shapes: List[torch.Tensor] = []
self.output_shapes = output_shapes
self.output_dtypes = output_dtypes

self._input_buffers: List[torch.Tensor] = []
Expand All @@ -37,6 +39,7 @@ def __init__(
self.cudagraphs_enabled = False
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None
self.input_is_dynamic = input_is_dynamic(self.inputs)

# Disable cudagrphs in submodules as it will be enabled in wrapper
for name, rt_mod in self.original_module.named_children():
Expand Down Expand Up @@ -67,11 +70,12 @@ def validate_input_shapes(self, inputs: Sequence[torch.Tensor]) -> bool:
logger.debug(f"Input shape changed {self.shape_key} -> {new_shape_key}")
self.shape_key = new_shape_key

# TODO: avoid it for static input shape
outputs = self.original_module(*inputs)
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
self.output_shapes = [tuple(output.shape) for output in outputs]
if self.input_is_dynamic:
tmp_outputs = self.original_module(*inputs)
if not isinstance(tmp_outputs, (list, tuple)):
tmp_outputs = [tmp_outputs]
self.output_shapes = [tuple(output.shape) for output in tmp_outputs]

return True

return False
Expand All @@ -86,8 +90,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
(i.contiguous() if isinstance(i, torch.Tensor) else torch.tensor(i).cuda())
for i in inputs
]
with nvtx.annotate("Wrapper:Forward", color="orange"):

with (
torch.autograd.profiler.record_function(
"WrapperTorchTensorRTModule:Forward"
)
if self.profiling_enabled
else nullcontext()
):
shape_changed = self.validate_input_shapes(inputs)
cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()
# Cudagraphs record is required if cudagraphs_enabled is toggled to True regardless of shape change
Expand All @@ -100,6 +109,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if need_cudagraphs_record:
if self.cudagraph:
self.cudagraph.reset()

self._input_buffers = [None] * len(self.inputs)
self._output_buffers = [None] * len(self.output_shapes)

Expand Down Expand Up @@ -139,15 +149,21 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
]
logger.warning(f"Moved all input Tensors to cuda:{device_id}")

with nvtx.annotate("Wrapper:ProcessInputs", color="orange"):
with (
torch.autograd.profiler.record_function(
"WrapperTorchTensorRTModule:ProcessInputs"
)
if self.profiling_enabled
else nullcontext()
):
assert len(contiguous_inputs) == len(
self.inputs
), f"Wrong number of inputs, expect {len(self.inputs)} get {len(contiguous_inputs)}."

for i, input_name in enumerate(self.inputs):
for i, _ in enumerate(self.inputs):
if not contiguous_inputs[i].is_cuda:
logger.warning(
f"Detected input {input_name} of engine {self.engine.name} is not on a cuda device. "
f"Detected input[{i}] of engine {self.engine.name} is not on a cuda device. "
"This tensor is being moved by the runtime but for performance considerations, "
"ensure your inputs are all on GPU and open an issue here "
"(https://github.com/pytorch/TensorRT/issues) if this warning persists."
Expand All @@ -169,7 +185,13 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
elif cudagraphs_enabled:
self._input_buffers[i].copy_(contiguous_inputs[i])

with nvtx.annotate("ProcessOutputs", color="red"):
with (
torch.autograd.profiler.record_function(
"WrapperTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
# create output tensors
outputs: List[torch.Tensor] = []

Expand All @@ -189,34 +211,35 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

with nvtx.annotate("Wrapper:TensorRTRuntime", color="orange"):
with (
torch.autograd.profiler.record_function(
"WrapperTorchTensorRTModule:TensorRTRuntime"
)
if self.profiling_enabled
else nullcontext()
):
self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()

with nvtx.annotate("wait_stream", color="green"):
self._engine_stream.wait_stream(self._caller_stream)
self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):
if cudagraphs_enabled:
if need_cudagraphs_record:
with nvtx.annotate("CUDAGraph", color="green"):
self.cudagraph = torch.cuda.CUDAGraph()
self.cudagraph = torch.cuda.CUDAGraph()

if self.profiling_enabled:
self.cudagraph.enable_debug_mode()
with nvtx.annotate("torch.cuda.graph", color="green"):
with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
):
with nvtx.annotate("record", color="green"):
self._output_buffers = self.original_module(
*self._input_buffers
)
with torch.cuda.graph(
self.cudagraph, stream=self._engine_stream
):
self._output_buffers = self.original_module(
*self._input_buffers
)

if self.profiling_enabled:
import tempfile
Expand All @@ -225,8 +248,7 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self.cudagraph.debug_dump(
f"{tempdir}/{self.name}_cudagraph.dot"
)
with nvtx.annotate("replay", color="green"):
self.cudagraph.replay() # type: ignore
self.cudagraph.replay() # type: ignore

else:
outputs = self.original_module(*inputs)
Expand Down

0 comments on commit 0c6b6cd

Please sign in to comment.