Skip to content

Commit

Permalink
change to opt-in feature
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Nov 28, 2024
1 parent c69c61a commit 05b560d
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 20 deletions.
1 change: 1 addition & 0 deletions examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def compile_bert(iterations=3):
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
Expand Down
5 changes: 5 additions & 0 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refittable (``immutable_weights=False``). See :ref:`refit_engine_example` for more details.


def torch_compile(iterations=3):
Expand Down Expand Up @@ -95,6 +97,7 @@ def torch_compile(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
Expand Down Expand Up @@ -154,6 +157,7 @@ def dynamo_compile(iterations=3):
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
immutable_weights=False,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
Expand Down Expand Up @@ -264,6 +268,7 @@ def torch_compile_my_cache(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"immutable_weights": False,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
Expand Down
2 changes: 2 additions & 0 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -79,6 +80,7 @@
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"immutable_weights": False,
}

model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
7 changes: 6 additions & 1 deletion examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@
# Make a refittable Compilation Program
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal.
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `immutable_weights` that is set to `False`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
#
# In this case we are going to compile a ResNet18 model with randomly initialized weights and save it.

Expand All @@ -66,6 +69,8 @@
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
immutable_weights=False,
reuse_cached_engines=False,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
101 changes: 95 additions & 6 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def cross_compile_for_windows(
custom_engine_cache: Optional[BaseEngineCache] = _defaults.CUSTOM_ENGINE_CACHE,
use_explicit_typing: bool = _defaults.USE_EXPLICIT_TYPING,
use_fp32_acc: bool = _defaults.USE_FP32_ACC,
refit_identical_engine_weights: bool = _defaults.REFIT_IDENTICAL_ENGINE_WEIGHTS,
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
**kwargs: Any,
Expand Down Expand Up @@ -163,6 +165,8 @@ def cross_compile_for_windows(
custom_engine_cache (Optional[BaseEngineCache]): Engine cache instance to use for saving and loading engines. Users can provide their own engine cache by inheriting from BaseEngineCache. If used, engine_cache_dir and engine_cache_size will be ignored.
use_explicit_typing (bool): This flag enables strong typing in TensorRT compilation which respects the precisions set in the Pytorch model. This is useful when users have mixed precision graphs.
use_fp32_acc (bool): This option inserts cast to FP32 nodes around matmul layers and TensorRT ensures the accumulation of matmul happens in FP32. Use this only when FP16 precision is configured in enabled_precisions.
refit_identical_engine_weights (bool): Refit engines with identical weights. This is useful when the same model is compiled multiple times with different inputs and the weights are the same. This will save time by reusing the same engine for different inputs.
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
enable_weight_streaming (bool): Enable weight streaming.
**kwargs: Any,
Expand Down Expand Up @@ -193,17 +197,44 @@ def cross_compile_for_windows(

if "refit" in kwargs.keys():
warnings.warn(
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
"`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if immutable_weights:
raise ValueError(
"Use flag `immutable_weights` only. Flag `refit` is deprecated."
)
else:
immutable_weights = not kwargs["refit"]

if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
"`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if immutable_weights:
raise ValueError(
"Use flag `immutable_weights` only. Flag `refit` is deprecated."
)
else:
immutable_weights = not kwargs["make_refittable"]

if refit_identical_engine_weights:
if immutable_weights:
raise ValueError(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

engine_capability = EngineCapability._from(engine_capability)

Expand Down Expand Up @@ -288,6 +319,8 @@ def cross_compile_for_windows(
"lazy_engine_init": lazy_engine_init,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"refit_identical_engine_weights": refit_identical_engine_weights,
"strip_engine_weights": strip_engine_weights,
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": True,
"enable_weight_streaming": enable_weight_streaming,
Expand Down Expand Up @@ -475,17 +508,44 @@ def compile(

if "refit" in kwargs.keys():
warnings.warn(
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
"`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if immutable_weights:
raise ValueError(
"Use flag `immutable_weights` only. Flag `refit` is deprecated."
)
else:
immutable_weights = not kwargs["refit"]

if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
"`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if immutable_weights:
raise ValueError(
"Use flag `immutable_weights` only. Flag `refit` is deprecated."
)
else:
immutable_weights = not kwargs["make_refittable"]

if refit_identical_engine_weights:
if immutable_weights:
raise ValueError(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

if (
"enable_cross_compile_for_windows" in kwargs.keys()
Expand Down Expand Up @@ -965,18 +1025,47 @@ def convert_exported_program_to_serialized_trt_engine(
DeprecationWarning,
stacklevel=2,
)

if "refit" in kwargs.keys():
warnings.warn(
"`refit` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
"`refit` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if immutable_weights:
raise ValueError(
"Use flag `immutable_weights` only. Flag `refit` is deprecated."
)
else:
immutable_weights = not kwargs["refit"]

if "make_refittable" in kwargs.keys():
warnings.warn(
"`make_refittable` is deprecated. Engines are refittable by default. Please set immutable_weights=True to build a non-refittable engine whose weights will be fixed.",
"`make_refittable` is deprecated. Please set `immutable_weights=True` to build a non-refittable engine whose weights will be fixed.",
DeprecationWarning,
stacklevel=2,
)
if immutable_weights:
raise ValueError(
"Use flag `immutable_weights` only. Flag `refit` is deprecated."
)
else:
immutable_weights = not kwargs["make_refittable"]

if refit_identical_engine_weights:
if immutable_weights:
raise ValueError(
"`immutable_weights` must be False when `refit_identical_engine_weights` is True."
)

if (
not immutable_weights
and not refit_identical_engine_weights
and enable_weight_streaming
):
raise ValueError(
"TensorRT's `REFIT` flag is not compatible with `enable_weight_streaming=True` for now. This issue was reported on https://github.com/pytorch/TensorRT/issues/3305"
)

if arg_inputs is None and inputs is None:
raise AssertionError("'arg_inputs' and 'inputs' should not both be None.")
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
USE_FP32_ACC = False
REFIT_IDENTICAL_ENGINE_WEIGHTS = False
STRIP_ENGINE_WEIGHTS = False
IMMUTABLE_WEIGHTS = False
IMMUTABLE_WEIGHTS = True
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False

Expand Down
4 changes: 4 additions & 0 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ def refit_module_weights(

assert settings is not None

assert (
not settings.immutable_weights
), "Refitting is not enabled. Please recompile the engine with immutable_weights=False."

if settings.debug:
set_log_level(logger.parent, logging.DEBUG)

Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _pretraced_backend(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
logger.warning(
logger.error(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
Expand Down
13 changes: 7 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ def _populate_trt_builder_config(
# non-refittable engine
if self.compilation_settings.strip_engine_weights:
_LOGGER.warning("strip_engine_weights will be ignored.")
if self.compilation_settings.refit_identical_engine_weights:
_LOGGER.warning("refit_identical_engine_weights will be ignored.")
else:
# refittable engine
if self.compilation_settings.refit_identical_engine_weights:
Expand Down Expand Up @@ -496,16 +498,15 @@ def _save_weight_mapping(self) -> None:
suffix = sd_weight_name_list[-1]
# Retrieve each weight name(s) in state_dict
if layer_type == "CONSTANT":
if "embedding" in suffix:
sd_weight_name = f"{sd_weight_name}.weight"
elif "weight" in suffix or "mm_other" in suffix:
# Linear layer weight
if (
"embedding" in suffix
or "weight" in suffix
or "mm_other" in suffix
):
sd_weight_name = f"{sd_weight_name}.weight"
elif "running_mean" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_mean"
elif "running_var" in suffix:
# Linear layer weight
sd_weight_name = f"{sd_weight_name}.running_var"
elif "bias" in suffix:
sd_weight_name = f"{sd_weight_name}.bias"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
Union[torch.dtype, dtype]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -102,7 +103,7 @@ def __init__(
assume_dynamic_shape_support (bool): Setting this to true enables the converters work for both dynamic and static shapes. Default: False
sparse_weights (bool): Enable sparsity for convolution and fully connected layers.
enabled_precision (Set(Union(torch.dtype, torch_tensorrt.dtype))): The set of datatypes that TensorRT can use when selecting kernels
refit (bool): Enable refitting
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable.
debug (bool): Enable debuggable engine
capability (torch_tensorrt.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels
num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels
Expand Down Expand Up @@ -150,6 +151,9 @@ def __init__(
self.kwarg_inputs: dict[str, Any] = {}
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
assert (
not immutable_weights
), "`immutable_weights` has to be False for a MutableTorchTensorRTModule."
compilation_options = {
"enabled_precisions": (
enabled_precisions
Expand All @@ -176,6 +180,7 @@ def __init__(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"immutable_weights": immutable_weights,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand Down
4 changes: 4 additions & 0 deletions tests/py/dynamo/conversion/test_cumsum_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def forward(self, x):
self.run_test(
Cumsum(),
inputs,
immutable_weights=True,
)

@parameterized.expand(
Expand All @@ -43,6 +44,7 @@ def forward(self, x):
self.run_test(
Cumsum(),
inputs,
immutable_weights=True,
)

@parameterized.expand(
Expand All @@ -63,6 +65,7 @@ def forward(self, x):
self.run_test(
Cumsum(),
inputs,
immutable_weights=True,
)

@parameterized.expand(
Expand Down Expand Up @@ -92,6 +95,7 @@ def forward(self, x):
self.run_test_with_dynamic_shape(
Cumsum(),
inputs,
immutable_weights=True,
)


Expand Down
Loading

0 comments on commit 05b560d

Please sign in to comment.