From 59fb544f8c72f8d4e60764515d64632ba32039c1 Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Wed, 15 Sep 2021 09:50:11 -0700 Subject: [PATCH 1/2] Add support for inplace mul_ operator from pytorch Signed-off-by: Anurag Dixit --- torch2trt/converters/mul.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch2trt/converters/mul.py b/torch2trt/converters/mul.py index eefd744c..ebbcccdf 100644 --- a/torch2trt/converters/mul.py +++ b/torch2trt/converters/mul.py @@ -3,6 +3,7 @@ @tensorrt_converter('torch.mul') +@tensorrt_converter('torch.Tensor.mul_') @tensorrt_converter('torch.Tensor.__imul__') @tensorrt_converter('torch.Tensor.__mul__') @tensorrt_converter('torch.Tensor.__rmul__') From dc01d80fe4033fe1f2302681f9405ab719c9ad3c Mon Sep 17 00:00:00 2001 From: Anurag Dixit Date: Wed, 15 Sep 2021 10:13:29 -0700 Subject: [PATCH 2/2] Updated TRT python api in torch2trt Signed-off-by: Anurag Dixit --- torch2trt/torch2trt.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/torch2trt/torch2trt.py b/torch2trt/torch2trt.py index 6a33a9ee..206dcdb8 100644 --- a/torch2trt/torch2trt.py +++ b/torch2trt/torch2trt.py @@ -555,10 +555,15 @@ def torch2trt(module, outputs = (outputs,) ctx.mark_outputs(outputs, output_names) - builder.max_workspace_size = max_workspace_size - builder.fp16_mode = fp16_mode builder.max_batch_size = max_batch_size - builder.strict_type_constraints = strict_type_constraints + config = builder.create_builder_config() + config.max_workspace_size = max_workspace_size + + if strict_type_constraints: + config.set_flag(trt.BuilderFlag.STRICT_TYPES) + + if fp16_mode: + config.set_flag(trt.BuilderFlag.FP16) if int8_mode: @@ -566,7 +571,7 @@ def torch2trt(module, if int8_calib_dataset is None: int8_calib_dataset = TensorBatchDataset(inputs_in) - builder.int8_mode = True + config.set_flag(trt.BuilderFlag.INT8) #Making sure not to run calibration with QAT mode on if not 'qat_mode' in kwargs: @@ -575,7 +580,7 @@ def torch2trt(module, inputs, int8_calib_dataset, batch_size=int8_calib_batch_size, algorithm=int8_calib_algorithm ) - engine = builder.build_cuda_engine(network) + engine = builder.build_engine(network, config) module_trt = TRTModule(engine, input_names, output_names)