From 2d70d18fccde9222b1caa717f60f4f496d93b5fd Mon Sep 17 00:00:00 2001 From: Jiwoong Choi Date: Fri, 30 Aug 2024 13:49:15 +0900 Subject: [PATCH] fix: allow `dict` type `module_outputs` in `infer_module_output_dtypes` --- py/torch_tensorrt/dynamo/conversion/_conversion.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index cd38ce56e6..6d4be2a65a 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -46,7 +46,9 @@ def infer_module_output_dtypes( kwarg_inputs = {} torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device) module_outputs = module(*torch_inputs, **torch_kwarg_inputs) - if not isinstance(module_outputs, (list, tuple)): + if isinstance(module_outputs, dict): + module_outputs = list(module_outputs.values()) + elif not isinstance(module_outputs, (list, tuple)): module_outputs = [module_outputs] # Int64 outputs can sometimes be generated from within other operators