From 4ec12c499b751d09139cba39e40d7f4abc471c4c Mon Sep 17 00:00:00 2001 From: John Welsh Date: Tue, 26 Jul 2022 12:56:02 -0700 Subject: [PATCH 1/2] add cpu tracing --- torch2trt/tests/test_cpu_tracing.py | 22 ++++++++++++++++++++++ torch2trt/torch2trt.py | 19 ++++++++++++++++--- 2 files changed, 38 insertions(+), 3 deletions(-) create mode 100644 torch2trt/tests/test_cpu_tracing.py diff --git a/torch2trt/tests/test_cpu_tracing.py b/torch2trt/tests/test_cpu_tracing.py new file mode 100644 index 00000000..6f65201c --- /dev/null +++ b/torch2trt/tests/test_cpu_tracing.py @@ -0,0 +1,22 @@ +import pytest +import torch +from torch2trt import torch2trt + + +def test_cpu_tracing(): + + model = torch.nn.Conv2d(3, 3, kernel_size=1) + + data = torch.randn(1, 3, 32, 32) + + model_trt = torch2trt(model, [data]) + + assert(hasattr(model_trt, 'engine')) + assert(model_trt.engine is not None) + + data = torch.randn(1, 3, 32, 32) + assert(torch.allclose(model(data), model_trt(data), atol=1e-3, rtol=1e-3)) + + +if __name__ == '__main__': + test_cpu_tracing() diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index cdd425aa..e7bc923f 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -509,7 +509,8 @@ def add_inputs(self, torch_inputs, names=None, dynamic_axes=None): shape=shape, dtype=torch_dtype_to_trt(torch_input.dtype), ) - trt_tensor.location = torch_device_to_trt(torch_input.device) + # trt_tensor.location = torch_device_to_trt(torch_input.device) + trt_tensor.location = trt.TensorLocation.DEVICE torch_input._trt = trt_tensor def mark_outputs(self, torch_outputs, names=None): @@ -520,7 +521,8 @@ def mark_outputs(self, torch_outputs, names=None): for i, torch_output in enumerate(torch_outputs): trt_tensor = torch_output._trt trt_tensor.name = names[i] - trt_tensor.location = torch_device_to_trt(torch_output.device) + # trt_tensor.location = torch_device_to_trt(torch_output.device) + trt_tensor.location = trt.TensorLocation.DEVICE trt_tensor.dtype = torch_dtype_to_trt(torch_output.dtype) self.network.mark_output(trt_tensor) @@ -576,9 +578,15 @@ def _load_from_state_dict( def forward(self, *inputs): bindings = [None] * (len(self.input_names) + len(self.output_names)) + # flatten inputs if self.input_flattener is not None: inputs = self.input_flattener.flatten(inputs) + input_dtype = inputs[0].device + + # place inputs on device + inputs = [t.cuda() for t in inputs] + for i, input_name in enumerate(self.input_names): idx = self.engine.get_binding_index(input_name) shape = tuple(inputs[i].shape) @@ -600,6 +608,10 @@ def forward(self, *inputs): bindings, torch.cuda.current_stream().cuda_stream ) + # map outputs to input dtype + outputs = [t.to(input_dtype) for t in outputs] + + # unflatten outputs if self.output_flattener is not None: outputs = self.output_flattener.unflatten(outputs) else: @@ -782,7 +794,8 @@ def torch2trt(module, config.set_calibration_profile(profile) # BUILD ENGINE - + torch.cuda.empty_cache() + engine = builder.build_engine(network, config) module_trt = TRTModule(engine, input_names, output_names, input_flattener=input_flattener, output_flattener=output_flattener) From 493dc7db0de610d2c3c72bbdd185e4478758cc1d Mon Sep 17 00:00:00 2001 From: John Welsh Date: Tue, 26 Jul 2022 12:58:06 -0700 Subject: [PATCH 2/2] remove empty cache --- torch2trt/torch2trt.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index e7bc923f..a385a431 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -794,8 +794,7 @@ def torch2trt(module, config.set_calibration_profile(profile) # BUILD ENGINE - torch.cuda.empty_cache() - + engine = builder.build_engine(network, config) module_trt = TRTModule(engine, input_names, output_names, input_flattener=input_flattener, output_flattener=output_flattener)