5050from  torch_tensorrt .dynamo .debug ._DebuggerConfig  import  DebuggerConfig 
5151from  torch_tensorrt .dynamo .debug ._supports_debugger  import  cls_supports_debugger 
5252from  torch_tensorrt .dynamo .observer  import  Observer 
53- from  torch_tensorrt .dynamo .utils  import  DYNAMIC_DIM , deallocate_module , to_torch_device 
53+ from  torch_tensorrt .dynamo .utils  import  (
54+     DYNAMIC_DIM ,
55+     deallocate_module ,
56+     get_cpu_memory_usage ,
57+     to_torch_device ,
58+ )
5459from  torch_tensorrt .logging  import  TRT_LOGGER 
5560
5661_LOGGER : logging .Logger  =  logging .getLogger (__name__ )
@@ -65,7 +70,7 @@ class UnsupportedOperatorException(RuntimeError):
6570
6671
6772class  TRTInterpreterResult (NamedTuple ):
68-     serialized_engine :  bytes 
73+     engine :  trt . ICudaEngine 
6974    input_names : Sequence [str ]
7075    output_names : Sequence [str ]
7176    weight_name_map : Optional [dict [Any , Any ]]
@@ -512,8 +517,7 @@ def _save_weight_mapping(self) -> None:
512517        _LOGGER .info ("Building weight name mapping..." )
513518        # Stage 1: Name mapping 
514519        torch_device  =  to_torch_device (self .compilation_settings .device )
515-         self .module .to (torch_device )
516-         sd  =  self .module .state_dict ()
520+         sd  =  {k : v .to (torch_device ) for  k , v  in  self .module .state_dict ().items ()}
517521        weight_name_map : dict [str , Any ] =  {}
518522        weight_refit_map  =  self .ctx .weight_refit_map 
519523        constant_mapping  =  {k : v  for  k , v  in  weight_refit_map .items () if  v .size  ==  1 }
@@ -592,13 +596,11 @@ def _save_weight_mapping(self) -> None:
592596        torch .cuda .empty_cache ()
593597
594598    @needs_refit   # type: ignore[misc]  
595-     def  _insert_engine_to_cache (self , hash_val : str , serialized_engine : bytes ) ->  None :
599+     def  _insert_engine_to_cache (self , hash_val : str , engine : trt .ICudaEngine ) ->  None :
600+         serialized_engine  =  engine .serialize ()
596601        # TODO: @Evan is waiting for TRT's feature to cache the weight-stripped engine 
597602        # if not self.compilation_settings.strip_engine_weights: 
598603        #     # set EXCLUDE_WEIGHTS flag to strip weights 
599-         #     runtime = trt.Runtime(TRT_LOGGER) 
600-         #     engine = runtime.deserialize_cuda_engine(serialized_engine) 
601- 
602604        #     serialization_config = engine.create_serialization_config() 
603605        #     serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS) 
604606        #     serialized_engine = engine.serialize_with_config( 
@@ -733,6 +735,9 @@ def run(
733735                        return  interpreter_result   # type: ignore[no-any-return] 
734736
735737        self ._construct_trt_network_def ()
738+         _LOGGER .debug (
739+             f"CPU memory usage after network construction: { get_cpu_memory_usage ()}  
740+         )
736741
737742        if  not  self .compilation_settings .immutable_weights :
738743            self ._save_weight_mapping ()
@@ -750,16 +755,19 @@ def run(
750755        self ._create_timing_cache (
751756            builder_config , self .compilation_settings .timing_cache_path 
752757        )
753-         serialized_engine  =  self .builder .build_serialized_network (
758+ 
759+         cuda_engine  =  self .builder .build_engine_with_config (
754760            self .ctx .net , builder_config 
755761        )
756-         assert  serialized_engine 
762+         assert  cuda_engine 
763+ 
764+         _LOGGER .debug (
765+             f"CPU memory usage after engine building: { get_cpu_memory_usage ()}  
766+         )
757767
758768        _LOGGER .info (
759769            f"Build TRT engine elapsed time: { datetime .now () -  build_engine_start_time }  
760770        )
761-         _LOGGER .info (f"TRT Engine uses: { serialized_engine .nbytes }  )
762- 
763771        self .ctx .clear_cpu_weights_reference_holder ()
764772
765773        self ._save_timing_cache (
@@ -772,14 +780,10 @@ def run(
772780            and  self .compilation_settings .cache_built_engines 
773781            and  self .engine_cache  is  not None 
774782        ):
775-             self ._insert_engine_to_cache (hash_val , serialized_engine )
776- 
777-         with  io .BytesIO () as  engine_bytes :
778-             engine_bytes .write (serialized_engine )
779-             engine_str  =  engine_bytes .getvalue ()
783+             self ._insert_engine_to_cache (hash_val , cuda_engine )
780784
781785        return  TRTInterpreterResult (
782-             engine_str ,
786+             cuda_engine ,
783787            self ._input_names ,
784788            self ._output_names ,
785789            self .weight_name_map ,
0 commit comments