Skip to content

Commit

Permalink
chore: update test case
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Nov 15, 2024
1 parent fdbd3d8 commit ce063a2
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 75 deletions.
7 changes: 6 additions & 1 deletion py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
from torch_tensorrt.fx.utils import LowerPrecision
Expand Down Expand Up @@ -586,14 +589,16 @@ def save(
Save the model to disk in the specified output format.
Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule)): Compiled Torch-TensorRT module
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
if isinstance(module, WrapperTorchTensorRTModule):
module = module.original_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
if arg_inputs is not None and not all(
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

dryrun_stats_display(dryrun_tracker, settings.dryrun)

if len(trt_modules) > 1:
if len(dryrun_tracker.to_run_in_torch) > 0:
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
partitioned_module = WrapperTorchTensorRTModule(
partitioned_module,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None

# TODO: Make the below a Dictionary {shape: cudagraph}
self.shape_key: Optional[str] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def __init__(
if "_run_on_acc" in name:
rt_mod.set_cudagraphs_enabled_parent_module(True)

# TODO: check if only torch needs warm up.
# Warm up is necessary to ensure that memory allocations and initializations are not recorded in cuda graphs
with unset_fake_temporarily():
inputs_tensor = [spec.torch_tensor.cuda() for spec in self.inputs]
s = torch.cuda.Stream()
Expand Down Expand Up @@ -256,7 +256,6 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self._caller_stream.wait_stream(self._engine_stream)

if cudagraphs_enabled:
# TODO: submodule to return list only
if isinstance(self._output_buffers, (list, tuple)):
output_buffers = self._output_buffers
else:
Expand Down
10 changes: 9 additions & 1 deletion py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)

logger = logging.getLogger(__name__)

Expand All @@ -12,9 +15,14 @@ class _WeightStreamingContextManager(object):
Helper class used to setup weight streaming budget
"""

def __init__(self, module: torch.fx.GraphModule) -> None:
def __init__(
self, module: torch.fx.GraphModule | WrapperTorchTensorRTModule
) -> None:
rt_mods = []
self.current_device_budget = 0

if isinstance(module, WrapperTorchTensorRTModule):
module = module.original_module
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule)
Expand Down
70 changes: 0 additions & 70 deletions tests/py/dynamo/runtime/test_002_cudagraphs_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,76 +158,6 @@ def forward(self, x):
msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})",
)

def test_cudagraphs_dynamic_py(self):
class SampleModel(torch.nn.Module):
def forward(self, x):
return torch.relu((x + 2) * 0.5)

# TODO: more dynamic dim
# TODO: multiple output
# TODO: module that graph cannot be used
inputs = torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.float,
name="x",
)
fx_graph = torch.fx.symbolic_trace(SampleModel())

# Validate that the results between Torch and Torch-TRT are similar
optimized_model = torch_tensorrt.compile(
fx_graph,
"dynamo",
inputs,
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
use_python_runtime=True,
)

result_samples = []
torch_results_samples = []

inputs = []
for i in [1, 3, 8, 11, 16]:
inputs.append(torch.randn((i, 3, 224, 224)).cuda())

for n in range(len(inputs) * TRIALS):
i = n // TRIALS
# disable cuda graph at all index for all trials
if n % TRIALS == n // TRIALS:
torch_tensorrt.runtime.set_cudagraphs_mode(False)
else:
torch_tensorrt.runtime.set_cudagraphs_mode(True)

result_samples.append(optimized_model(inputs[i]).detach().cpu())
torch_results_samples.append(fx_graph(inputs[i]).detach().cpu())

for n in range(len(inputs) * TRIALS):
i = n // TRIALS
# enable cuda graph at all index for all trials
if n % TRIALS == n // TRIALS:
torch_tensorrt.runtime.set_cudagraphs_mode(True)
else:
torch_tensorrt.runtime.set_cudagraphs_mode(False)

result_samples.append(optimized_model(inputs[i]).detach().cpu())
torch_results_samples.append(fx_graph(inputs[i]).detach().cpu())

for i, (optimized_model_results, torch_model_results) in enumerate(
zip(result_samples, torch_results_samples)
):
max_diff = float(
torch.max(torch.abs(optimized_model_results - torch_model_results))
)
self.assertAlmostEqual(
max_diff,
0,
DECIMALS_OF_AGREEMENT,
msg=f"CUDA Graph Python TRT outputs don't match with the original model. (trial: {i})",
)


if __name__ == "__main__":
run_tests()

0 comments on commit ce063a2

Please sign in to comment.