-
Notifications
You must be signed in to change notification settings - Fork 351
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: Runtime api for pre-allocated outputs
- Loading branch information
Showing
8 changed files
with
259 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import logging | ||
from typing import Any | ||
|
||
import torch | ||
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class _PreAllocatedOutputContextManager(object): | ||
""" | ||
Helper class used to enable pre-allocated output feature in runtime module | ||
""" | ||
|
||
def __init__(self, module: torch.fx.GraphModule) -> None: | ||
rt_mods = [] | ||
for name, rt_mod in module.named_children(): | ||
if "_run_on_acc" in name and isinstance( | ||
rt_mod, (PythonTorchTensorRTModule, TorchTensorRTModule) | ||
): | ||
rt_mods.append(rt_mod) | ||
self.rt_mods = rt_mods | ||
|
||
def set_pre_allocated_output(self, enable: bool) -> None: | ||
for mod in self.rt_mods: | ||
mod.set_pre_allocated_outputs(enable) | ||
|
||
def __enter__(self) -> "_PreAllocatedOutputContextManager": | ||
# Enable pre-allocated output | ||
self.set_pre_allocated_output(True) | ||
return self | ||
|
||
def __exit__(self, *args: Any) -> None: | ||
# Disable pre-allocated output | ||
self.set_pre_allocated_output(False) | ||
|
||
|
||
def enable_pre_allocated_outputs( | ||
module: torch.fx.GraphModule, | ||
) -> _PreAllocatedOutputContextManager: | ||
return _PreAllocatedOutputContextManager(module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
import torch | ||
import torch_tensorrt as torchtrt | ||
from parameterized import parameterized | ||
from torch.testing._internal.common_utils import TestCase, run_tests | ||
|
||
INPUT_SIZE = (3, 16, 16) | ||
TRIALS = 5 | ||
|
||
|
||
class TestPreAllocatedOutputs(TestCase): | ||
@parameterized.expand( | ||
[ | ||
("python_runtime", True), | ||
("cpp_runtime", False), | ||
] | ||
) | ||
def test_pre_allocated_outputs_default(self, _, use_python_runtime): | ||
class SampleModel(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.softmax((x + 2) * 7, dim=0) | ||
|
||
model = SampleModel().eval().cuda() | ||
inputs = [torch.randn(*INPUT_SIZE).cuda() for _ in range(TRIALS)] | ||
fx_graph = torch.fx.symbolic_trace(model) | ||
|
||
# Validate that the results between Torch and Torch-TRT are similar | ||
optimized_model = torchtrt.compile( | ||
fx_graph, | ||
"torch_compile", | ||
inputs[0], | ||
min_block_size=1, | ||
pass_through_build_failures=True, | ||
use_python_runtime=use_python_runtime, | ||
) | ||
|
||
ref_out_list = [] | ||
trt_out_list = [] | ||
with torchtrt.runtime.enable_pre_allocated_outputs(optimized_model): | ||
for i in inputs: | ||
ref_out_list.append(fx_graph(i).detach().cpu()) | ||
trt_out_list.append(optimized_model(i).detach().cpu()) | ||
|
||
for torch_model_results, optimized_model_results in zip( | ||
ref_out_list, trt_out_list | ||
): | ||
torch.testing.assert_close( | ||
torch_model_results, | ||
optimized_model_results, | ||
rtol=5e-03, | ||
atol=5e-03, | ||
equal_nan=True, | ||
check_dtype=True, | ||
) | ||
|
||
torch._dynamo.reset() | ||
|
||
@parameterized.expand( | ||
[ | ||
("python_runtime", True), | ||
("cpp_runtime", False), | ||
] | ||
) | ||
def test_pre_allocated_outputs_dynamic(self, _, use_python_runtime): | ||
class SampleModel(torch.nn.Module): | ||
def forward(self, x): | ||
return torch.relu((x + 2) * 0.5) | ||
|
||
inputs = torchtrt.Input( | ||
min_shape=(1, 3, 128, 224), | ||
opt_shape=(8, 3, 192, 224), | ||
max_shape=(16, 3, 224, 224), | ||
dtype=torch.float, | ||
name="x", | ||
) | ||
fx_graph = torch.fx.symbolic_trace(SampleModel()) | ||
|
||
optimized_model = torchtrt.compile( | ||
fx_graph, | ||
"dynamo", | ||
inputs, | ||
min_block_size=1, | ||
pass_through_build_failures=True, | ||
torch_executed_ops={"torch.ops.aten.mul.Tensor"}, | ||
use_python_runtime=use_python_runtime, | ||
) | ||
|
||
input_list = [] | ||
ref_out_list = [] | ||
trt_out_list = [] | ||
# Alternating cuda_graphs enable and input shapes at every five iterations. | ||
for i in [1, 3, 8, 11, 16]: | ||
for j in [128, 128, 222, 222, 224]: | ||
input_list.append(torch.randn((i, 3, j, 224)).cuda()) | ||
|
||
pre_allocated_output_ctx = torchtrt.runtime.enable_pre_allocated_outputs( | ||
optimized_model | ||
) | ||
pre_allocated_output = False | ||
for enable_cuda_graphs in [False, True]: | ||
for i in range(len(input_list)): | ||
# Toggles cuda graph at all index in TRIALS | ||
if i % TRIALS == i // TRIALS: | ||
cuda_graphs = enable_cuda_graphs | ||
else: | ||
cuda_graphs = not enable_cuda_graphs | ||
if i % 3 == 0: | ||
pre_allocated_output = not pre_allocated_output | ||
|
||
torchtrt.runtime.set_cudagraphs_mode(cuda_graphs) | ||
pre_allocated_output_ctx.set_pre_allocated_output(pre_allocated_output) | ||
|
||
ref_out_list.append(fx_graph(input_list[i])) | ||
trt_out_list.append(optimized_model(input_list[i])) | ||
|
||
for torch_model_results, optimized_model_results in zip( | ||
ref_out_list, trt_out_list | ||
): | ||
torch.testing.assert_close( | ||
torch_model_results, | ||
optimized_model_results, | ||
rtol=5e-03, | ||
atol=5e-03, | ||
equal_nan=True, | ||
check_dtype=True, | ||
) | ||
torch._dynamo.reset() | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |