diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 428c414a06..989913bd31 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -52,7 +52,7 @@ def compile_bert(iterations=3): "truncate_double": True, "debug": False, "min_block_size": 1, - "make_refitable": True, + "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 5154dc1e2c..28ff73aa72 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): # in a subsequent compilation, either as part of this session or a new session, the cache will # pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude. # As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``), -# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details. +# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details. def torch_compile(iterations=3): @@ -97,7 +97,7 @@ def torch_compile(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, - "make_refitable": True, + "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, }, @@ -157,7 +157,7 @@ def dynamo_compile(iterations=3): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_size=1 << 30, # 1GB @@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, - "make_refitable": True, + "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": engine_cache, diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index a10c0e17ae..b68c9a11ee 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -31,7 +31,7 @@ settings = { "use_python": False, "enabled_precisions": {torch.float32}, - "make_refitable": True, + "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -80,7 +80,7 @@ "use_python_runtime": True, "enabled_precisions": {torch.float16}, "debug": True, - "make_refitable": True, + "make_refittable": True, } model_id = "runwayml/stable-diffusion-v1-5" diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 1feb033a3a..adf1057055 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -43,11 +43,11 @@ # %% -# Make a Refitable Compilation Program +# Make a refittable Compilation Program # --------------------------------------- # # The inital step is to compile a module and save it as with a normal. Note that there is an -# additional parameter `make_refitable` that is set to `True`. This parameter is used to +# additional parameter `make_refittable` that is set to `True`. This parameter is used to # indicate that the engine being built should support weight refitting later. Engines built without # these setttings will not be able to be refit. # @@ -69,7 +69,7 @@ debug=debug, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, - make_refitable=True, + make_refittable=True, ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 2e6ff039b4..a18b4aac35 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -60,7 +60,7 @@ def compile( Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refitable: bool = _defaults.MAKE_REFITABLE, + make_refittable: bool = _defaults.MAKE_refittable, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -180,14 +180,14 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", DeprecationWarning, stacklevel=2, ) - if make_refitable: - raise ValueError("Use flag make_refitable only. Flag refit is deprecated.") + if make_refittable: + raise ValueError("Use flag make_refittable only. Flag refit is deprecated.") else: - make_refitable = kwargs["refit"] + make_refittable = kwargs["refit"] engine_capability = EngineCapability._from(engine_capability) @@ -238,8 +238,8 @@ def compile( engine_cache = None if cache_built_engines or reuse_cached_engines: assert ( - make_refitable - ), "Engine caching requires make_refitable to be set to True" + make_refittable + ), "Engine caching requires make_refittable to be set to True" engine_cache = ( custom_engine_cache if custom_engine_cache is not None @@ -270,7 +270,7 @@ def compile( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refitable": make_refitable, + "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, @@ -513,7 +513,7 @@ def convert_exported_program_to_serialized_trt_engine( require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION, disable_tf32: bool = _defaults.DISABLE_TF32, sparse_weights: bool = _defaults.SPARSE_WEIGHTS, - make_refitable: bool = _defaults.MAKE_REFITABLE, + make_refittable: bool = _defaults.MAKE_refittable, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, dla_sram_size: int = _defaults.DLA_SRAM_SIZE, @@ -600,7 +600,7 @@ def convert_exported_program_to_serialized_trt_engine( ) if "refit" in kwargs.keys(): warnings.warn( - "Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.", + "Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.", DeprecationWarning, stacklevel=2, ) @@ -646,7 +646,7 @@ def convert_exported_program_to_serialized_trt_engine( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refitable": make_refitable, + "make_refittable": make_refittable, "engine_capability": engine_capability, "num_avg_timing_iters": num_avg_timing_iters, "dla_sram_size": dla_sram_size, diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 83e85cb3c7..dad97d6489 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -26,7 +26,7 @@ USE_PYTHON_RUNTIME = False USE_FAST_PARTITIONER = True ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False -MAKE_REFITABLE = False +MAKE_refittable = False REQUIRE_FULL_COMPILATION = False DRYRUN = False HARDWARE_COMPATIBLE = False diff --git a/py/torch_tensorrt/dynamo/_engine_cache.py b/py/torch_tensorrt/dynamo/_engine_cache.py index 7a33a81521..e775f2730f 100644 --- a/py/torch_tensorrt/dynamo/_engine_cache.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -1,17 +1,25 @@ import copy +from dataclasses import asdict import logging import os +import io import pickle +import pickletools import shutil from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, cast +from ast import literal_eval +from sympy.polys.matrices.dense import Sequence import torch -from torch._inductor.codecache import FxGraphCachePickler +from torch._inductor.codecache import FxGraphCachePickler, sha256_hash from torch.fx.experimental.proxy_tensor import unset_fake_temporarily +from torch_tensorrt.dynamo._settings import CompilationSettings, _SETTINGS_TO_BE_ENGINE_INVARIANT +from torch_tensorrt._Input import Input _LOGGER: logging.Logger = logging.getLogger(__name__) +UnpackedCacheHit = Tuple[bytes, List[str], List[str], Tuple[Input], CompilationSettings, Optional[Dict[Any, Any]]] class BaseEngineCache(ABC): @@ -24,7 +32,7 @@ def __init__( pass @staticmethod - def get_hash(gm: torch.fx.GraphModule) -> str: + def get_hash(gm: torch.fx.GraphModule, input_specs: Sequence[Input], settings: CompilationSettings) -> str: """Get the hash value of the GraphModule Args: @@ -39,7 +47,22 @@ def get_hash(gm: torch.fx.GraphModule) -> str: for name, param in new_gm.named_parameters(): param.data.zero_() - hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm)) + graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm)) + + input_spec_strs = [str(i) for i in input_specs] + with io.BytesIO() as stream: + input_specs_data = pickle.dumps(input_spec_strs) + input_specs_data = pickletools.optimize(input_specs_data) + input_specs_hash = sha256_hash(input_specs_data) + + invariant_engine_specs = [str(getattr(settings, field)) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT] + with io.BytesIO() as stream: + engine_specs_data = pickle.dumps(invariant_engine_specs) + engine_specs_data = pickletools.optimize(engine_specs_data) + engine_specs_hash = sha256_hash(engine_specs_data) + + # TODO: Super first idea I had hash combination solution @Evan please iterate on this + hash_val = graph_hash_val + input_specs_hash + engine_specs_hash return hash_val @@ -48,6 +71,8 @@ def pack( serialized_engine: bytes, input_names: List[str], output_names: List[str], + input_specs: Tuple[Input], + compilation_settings: CompilationSettings, weight_name_map: Optional[Dict[Any, Any]], ) -> bytes: """Pack serialized engine, input names, output names, and weight map into a single blob @@ -61,35 +86,80 @@ def pack( Returns: bytes: packed blob """ + + settings = copy.deepcopy(compilation_settings) + settings.torch_executed_ops = { + f"torch.ops.{op.__str__()}" + for op in settings.torch_executed_ops + } + return pickle.dumps( { "serialized_engine": bytes(serialized_engine), "input_names": input_names, "output_names": output_names, + "input_specs": input_specs, + "compilation_settings": settings, "weight_name_map": weight_name_map, } ) @staticmethod - def unpack( - packed_obj: bytes, - ) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]: + def unpack(packed_obj: bytes) -> UnpackedCacheHit: """Unpack packed blob into serialized engine, input names, output names, and weight map Args: packed_obj (bytes): packed blob Returns: - Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map + Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, CompilationSettings, weight name map """ unpacked = pickle.loads(packed_obj) return ( unpacked["serialized_engine"], unpacked["input_names"], unpacked["output_names"], + unpacked["input_specs"], + unpacked["compilation_settings"], unpacked["weight_name_map"], ) + def insert(self, hash: str, entry: UnpackedCacheHit, *args: Any, **kwargs: Any) -> None: + """ + Insert a cache entry into the engine cache. + + Args: + hash (str): The hash value of the GraphModule. + entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted. + *args: Variable length argument list passed to ``save``. + **kwargs: Arbitrary keyword arguments passed to ``save``. + + Returns: + None + """ + packed_cache_info = BaseEngineCache.pack(*entry) + return self.save(hash, packed_cache_info, *args, **kwargs) + + + def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheHit]: + """ + Check if a cache entry exists for the given hash. + + Args: + hash (str): The hash value of the GraphModule. + *args: Variable length argument list passed to ``load``. + **kwargs: Arbitrary keyword arguments passed to ``load``. + + Returns: + Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise. + """ + packed_cache_info = self.load(hash, *args, **kwargs) + + if packed_cache_info: + return BaseEngineCache.unpack(packed_cache_info) + else: + return None + @abstractmethod def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None: """Store blob in cache diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 8b0d7c3e20..c4043260fb 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -254,7 +254,7 @@ def refit_module_weights( ] assert ( encoded_metadata != "" - ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True" + ), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refittable=True" settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"] # Handle torch modules compiled_submodules_map = dict(compiled_submodules) @@ -270,7 +270,7 @@ def refit_module_weights( settings = submodule.settings assert ( - settings.make_refitable + settings.make_refittable ), "Refitting is not enabled. Please recompile the engine with refit=True." if settings.debug: diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index 063f6f3718..1f1c29fde5 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Collection, Optional, Set, Union +from typing import Collection, Optional, Set, Union, Tuple from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -18,7 +18,7 @@ ENGINE_CAPABILITY, HARDWARE_COMPATIBLE, LAZY_ENGINE_INIT, - MAKE_REFITABLE, + MAKE_refittable, MAX_AUX_STREAMS, MIN_BLOCK_SIZE, NUM_AVG_TIMING_ITERS, @@ -98,7 +98,7 @@ class CompilationSettings: disable_tf32: bool = DISABLE_TF32 assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT sparse_weights: bool = SPARSE_WEIGHTS - make_refitable: bool = MAKE_REFITABLE + make_refittable: bool = MAKE_refittable engine_capability: EngineCapability = field( default_factory=lambda: ENGINE_CAPABILITY ) @@ -112,3 +112,28 @@ class CompilationSettings: lazy_engine_init: bool = LAZY_ENGINE_INIT cache_built_engines: bool = CACHE_BUILT_ENGINES reuse_cached_engines: bool = REUSE_CACHED_ENGINES + + +_SETTINGS_TO_BE_ENGINE_INVARIANT =( + "enabled_precisions", + "max_aux_streams", + "version_compatible", + "optimization_level", + "disable_tf32", + "sparse_weights", + "make_refittable", + "engine_capability", + "hardware_compatible",) + + +def settings_are_compatible(set_a: CompilationSettings, set_b: CompilationSettings) -> Tuple[bool, Set[str]]: + incompatible_settings: Set[str] = set() + + for field in _SETTINGS_TO_BE_ENGINE_INVARIANT: + if getattr(set_a, field) != getattr(set_b, field): + incompatible_settings.add(field) + + if len(incompatible_settings) == 0: + return True, set() + else: + return False, incompatible_settings diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 84fe345137..d837c704bf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -27,7 +27,7 @@ from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._engine_cache import BaseEngineCache -from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( DYNAMO_CONVERTERS as CONVERTERS, @@ -282,7 +282,7 @@ def _populate_trt_builder_config( if self.compilation_settings.disable_tf32: builder_config.clear_flag(trt.BuilderFlag.TF32) - if self.compilation_settings.make_refitable: + if self.compilation_settings.make_refittable: builder_config.set_flag(trt.BuilderFlag.REFIT) if strict_type_constraints: @@ -533,18 +533,17 @@ def run( self.compilation_settings.cache_built_engines or self.compilation_settings.reuse_cached_engines ): - hash_val = self.engine_cache.get_hash(self.module) + hash_val = self.engine_cache.get_hash(self.module, self.input_specs, self.compilation_settings) if self.compilation_settings.reuse_cached_engines: # query the cached TRT engine - blob = self.engine_cache.load(hash_val) - if blob is not None: # hit the cache - serialized_engine, input_names, output_names, weight_name_map = ( - self.engine_cache.unpack(blob) - ) - self._input_names = input_names - self._output_names = output_names - self.weight_name_map = weight_name_map + cached_data = self.engine_cache.check(hash_val) + if cached_data is not None: # hit the cache + (serialized_engine, self._input_names, self._output_names, engine_input_specs, engine_compilation_settings, self.weight_name_map) = cached_data + + setting_compatiblity, incompattible_settings = settings_are_compatible(self.compilation_settings, engine_compilation_settings) + assert setting_compatiblity, f"Attempted to refit a prebuilt engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})" + _LOGGER.info( "Found the cached engine that corresponds to this graph. It is directly loaded." ) @@ -581,7 +580,7 @@ def run( self._construct_trt_network_def() - if self.compilation_settings.make_refitable: + if self.compilation_settings.make_refittable: self._save_weight_mapping() build_engine_start_time = datetime.now() @@ -612,13 +611,14 @@ def run( self.engine_cache is not None and self.compilation_settings.cache_built_engines ): - blob = self.engine_cache.pack( + self.engine_cache.insert(hash_val, ( serialized_engine, self._input_names, self._output_names, + self.input_specs, + self.compilation_settings, self.weight_name_map, - ) - self.engine_cache.save(hash_val, blob) + )) with io.BytesIO() as engine_bytes: engine_bytes.write(serialized_engine) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 672a7e267d..4e98892c5c 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -65,7 +65,7 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, - make_refitable: bool = _defaults.MAKE_REFITABLE, + make_refittable: bool = _defaults.MAKE_refittable, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -152,8 +152,8 @@ def __init__( device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} assert ( - make_refitable - ), "'make_refitable' has to be True for a MutableTorchTensorRTModule." + make_refittable + ), "'make_refittable' has to be True for a MutableTorchTensorRTModule." compilation_options = { "enabled_precisions": ( enabled_precisions @@ -180,7 +180,7 @@ def __init__( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, - "make_refitable": make_refitable, + "make_refittable": make_refittable, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index 63a932c353..be169ea2f2 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -5,6 +5,7 @@ import logging import pickle from typing import Any, List, Optional, Tuple, Union +from ast import literal_eval import torch from torch_tensorrt._Device import Device @@ -184,10 +185,6 @@ def setup_engine(self) -> None: def encode_metadata(self, metadata: Any) -> str: metadata = copy.deepcopy(metadata) - metadata["settings"].torch_executed_ops = { - f"torch.ops.{op.__str__()}" - for op in metadata["settings"].torch_executed_ops - } dumped_metadata = pickle.dumps(metadata) encoded_metadata = base64.b64encode(dumped_metadata).decode("utf-8") return encoded_metadata @@ -196,9 +193,6 @@ def encode_metadata(self, metadata: Any) -> str: def decode_metadata(encoded_metadata: bytes) -> Any: dumped_metadata = base64.b64decode(encoded_metadata.encode("utf-8")) metadata = pickle.loads(dumped_metadata) - metadata["settings"].torch_executed_ops = { - eval(op) for op in metadata["settings"].torch_executed_ops - } return metadata def get_extra_state(self) -> SerializedTorchTensorRTModuleFmt: @@ -240,7 +234,9 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: serialized_metadata = serialized_engine_info[SERIALIZED_METADATA_IDX] assert isinstance(serialized_metadata, bytes) - self.settings = TorchTensorRTModule.decode_metadata(serialized_metadata) + metadata = TorchTensorRTModule.decode_metadata(serialized_metadata) + self.settings = metadata["settings"] + self.weight_name_map = metadata["weight_name_map"] else: self.engine = None diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 2af7922cd1..d8aea04fbb 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -449,7 +449,6 @@ def parse_dynamo_kwargs( Returns: CompilationSettings object with relevant kwargs """ - # Initialize an empty CompilationSettings object settings = CompilationSettings() @@ -500,11 +499,12 @@ def parse_dynamo_kwargs( # If cache_built_engines and reuse_cached_engines are True but custom_engine_cache is not provided, # then create a default disk engine cache + # engine_cache = None if kwargs.get("cache_built_engines") or kwargs.get("reuse_cached_engines"): assert kwargs.get( - "make_refitable" - ), "Engine caching requires make_refitable to be set to True" + "make_refittable" + ), "Engine caching requires make_refittable to be set to True" if kwargs.get("custom_engine_cache") is not None: engine_cache = kwargs.get("custom_engine_cache") @@ -519,6 +519,9 @@ def parse_dynamo_kwargs( ) engine_cache = DiskEngineCache(engine_cache_dir, engine_cache_size) + if kwargs.get("torch_executed_ops"): + settings.torch_executed_ops = kwargs.get("torch_executed_ops") + logger.info("Compilation Settings: %s\n", settings) return settings, engine_cache diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 770e057a36..9d2588a50d 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -4,6 +4,8 @@ import unittest from typing import Optional +from torch_tensorrt.dynamo._settings import CompilationSettings + import pytest import torch import torch_tensorrt as torch_trt @@ -16,6 +18,7 @@ assertions = unittest.TestCase() + class MyEngineCache(BaseEngineCache): def __init__( self, @@ -25,6 +28,8 @@ def __init__( if not os.path.exists(self.engine_cache_dir): os.makedirs(self.engine_cache_dir, exist_ok=True) + self.hashes = {} + def save( self, hash: str, @@ -41,17 +46,127 @@ def save( with open(path, "wb") as f: f.write(blob) + self.hashes[hash] = 0 + def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") if os.path.exists(path): with open(path, "rb") as f: blob = f.read() + self.hashes[hash] += 1 return blob return None +class TestHashFunction(TestCase): + + def test_reexport_is_equal(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + batch = torch.export.Dim("batch", min=1, max=200) + + exp_program1 = torch.export.export( + pyt_model, + args=example_inputs, + dynamic_shapes={"x": {0: batch}} + ) + input_specs1 = (torch_trt.Input(min_shape=(1, 3, 224, 224), opt_shape=(100, 3, 224, 224), max_shape=(200, 3, 224, 224)),) + settings1 = CompilationSettings( + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True + ) + hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) + + exp_program2 = torch.export.export( + pyt_model, + args=example_inputs, + dynamic_shapes={"x": {0: batch}} + ) + input_specs2 = (torch_trt.Input(min_shape=(1, 3, 224, 224), opt_shape=(100, 3, 224, 224), max_shape=(200, 3, 224, 224)),) + settings2 = CompilationSettings( + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True + ) + hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) + + self.assertEqual(hash1, hash2) + + + + def test_input_shape_change_is_not_equal(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + batch = torch.export.Dim("batch", min=1, max=200) + + exp_program1 = torch.export.export( + pyt_model, + args=example_inputs, + dynamic_shapes={"x": {0: batch}} + ) + input_specs1 = (torch_trt.Input(min_shape=(1, 3, 224, 224), opt_shape=(100, 3, 224, 224), max_shape=(200, 3, 224, 224)),) + settings1 = CompilationSettings( + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True + ) + hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) + + exp_program2 = torch.export.export( + pyt_model, + args=example_inputs, + dynamic_shapes={"x": {0: batch}} + ) + input_specs2 = (torch_trt.Input(min_shape=(1, 3, 300, 300), opt_shape=(100, 3, 300, 300), max_shape=(200, 3, 300, 300)),) + settings2 = CompilationSettings( + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True + ) + hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) + + self.assertNotEqual(hash1, hash2) + + + def test_engine_settings_is_not_equal(self): + pyt_model = models.resnet18(pretrained=True).eval().to("cuda") + example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) + batch = torch.export.Dim("batch", min=1, max=200) + + exp_program1 = torch.export.export( + pyt_model, + args=example_inputs, + dynamic_shapes={"x": {0: batch}} + ) + input_specs1 = (torch_trt.Input(min_shape=(1, 3, 224, 224), opt_shape=(100, 3, 224, 224), max_shape=(200, 3, 224, 224)),) + settings1 = CompilationSettings( + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True, + enabled_precisions={torch.float32} + ) + hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) + + exp_program2 = torch.export.export( + pyt_model, + args=example_inputs, + dynamic_shapes={"x": {0: batch}} + ) + input_specs2 = (torch_trt.Input(min_shape=(1, 3, 300, 300), opt_shape=(100, 3, 300, 300), max_shape=(200, 3, 300, 300)),) + settings2 = CompilationSettings( + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True, + enabled_precisions={torch.float32, torch.float16} + ) + hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) + + self.assertNotEqual(hash1, hash2) + class TestEngineCache(TestCase): + @pytest.mark.xfail def test_dynamo_compile_with_default_disk_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") example_inputs = (torch.randn((100, 3, 224, 224)).to("cuda"),) @@ -61,7 +176,7 @@ def test_dynamo_compile_with_default_disk_engine_cache(self): model, args=example_inputs, dynamic_shapes={"x": {0: batch}} ) - engine_cache_dir = ENGINE_CACHE_DIR + engine_cache_dir = "/tmp/test_torch_dynamo_with_default_disk_engine_cache" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) @@ -82,6 +197,7 @@ def test_dynamo_compile_with_default_disk_engine_cache(self): cache_built_engines = True reuse_cached_engines = True + torch.cuda.synchronize() start.record() trt_gm = torch_trt.dynamo.compile( exp_program, @@ -90,12 +206,14 @@ def test_dynamo_compile_with_default_disk_engine_cache(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refitable=True, + make_refittable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, + engine_cache_dir=engine_cache_dir ) end.record() torch.cuda.synchronize() + torch._dynamo.reset() times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) @@ -119,7 +237,7 @@ def test_dynamo_compile_with_default_disk_engine_cache(self): def test_dynamo_compile_with_custom_engine_cache(self): model = models.resnet18(pretrained=True).eval().to("cuda") - engine_cache_dir = "/tmp/your_dir" + engine_cache_dir = "/tmp/test_torch_dynamo_with_custom_engine_cache" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) @@ -138,9 +256,6 @@ def test_dynamo_compile_with_custom_engine_cache(self): # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. inputs = [torch.rand((128, 3, 224, 224)).to("cuda")] results = [] - times = [] - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) for i in range(3): if i == 0: cache_built_engines = False @@ -149,7 +264,6 @@ def test_dynamo_compile_with_custom_engine_cache(self): cache_built_engines = True reuse_cached_engines = True - start.record() trt_gm = torch_trt.dynamo.compile( exp_program, tuple(inputs), @@ -157,14 +271,11 @@ def test_dynamo_compile_with_custom_engine_cache(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, - make_refitable=True, + make_refittable=True, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, custom_engine_cache=custom_engine_cache, ) - end.record() - torch.cuda.synchronize() - times.append(start.elapsed_time(end)) results.append(trt_gm(*inputs)) cos_sim = cosine_similarity(results[0], results[1]) @@ -179,11 +290,37 @@ def test_dynamo_compile_with_custom_engine_cache(self): msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - assertions.assertTrue( - times[0] > times[2], - msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", - ) + [assertions.assertTrue(count == 1, f"cache was not hit exactly once for entry ({h}, hit: {count})") for h, count in custom_engine_cache.hashes.items()] + + def test_dynamo_compile_change_input_shape(self): + """Runs compilation 3 times, the cache should miss each time""" + model = models.resnet18(pretrained=True).eval().to("cuda") + # Mark the dim0 of inputs as dynamic + + engine_cache_dir = "/tmp/test_torch_dynamo_with_custom_engine_cache" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + + for i in range(3): + inputs = (torch.rand((4*(i + 1), 3, 224, 224)).to("cuda"),) + trt_gm = torch_trt.dynamo.compile( + torch.export.export(model, args=inputs), + inputs=inputs, + use_python_runtime=False, + enabled_precisions={torch.float}, + debug=False, + min_block_size=1, + make_refittable=True, + cache_built_engines=True, + reuse_cached_engines=True, + ) + + [assertions.assertTrue(count == 0, f"Unintended cache hit for entry ({h}, hit: {count})") for h, count in custom_engine_cache.hashes.items()] + + @pytest.mark.xfail def test_torch_compile_with_default_disk_engine_cache(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") @@ -210,6 +347,7 @@ def test_torch_compile_with_default_disk_engine_cache(self): cache_built_engines = True reuse_cached_engines = True + torch.cuda.synchronize() start.record() compiled_model = torch.compile( model, @@ -219,16 +357,18 @@ def test_torch_compile_with_default_disk_engine_cache(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refitable": True, + "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, "engine_cache_size": 1 << 30, # 1GB + "torch_executed_ops": {"torch.ops.aten.relu.default"} }, ) results.append(compiled_model(*inputs)) # trigger the compilation end.record() torch.cuda.synchronize() + torch._dynamo.reset() times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) @@ -252,7 +392,7 @@ def test_torch_compile_with_custom_engine_cache(self): # Custom Engine Cache model = models.resnet18(pretrained=True).eval().to("cuda") - engine_cache_dir = "/tmp/your_dir" + engine_cache_dir = "/tmp/test_torch_compile_with_custom_engine_cache" if os.path.exists(engine_cache_dir): shutil.rmtree(engine_cache_dir) @@ -284,15 +424,17 @@ def test_torch_compile_with_custom_engine_cache(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, - "make_refitable": True, + "make_refittable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, + "torch_executed_ops": {"torch.ops.aten.relu.default"} }, ) results.append(compiled_model(*inputs)) # trigger the compilation end.record() torch.cuda.synchronize() + torch._dynamo.reset() times.append(start.elapsed_time(end)) cos_sim = cosine_similarity(results[0], results[1]) @@ -307,7 +449,36 @@ def test_torch_compile_with_custom_engine_cache(self): msg=f"results[1] doesn't match with results[2]. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - assertions.assertTrue( - times[0] > times[2], - msg=f"Engine caching didn't speed up the compilation. Time taken without engine caching: {times[0]} ms, time taken with engine caching: {times[2]} ms", - ) + [assertions.assertTrue(count == 1, f"cache was not hit exactly once for entry ({h}, hit: {count})") for h, count in custom_engine_cache.hashes.items()] + + + + def test_torch_compile_change_input_shape(self): + # Custom Engine Cache + model = models.resnet18(pretrained=True).eval().to("cuda") + + engine_cache_dir = "/tmp/test_torch_compile_with_default_disk_engine_cache" + if os.path.exists(engine_cache_dir): + shutil.rmtree(engine_cache_dir) + + custom_engine_cache = MyEngineCache(engine_cache_dir) + for i in range(3): + # remove timing cache and reset dynamo for engine caching messurement + inputs = [torch.rand((4 * (i + 1), 3, 224, 224)).to("cuda")] + compiled_model = torch.compile( + model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": {torch.float}, + "debug": False, + "min_block_size": 1, + "make_refittable": True, + "cache_built_engines": True, + "reuse_cached_engines": True, + "custom_engine_cache": custom_engine_cache, + "torch_executed_ops": {"torch.ops.aten.relu.default"} + }, + ) + + [assertions.assertTrue(count == 0, f"Unintended cache hit for entry ({h}, hit: {count})") for h, count in custom_engine_cache.hashes.items()] diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 9782cd829c..b5d3c962d6 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -57,7 +57,7 @@ def test_mapping(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -109,7 +109,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) new_trt_gm = refit_module_weights( @@ -159,7 +159,7 @@ def test_refit_one_engine_no_map_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) trt_gm._run_on_acc_0.weight_name_map = None @@ -210,7 +210,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) # Manually Deleted all batch norm layer. This suppose to fail the fast refit trt_gm._run_on_acc_0.weight_name_map = { @@ -267,7 +267,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) new_trt_gm = refit_module_weights( @@ -320,7 +320,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -366,7 +366,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) new_trt_gm = refit_module_weights( @@ -428,7 +428,7 @@ def forward(self, x): exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) - torch_executed_ops = {torch.ops.aten.convolution.default} + torch_executed_ops = {"torch.ops.aten.convolution.default"} trt_gm = torchtrt.dynamo.compile( exp_program, tuple(inputs), @@ -436,7 +436,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, torch_executed_ops=torch_executed_ops, ) @@ -486,7 +486,7 @@ def test_refit_one_engine_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) new_trt_gm = refit_module_weights( @@ -537,7 +537,7 @@ def test_refit_one_engine_bert_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) new_trt_gm = refit_module_weights( @@ -590,7 +590,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) torchtrt.save(trt_gm, trt_ep_path, inputs=inputs) trt_gm = torch.export.load(trt_ep_path) @@ -636,7 +636,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, ) new_trt_gm = refit_module_weights( @@ -698,7 +698,7 @@ def forward(self, x): exp_program = torch.export.export(model, tuple(inputs)) exp_program2 = torch.export.export(model2, tuple(inputs)) - torch_executed_ops = {torch.ops.aten.convolution.default} + torch_executed_ops = {"torch.ops.aten.convolution.default"} trt_gm = torchtrt.dynamo.compile( exp_program, tuple(inputs), @@ -706,7 +706,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, - make_refitable=True, + make_refittable=True, torch_executed_ops=torch_executed_ops, ) diff --git a/tests/py/dynamo/models/test_models_export.py b/tests/py/dynamo/models/test_models_export.py index bf19c3c5e6..a82b60330f 100644 --- a/tests/py/dynamo/models/test_models_export.py +++ b/tests/py/dynamo/models/test_models_export.py @@ -1,7 +1,9 @@ # type: ignore import unittest -import modelopt +import importlib + +from packaging.version import Version import pytest import timm import torch @@ -195,8 +197,13 @@ def test_resnet18_half(ir): torch.cuda.get_device_properties(torch.cuda.current_device()).major < 9, "FP8 compilation in Torch-TRT is not supported on cards older than Hopper", ) +@unittest.skipIf( + not importlib.util.find_spec("modelopt"), reason="ModelOpt is necessary to run this test" +) @pytest.mark.unit def test_base_fp8(ir): + import modelopt + class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() @@ -239,13 +246,14 @@ def calibrate_loop(model): outputs_trt = trt_model(input_tensor) assert torch.allclose(output_pyt, outputs_trt, rtol=1e-3, atol=1e-2) - @unittest.skipIf( - modelopt.__version__ < "0.16.1", - "Int8 quantization is supported in modelopt since 0.16.1 or later", + not importlib.util.find_spec("modelopt") or Version(importlib.metadata.version("modelopt")) < Version("0.16.1"), + "modelopt 0.16.1 or later is required Int8 quantization is supported in modelopt since 0.16.1 or later", ) @pytest.mark.unit def test_base_int8(ir): + import modelopt + class SimpleNetwork(torch.nn.Module): def __init__(self): super(SimpleNetwork, self).__init__() diff --git a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py index 008b0f53b1..aafd099bde 100644 --- a/tests/py/dynamo/runtime/test_002_lazy_engine_init.py +++ b/tests/py/dynamo/runtime/test_002_lazy_engine_init.py @@ -281,7 +281,7 @@ def forward(self, a, b): "ir": "dynamo", "lazy_engine_init": True, "use_python_runtime": True, - "torch_executed_ops": [torch.ops.aten.sub.Tensor], + "torch_executed_ops": {"torch.ops.aten.sub.Tensor"}, "cache_built_engines": False, "reuse_cached_engines": False, } @@ -325,7 +325,7 @@ def forward(self, a, b): "ir": "dynamo", "lazy_engine_init": True, "use_python_runtime": False, - "torch_executed_ops": [torch.ops.aten.sub.Tensor], + "torch_executed_ops": {"torch.ops.aten.sub.Tensor"}, "cache_built_engines": False, "reuse_cached_engines": False, } diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index 86e7678a66..b52530efd1 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -49,7 +49,7 @@ def test_resnet18(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refitable": True, + "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -89,7 +89,7 @@ def test_save(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refitable": True, + "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -123,7 +123,7 @@ def test_resnet18_modify_attribute(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refitable": True, + "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -164,7 +164,7 @@ def test_resnet18_modify_attribute_no_refit(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, - "make_refitable": True, + "make_refittable": True, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -243,7 +243,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refitable": True, + "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -304,7 +304,7 @@ def set_weights(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refitable": True, + "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -367,7 +367,7 @@ def set_layer(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refitable": True, + "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -436,7 +436,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", - "make_refitable": True, + "make_refittable": True, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec)