22
33import  io 
44import  logging 
5- from  typing  import  Any , List , Optional , Sequence 
5+ from  typing  import  Any , List , NamedTuple ,  Optional , Sequence 
66
77import  torch 
88from  torch_tensorrt ._enums  import  dtype 
99from  torch_tensorrt ._features  import  ENABLED_FEATURES 
1010from  torch_tensorrt ._Input  import  Input 
1111from  torch_tensorrt .dynamo ._engine_cache  import  BaseEngineCache 
1212from  torch_tensorrt .dynamo ._settings  import  CompilationSettings 
13- from  torch_tensorrt .dynamo .conversion ._TRTInterpreter  import  (
14-     TRTInterpreter ,
15-     TRTInterpreterResult ,
16- )
13+ from  torch_tensorrt .dynamo .conversion ._TRTInterpreter  import  TRTInterpreter 
1714from  torch_tensorrt .dynamo .runtime  import  PythonTorchTensorRTModule , TorchTensorRTModule 
1815from  torch_tensorrt .dynamo .utils  import  (
1916    get_cpu_memory_usage ,
2421logger  =  logging .getLogger (__name__ )
2522
2623
24+ class  SerializedInterpreterResult (NamedTuple ):
25+     serialized_engine : bytes 
26+     input_names : Sequence [str ]
27+     output_names : Sequence [str ]
28+     weight_name_map : Optional [dict [Any , Any ]]
29+     requires_output_allocator : bool 
30+ 
31+ 
2732def  infer_module_output_dtypes (
2833    module : torch .fx .GraphModule ,
2934    truncate_double : bool  =  False ,
@@ -34,7 +39,7 @@ def infer_module_output_dtypes(
3439    """ 
3540    outputs  =  [node  for  node  in  module .graph .nodes  if  node .op  ==  "output" ]
3641    outputs  =  outputs [0 ].args 
37-     return  get_output_dtypes (outputs , truncate_double )
42+     return  get_output_dtypes (outputs , truncate_double )   # type: ignore 
3843
3944
4045def  interpret_module_to_result (
@@ -44,7 +49,7 @@ def interpret_module_to_result(
4449    arg_inputs : Optional [Sequence [Input ]] =  None ,
4550    kwarg_inputs : Optional [dict [str , Any ]] =  None ,
4651    engine_cache : Optional [BaseEngineCache ] =  None ,
47- ) ->  TRTInterpreterResult :
52+ ) ->  SerializedInterpreterResult :
4853    """Interpret an FX module to a TRTInterpreterResult 
4954    Args: 
5055        module: FX GraphModule to interpret 
@@ -84,16 +89,18 @@ def interpret_module_to_result(
8489    with  io .BytesIO () as  engine_bytes :
8590        engine_bytes .write (serialized_engine )
8691        serialized_engine  =  engine_bytes .getvalue ()
87- 
88-     interpreter_result  =  TRTInterpreterResult (
89-         engine = serialized_engine ,
92+         logger .debug (
93+             f"CPU memory usage after serializing engine: { get_cpu_memory_usage ()}   MB" 
94+         )
95+     serialized_interpreter_result  =  SerializedInterpreterResult (
96+         serialized_engine = serialized_engine ,
9097        input_names = interpreter_result .input_names ,
9198        output_names = interpreter_result .output_names ,
9299        weight_name_map = interpreter_result .weight_name_map ,
93100        requires_output_allocator = interpreter_result .requires_output_allocator ,
94101    )
95102
96-     return  interpreter_result 
103+     return  serialized_interpreter_result 
97104
98105
99106def  convert_module (
@@ -132,7 +139,7 @@ def convert_module(
132139        )
133140
134141    return  rt_cls (
135-         serialized_engine = interpreter_result .engine ,
142+         serialized_engine = interpreter_result .serialized_engine ,
136143        input_binding_names = list (interpreter_result .input_names ),
137144        output_binding_names = list (interpreter_result .output_names ),
138145        name = name ,
0 commit comments