From 05b560d9a48f16dfd6aac4fbbf232798b9b8b307 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Thu, 28 Nov 2024 00:07:19 +0000 Subject: [PATCH] change to opt-in feature --- .../dynamo/engine_caching_bert_example.py | 1 + examples/dynamo/engine_caching_example.py | 5 + .../dynamo/mutable_torchtrt_module_example.py | 2 + examples/dynamo/refit_engine_example.py | 7 +- py/torch_tensorrt/dynamo/_compiler.py | 101 ++++++++++++++++-- py/torch_tensorrt/dynamo/_defaults.py | 2 +- py/torch_tensorrt/dynamo/_refit.py | 4 + py/torch_tensorrt/dynamo/backend/backends.py | 2 +- .../dynamo/conversion/_TRTInterpreter.py | 13 +-- .../runtime/_MutableTorchTensorRTModule.py | 7 +- .../py/dynamo/conversion/test_cumsum_aten.py | 4 + .../conversion/test_embedding_bag_aten.py | 4 + tests/py/dynamo/models/test_engine_cache.py | 19 +++- tests/py/dynamo/models/test_model_refit.py | 14 +++ .../models/test_weight_stripped_engine.py | 14 +++ .../runtime/test_mutable_torchtrt_module.py | 8 ++ 16 files changed, 187 insertions(+), 20 deletions(-) diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 9cddefd509..1148d4f792 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -52,6 +52,7 @@ def compile_bert(iterations=3): "truncate_double": True, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": "/tmp/torch_trt_bert_engine_cache", diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 20388e9372..fb4c341077 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -62,6 +62,8 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): # engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If # in a subsequent compilation, either as part of this session or a new session, the cache will # pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude. +# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``), +# the engine must be refittable (``immutable_weights=False``). See :ref:`refit_engine_example` for more details. def torch_compile(iterations=3): @@ -95,6 +97,7 @@ def torch_compile(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, }, @@ -154,6 +157,7 @@ def dynamo_compile(iterations=3): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_size=1 << 30, # 1GB @@ -264,6 +268,7 @@ def torch_compile_my_cache(iterations=3): "enabled_precisions": enabled_precisions, "debug": debug, "min_block_size": min_block_size, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": engine_cache, diff --git a/examples/dynamo/mutable_torchtrt_module_example.py b/examples/dynamo/mutable_torchtrt_module_example.py index 3ea9fab9a5..8b62855c32 100644 --- a/examples/dynamo/mutable_torchtrt_module_example.py +++ b/examples/dynamo/mutable_torchtrt_module_example.py @@ -31,6 +31,7 @@ settings = { "use_python": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -79,6 +80,7 @@ "use_python_runtime": True, "enabled_precisions": {torch.float16}, "debug": True, + "immutable_weights": False, } model_id = "runwayml/stable-diffusion-v1-5" diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index 44f78abbc0..66a1a70964 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -46,7 +46,10 @@ # Make a refittable Compilation Program # --------------------------------------- # -# The inital step is to compile a module and save it as with a normal. +# The inital step is to compile a module and save it as with a normal. Note that there is an +# additional parameter `immutable_weights` that is set to `False`. This parameter is used to +# indicate that the engine being built should support weight refitting later. Engines built without +# these setttings will not be able to be refit. # # In this case we are going to compile a ResNet18 model with randomly initialized weights and save it. @@ -66,6 +69,8 @@ debug=debug, min_block_size=min_block_size, torch_executed_ops=torch_executed_ops, + immutable_weights=False, + reuse_cached_engines=False, ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 007f5632a3..93fbd675ec 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -92,6 +92,8 @@ def cross_compile_for_windows( custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE, use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING, use_fp32_acc: bool = _defaults.USE_FP32_ACC, + refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS, + strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS, immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING, **kwargs: Any, @@ -163,6 +165,8 @@ def cross_compile_for_windows( custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored. use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs. use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions. + refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs. + strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required. immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored. enable_weight_streaming (bool): Enable weight streaming. **kwargs: Any, @@ -193,17 +197,44 @@ def cross_compile_for_windows( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) engine_capability = EngineCapability._from(engine_capability) @@ -288,6 +319,8 @@ def cross_compile_for_windows( "lazy_engine_init": lazy_engine_init, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, + "refit_identical_engine_weights": refit_identical_engine_weights, + "strip_engine_weights": strip_engine_weights, "immutable_weights": immutable_weights, "enable_cross_compile_for_windows": True, "enable_weight_streaming": enable_weight_streaming, @@ -475,17 +508,44 @@ def compile( if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) if ( "enable_cross_compile_for_windows" in kwargs.keys() @@ -965,18 +1025,47 @@ def convert_exported_program_to_serialized_trt_engine( DeprecationWarning, stacklevel=2, ) + if "refit" in kwargs.keys(): warnings.warn( - "`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["refit"] + if "make_refittable" in kwargs.keys(): warnings.warn( - "`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.", + "`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.", DeprecationWarning, stacklevel=2, ) + if immutable_weights: + raise ValueError( + "Use flag `immutable_weights` only. Flag `refit` is deprecated." + ) + else: + immutable_weights = not kwargs["make_refittable"] + + if refit_identical_engine_weights: + if immutable_weights: + raise ValueError( + "`immutable_weights` must be False when `refit_identical_engine_weights` is True." + ) + + if ( + not immutable_weights + and not refit_identical_engine_weights + and enable_weight_streaming + ): + raise ValueError( + "TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305" + ) if arg_inputs is None and inputs is None: raise AssertionError("'arg_inputs' and 'inputs' should not both be None.") diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 1341ca739f..76630a75a5 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -43,7 +43,7 @@ USE_FP32_ACC = False REFIT_IDENTICAL_ENGINE_WEIGHTS = False STRIP_ENGINE_WEIGHTS = False -IMMUTABLE_WEIGHTS = False +IMMUTABLE_WEIGHTS = True ENABLE_WEIGHT_STREAMING = False ENABLE_CROSS_COMPILE_FOR_WINDOWS = False diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index ca379a9ada..98e6b627ab 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -285,6 +285,10 @@ def refit_module_weights( assert settings is not None + assert ( + not settings.immutable_weights + ), "Refitting is not enabled. Please recompile the engine with immutable_weights=False." + if settings.debug: set_log_level(logger.parent, logging.DEBUG) diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index fa808aa20b..c8a30e656b 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -112,7 +112,7 @@ def _pretraced_backend( "require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt" ) if settings.strip_engine_weights: - logger.warning( + logger.error( "strip_engine_weights arg is not supported for torch.compile()" ) trt_compiled = compile_module( diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 7ffc02ca3d..d7c0ea449e 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -291,6 +291,8 @@ def _populate_trt_builder_config( # non-refittable engine if self.compilation_settings.strip_engine_weights: _LOGGER.warning("strip_engine_weights will be ignored.") + if self.compilation_settings.refit_identical_engine_weights: + _LOGGER.warning("refit_identical_engine_weights will be ignored.") else: # refittable engine if self.compilation_settings.refit_identical_engine_weights: @@ -496,16 +498,15 @@ def _save_weight_mapping(self) -> None: suffix = sd_weight_name_list[-1] # Retrieve each weight name(s) in state_dict if layer_type == "CONSTANT": - if "embedding" in suffix: - sd_weight_name = f"{sd_weight_name}.weight" - elif "weight" in suffix or "mm_other" in suffix: - # Linear layer weight + if ( + "embedding" in suffix + or "weight" in suffix + or "mm_other" in suffix + ): sd_weight_name = f"{sd_weight_name}.weight" elif "running_mean" in suffix: - # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_mean" elif "running_var" in suffix: - # Linear layer weight sd_weight_name = f"{sd_weight_name}.running_var" elif "bias" in suffix: sd_weight_name = f"{sd_weight_name}.bias" diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index f51707768e..134d84cf6d 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -65,6 +65,7 @@ def __init__( Union[torch.dtype, dtype] ] = _defaults.ENABLED_PRECISIONS, engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY, + immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS, debug: bool = _defaults.DEBUG, num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS, workspace_size: int = _defaults.WORKSPACE_SIZE, @@ -102,7 +103,7 @@ def __init__( assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels - refit (bool): Enable refitting + immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. debug (bool): Enable debuggable engine capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels @@ -150,6 +151,9 @@ def __init__( self.kwarg_inputs: dict[str, Any] = {} device = to_torch_tensorrt_device(device) enabled_precisions = {dtype._from(p) for p in enabled_precisions} + assert ( + not immutable_weights + ), "`immutable_weights` has to be False for a MutableTorchTensorRTModule." compilation_options = { "enabled_precisions": ( enabled_precisions @@ -176,6 +180,7 @@ def __init__( "require_full_compilation": require_full_compilation, "disable_tf32": disable_tf32, "sparse_weights": sparse_weights, + "immutable_weights": immutable_weights, "engine_capability": engine_capability, "dla_sram_size": dla_sram_size, "dla_local_dram_size": dla_local_dram_size, diff --git a/tests/py/dynamo/conversion/test_cumsum_aten.py b/tests/py/dynamo/conversion/test_cumsum_aten.py index 4143401bd4..8ab699468d 100644 --- a/tests/py/dynamo/conversion/test_cumsum_aten.py +++ b/tests/py/dynamo/conversion/test_cumsum_aten.py @@ -24,6 +24,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + immutable_weights=True, ) @parameterized.expand( @@ -43,6 +44,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + immutable_weights=True, ) @parameterized.expand( @@ -63,6 +65,7 @@ def forward(self, x): self.run_test( Cumsum(), inputs, + immutable_weights=True, ) @parameterized.expand( @@ -92,6 +95,7 @@ def forward(self, x): self.run_test_with_dynamic_shape( Cumsum(), inputs, + immutable_weights=True, ) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 3fef3d70cf..1f119bd77e 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -148,6 +148,7 @@ def forward(self, weight, indices): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + immutable_weights=True, ) @parameterized.expand( @@ -345,6 +346,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + immutable_weights=True, ) @parameterized.expand( @@ -409,6 +411,7 @@ def forward(self, weight, indices, offsets): precision=weight.dtype, enable_passes=True, propagate_shapes=True, + immutable_weights=True, ) @parameterized.expand( @@ -490,6 +493,7 @@ def forward(self, weights, indices, offsets, per_sample_weights=None): min_block_size=1, cache_built_engines=False, reuse_cached_engines=False, + immutable_weights=True, ) # use the inputs with different shape to inference: if per_sample_weights is None: diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 5044654d81..68451674c5 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -74,7 +74,7 @@ def test_reexport_is_equal(self): ), ) settings1 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -89,7 +89,7 @@ def test_reexport_is_equal(self): ), ) settings2 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -111,7 +111,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings1 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash1 = BaseEngineCache.get_hash(exp_program1.module(), input_specs1, settings1) @@ -126,7 +126,7 @@ def test_input_shape_change_is_not_equal(self): ), ) settings2 = CompilationSettings( - cache_built_engines=True, reuse_cached_engines=True + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True ) hash2 = BaseEngineCache.get_hash(exp_program2.module(), input_specs2, settings2) @@ -148,6 +148,7 @@ def test_engine_settings_is_not_equal(self): ), ) settings1 = CompilationSettings( + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32}, @@ -165,6 +166,7 @@ def test_engine_settings_is_not_equal(self): ), ) settings2 = CompilationSettings( + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, enabled_precisions={torch.float32, torch.float16}, @@ -223,6 +225,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -286,6 +289,7 @@ def test_dynamo_compile_with_custom_engine_cache(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, custom_engine_cache=custom_engine_cache, @@ -332,6 +336,7 @@ def test_dynamo_compile_change_input_shape(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=True, reuse_cached_engines=True, ) @@ -386,6 +391,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, @@ -449,6 +455,7 @@ def test_torch_compile_with_custom_engine_cache(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "custom_engine_cache": custom_engine_cache, @@ -498,6 +505,7 @@ def test_torch_trt_compile_change_input_shape(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "custom_engine_cache": custom_engine_cache, @@ -540,6 +548,7 @@ def forward(self, x): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "custom_engine_cache": custom_engine_cache, @@ -628,6 +637,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=False, reuse_cached_engines=False, strip_engine_weights=False, @@ -858,6 +868,7 @@ def remove_timing_cache(path=timing_cache_path): enabled_precisions={torch.float32}, debug=False, min_block_size=1, + immutable_weights=False, truncate_double=True, device=DEVICE, disable_tf32=True, diff --git a/tests/py/dynamo/models/test_model_refit.py b/tests/py/dynamo/models/test_model_refit.py index 331db1d4fd..bb61ac2d43 100644 --- a/tests/py/dynamo/models/test_model_refit.py +++ b/tests/py/dynamo/models/test_model_refit.py @@ -55,6 +55,7 @@ def test_mapping(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) settings = trt_gm._run_on_acc_0.settings runtime = trt.Runtime(TRT_LOGGER) @@ -106,6 +107,7 @@ def test_refit_one_engine_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -155,6 +157,7 @@ def test_refit_one_engine_no_map_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) trt_gm._run_on_acc_0.weight_name_map = None @@ -205,6 +208,7 @@ def test_refit_one_engine_with_wrong_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) # Manually Deleted all batch norm layer. This suppose to fail the fast refit trt_gm._run_on_acc_0.weight_name_map = { @@ -261,6 +265,7 @@ def test_refit_one_engine_bert_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -313,6 +318,7 @@ def test_refit_one_engine_inline_runtime__with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) @@ -358,6 +364,7 @@ def test_refit_one_engine_python_runtime_with_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -427,6 +434,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -477,6 +485,7 @@ def test_refit_one_engine_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -527,6 +536,7 @@ def test_refit_one_engine_bert_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -579,6 +589,7 @@ def test_refit_one_engine_inline_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) torchtrt.save(trt_gm, trt_ep_path) trt_gm = torch.export.load(trt_ep_path) @@ -624,6 +635,7 @@ def test_refit_one_engine_python_runtime_without_weightmap(): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, ) new_trt_gm = refit_module_weights( @@ -693,6 +705,7 @@ def forward(self, x): enabled_precisions=enabled_precisions, debug=debug, min_block_size=min_block_size, + immutable_weights=False, torch_executed_ops=torch_executed_ops, reuse_cached_engines=False, ) @@ -746,6 +759,7 @@ def forward(self, x): enabled_precisions={torch.float}, debug=True, min_block_size=1, + immutable_weights=False, ) num_pyt_segments = len( diff --git a/tests/py/dynamo/models/test_weight_stripped_engine.py b/tests/py/dynamo/models/test_weight_stripped_engine.py index e647d623b5..67cfd167ed 100644 --- a/tests/py/dynamo/models/test_weight_stripped_engine.py +++ b/tests/py/dynamo/models/test_weight_stripped_engine.py @@ -26,6 +26,7 @@ def test_three_ways_to_compile(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "strip_engine_weights": False, "refit_identical_engine_weights": False, } @@ -76,6 +77,7 @@ def test_three_ways_to_compile_weight_stripped_engine(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "strip_engine_weights": True, "refit_identical_engine_weights": False, } @@ -117,12 +119,14 @@ def test_weight_stripped_engine_sizes(self): weight_included_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, + immutable_weights=False, strip_engine_weights=False, refit_identical_engine_weights=False, ) weight_stripped_engine = convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -130,6 +134,7 @@ def test_weight_stripped_engine_sizes(self): convert_exported_program_to_serialized_trt_engine( exp_program, example_inputs, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=True, ) @@ -162,6 +167,7 @@ def test_weight_stripped_engine_results(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -187,6 +193,7 @@ def test_weight_stripped_engine_results(self): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": False, "reuse_cached_engines": False, "refit_identical_engine_weights": False, @@ -226,6 +233,7 @@ def test_engine_caching_saves_weight_stripped_engine(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=False, # engine cache will save the stripped engine even if this is False refit_identical_engine_weights=True, cache_built_engines=True, @@ -291,6 +299,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, cache_built_engines=cache_built_engines, reuse_cached_engines=reuse_cached_engines, engine_cache_dir=engine_cache_dir, @@ -371,6 +380,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, "engine_cache_dir": engine_cache_dir, @@ -444,6 +454,7 @@ def forward(self, x): "enabled_precisions": {torch.float}, "debug": False, "min_block_size": 1, + "immutable_weights": False, "cache_built_engines": True, "reuse_cached_engines": True, "engine_cache_dir": engine_cache_dir, @@ -478,6 +489,7 @@ def forward(self, x): ir="dynamo", inputs=tuple(inputs), min_block_size=1, + immutable_weights=False, use_python_runtime=True, strip_engine_weights=True, refit_identical_engine_weights=False, @@ -517,6 +529,7 @@ def test_two_TRTRuntime_in_refitting(self): use_python_runtime=use_python_runtime, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=False, ) @@ -549,6 +562,7 @@ def test_refit_identical_engine_weights(self): enabled_precisions={torch.float}, debug=False, min_block_size=1, + immutable_weights=False, strip_engine_weights=True, refit_identical_engine_weights=True, ) diff --git a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py index fd9fa4e1e0..f2bcaf7ede 100644 --- a/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py +++ b/tests/py/dynamo/runtime/test_mutable_torchtrt_module.py @@ -49,6 +49,7 @@ def test_resnet18(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -88,6 +89,7 @@ def test_save(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -121,6 +123,7 @@ def test_resnet18_modify_attribute(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -161,6 +164,7 @@ def test_resnet18_modify_attribute_no_refit(): compile_spec = { "use_python_runtime": False, "enabled_precisions": {torch.float32}, + "immutable_weights": False, } model = models.resnet18(pretrained=True).eval().to("cuda") @@ -239,6 +243,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -299,6 +304,7 @@ def set_weights(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -361,6 +367,7 @@ def set_layer(self): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec) @@ -429,6 +436,7 @@ def forward(self, x, b=5, c=None, d=None): "optimization_level": 1, "min_block_size": 1, "ir": "dynamo", + "immutable_weights": False, } mutable_module = torch_trt.MutableTorchTensorRTModule(model, **compile_spec)