diff --git a/core/runtime/Platform.cpp b/core/runtime/Platform.cpp index a20159cd91..03d9e7580b 100644 --- a/core/runtime/Platform.cpp +++ b/core/runtime/Platform.cpp @@ -36,7 +36,6 @@ Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {} Platform::Platform(Platform::PlatformEnum val) : _platform{val} {} Platform::Platform(const std::string& platform_str) { - LOG_ERROR("Platform constructor: " << platform_str); auto name_map = get_name_to_platform_map(); auto it = name_map.find(platform_str); if (it != name_map.end()) { diff --git a/docsrc/conf.py b/docsrc/conf.py index 2e782358cb..c4ddd6eaec 100644 --- a/docsrc/conf.py +++ b/docsrc/conf.py @@ -93,6 +93,7 @@ sphinx_gallery_conf = { "examples_dirs": "../examples", "gallery_dirs": "tutorials/_rendered_examples/", + "ignore_pattern": "utils.py" } # Setup the breathe extension diff --git a/docsrc/index.rst b/docsrc/index.rst index da5ee3d690..d1a91beabc 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -51,6 +51,8 @@ User Guide user_guide/using_dla tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq + tutorials/_rendered_examples/dynamo/engine_caching_example + tutorials/_rendered_examples/dynamo/refit_engine_example Dynamo Frontend ---------------- diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index 22ae16ec0c..ff3563cffe 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -15,3 +15,5 @@ a number of ways you can leverage this backend to accelerate inference. * :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights * :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule * :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile`` +* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times +* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT diff --git a/examples/dynamo/engine_caching_bert_example.py b/examples/dynamo/engine_caching_bert_example.py index 43cfc5f15a..cb07e3adbb 100644 --- a/examples/dynamo/engine_caching_bert_example.py +++ b/examples/dynamo/engine_caching_bert_example.py @@ -1,3 +1,12 @@ +""" + +.. _engine_caching_bert_example: + +Engine Caching (BERT) +======================= + +Small caching example on BERT. +""" import numpy as np import torch import torch_tensorrt diff --git a/examples/dynamo/engine_caching_example.py b/examples/dynamo/engine_caching_example.py index 2d1018bb6e..ed29f5b268 100644 --- a/examples/dynamo/engine_caching_example.py +++ b/examples/dynamo/engine_caching_example.py @@ -1,12 +1,38 @@ +""" + +.. _engine_caching_example: + +Engine Caching +======================= + +As model sizes increase, the cost of compilation will as well. With AOT methods +like ``torch.dynamo.compile``, this cost is paid upfront. However if the weights +change, the session ends or you are using JIT methods like ``torch.compile``, as +graphs get invalidated they get re-compiled, this cost will get paid repeatedly. +Engine caching is a way to mitigate this cost by saving constructed engines to disk +and re-using them when possible. This tutorial demonstrates how to use engine caching +with TensorRT in PyTorch. Engine caching can significantly speed up subsequent model +compilations reusing previously built TensorRT engines. + +We'll explore two approaches: + + 1. Using torch_tensorrt.dynamo.compile + 2. Using torch.compile with the TensorRT backend + +The example uses a pre-trained ResNet18 model and shows the +differences between compilation without caching, with caching enabled, +and when reusing cached engines. +""" + import os -from typing import Optional +from typing import Optional, Dict import numpy as np import torch import torch_tensorrt as torch_trt import torchvision.models as models from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache +from torch_tensorrt.dynamo._engine_cache import BaseEngineCache np.random.seed(0) torch.manual_seed(0) @@ -22,6 +48,76 @@ def remove_timing_cache(path=TIMING_CACHE_PATH): if os.path.exists(path): os.remove(path) +# %% +# Engine Caching for JIT Compilation +# ---------------------------------- +# +# The primary goal of engine caching is to help speed up JIT workflows. ``torch.compile`` +# provides a great deal of flexibility in model construction which makes it a good +# first tool to try when looking to speed up your workflow. However, historically +# the cost of compilation and in particular recompilation has been a barrier to entry +# for many users. If for some reason a subgraph gets invalidated, that graph is reconstructed +# scratch prior to the addition of engine caching. Now as engines are constructed, with ``cache_built_engines=True``, +# engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If +# in a subsequent compilation, either as part of this session or a new session, the cache will +# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude. +# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``), +# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details. + +def torch_compile(iterations=3): + times = [] + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + # The 1st iteration is to measure the compilation time without engine caching + # The 2nd and 3rd iterations are to measure the compilation time with engine caching. + # Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration. + # The 3rd iteration should be faster than the 1st iteration because it loads the cached engine. + for i in range(iterations): + inputs = [torch.rand((100, 3, 224, 224)).to("cuda")] + # remove timing cache and reset dynamo just for engine caching messurement + remove_timing_cache() + torch._dynamo.reset() + + if i == 0: + cache_built_engines = False + reuse_cached_engines = False + else: + cache_built_engines = True + reuse_cached_engines = True + + start.record() + compiled_model = torch.compile( + model, + backend="tensorrt", + options={ + "use_python_runtime": True, + "enabled_precisions": enabled_precisions, + "debug": debug, + "min_block_size": min_block_size, + "make_refitable": True, + "cache_built_engines": cache_built_engines, + "reuse_cached_engines": reuse_cached_engines, + }, + ) + compiled_model(*inputs) # trigger the compilation + end.record() + torch.cuda.synchronize() + times.append(start.elapsed_time(end)) + + print("----------------torch_compile----------------") + print("disable engine caching, used:", times[0], "ms") + print("enable engine caching to cache engines, used:", times[1], "ms") + print("enable engine caching to reuse engines, used:", times[2], "ms") + +torch_compile() + +# %% +# Engine Caching for AOT Compilation +# ---------------------------------- +# Similarly to the JIT workflow, AOT workflows can benefit from engine caching. +# As the same architecture or common subgraphs get recompiled, the cache will pull +# previously built engines and refit the weights. def dynamo_compile(iterations=3): times = [] @@ -72,43 +168,71 @@ def dynamo_compile(iterations=3): print("enable engine caching to cache engines, used:", times[1], "ms") print("enable engine caching to reuse engines, used:", times[2], "ms") +dynamo_compile() +# %% # Custom Engine Cache -class MyEngineCache(BaseEngineCache): +# ---------------------- +# +# By default, the engine cache is stored in the system's temporary directory. Both the cache directory and +# size limit can be customized by passing ``engine_cache_dir`` and ``engine_cache_size``. +# Users can also define their own engine cache implementation by extending the ``BaseEngineCache`` class. +# This allows for remote or shared caching if so desired. +# +# The custom engine cache should implement the following methods: +# - ``save``: Save the engine blob to the cache. +# - ``load``: Load the engine blob from the cache. +# +# The hash provided by the cache systen is a weight agnostic hash of the originating PyTorch subgraph (post lowering). +# The blob contains a serialized engine, calling spec data, and weight map information in the pickle format +# +# Below is an example of a custom engine cache implementation that implents a ``RAMEngineCache``. + +class RAMEngineCache(BaseEngineCache): def __init__( self, - engine_cache_dir: str, ) -> None: - self.engine_cache_dir = engine_cache_dir + """ + Constructs a user held engine cache in memory. + """ + self.engine_cache: Dict[str, bytes] = {} def save( self, hash: str, blob: bytes, - prefix: str = "blob", ): - if not os.path.exists(self.engine_cache_dir): - os.makedirs(self.engine_cache_dir, exist_ok=True) + """ + Insert the engine blob to the cache. - path = os.path.join( - self.engine_cache_dir, - f"{prefix}_{hash}.bin", - ) - with open(path, "wb") as f: - f.write(blob) + Args: + hash (str): The hash key to associate with the engine blob. + blob (bytes): The engine blob to be saved. - def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]: - path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin") - if os.path.exists(path): - with open(path, "rb") as f: - blob = f.read() - return blob - return None + Returns: + None + """ + self.engine_cache[hash] = blob + def load(self, hash: str) -> Optional[bytes]: + """ + Load the engine blob from the cache. -def torch_compile(iterations=3): + Args: + hash (str): The hash key of the engine to load. + + Returns: + Optional[bytes]: The engine blob if found, None otherwise. + """ + if hash in self.engine_cache: + return self.engine_cache[hash] + else: + return None + + +def torch_compile_my_cache(iterations=3): times = [] - engine_cache = MyEngineCache("/tmp/your_dir") + engine_cache = RAMEngineCache() start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) @@ -141,7 +265,7 @@ def torch_compile(iterations=3): "make_refitable": True, "cache_built_engines": cache_built_engines, "reuse_cached_engines": reuse_cached_engines, - "custom_engine_cache": engine_cache, # use custom engine cache + "custom_engine_cache": engine_cache, }, ) compiled_model(*inputs) # trigger the compilation @@ -154,7 +278,4 @@ def torch_compile(iterations=3): print("enable engine caching to cache engines, used:", times[1], "ms") print("enable engine caching to reuse engines, used:", times[2], "ms") - -if __name__ == "__main__": - dynamo_compile() - torch_compile() +torch_compile_my_cache() diff --git a/examples/dynamo/refit_engine_example.py b/examples/dynamo/refit_engine_example.py index c8cd5590d3..1feb033a3a 100644 --- a/examples/dynamo/refit_engine_example.py +++ b/examples/dynamo/refit_engine_example.py @@ -1,19 +1,26 @@ """ .. _refit_engine_example: -Refit TenorRT Graph Module with Torch-TensorRT +Refitting Torch-TensorRT Programs with New Weights =================================================================== -We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights. - -In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products. -That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient. -Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow. +Compilation is an expensive operation as it involves many graph transformations, translations +and optimizations applied on the model. In cases were the weights of a model might be updated +occasionally (e.g. inserting LoRA adapters), the large cost of recompilation can make it infeasible +to use TensorRT if the compiled program needed to be built from scratch each time. Torch-TensorRT +provides a PyTorch native mechanism to update the weights of a compiled TensorRT program without +recompiling from scratch through weight refitting. In this tutorial, we are going to walk through -1. Compiling a PyTorch model to a TensorRT Graph Module -2. Save and load a graph module -3. Refit the graph module + + 1. Compiling a PyTorch model to a TensorRT Graph Module + 2. Save and load a graph module + 3. Refit the graph module + +This tutorial focuses mostly on the AOT workflow where it is most likely that a user might need to +manually refit a module. In the JIT workflow, weight changes trigger recompilation. As the engine +has previously been built, with an engine cache enabled, Torch-TensorRT can automatically recognize +a previously built engine, trigger refit and short cut recompilation on behalf of the user (see: :ref:`engine_caching_example`). """ # %% @@ -36,10 +43,17 @@ # %% -# Compile the module for the first time and save it. -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -model = models.resnet18(pretrained=True).eval().to("cuda") +# Make a Refitable Compilation Program +# --------------------------------------- +# +# The inital step is to compile a module and save it as with a normal. Note that there is an +# additional parameter `make_refitable` that is set to `True`. This parameter is used to +# indicate that the engine being built should support weight refitting later. Engines built without +# these setttings will not be able to be refit. +# +# In this case we are going to compile a ResNet18 model with randomly initialized weights and save it. + +model = models.resnet18(pretrained=False).eval().to("cuda") exp_program = torch.export.export(model, tuple(inputs)) enabled_precisions = {torch.float} debug = False @@ -59,16 +73,20 @@ ) # Output is a torch.fx.GraphModule # Save the graph module as an exported program -# This is only supported when use_python_runtime = False torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs) # %% -# Refit the module with update model weights -# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Refit the Program with Pretrained Weights +# ------------------------------------------ +# +# Random weights are not useful for inference. But now instead of recompiling the model, we can +# refit the model with the pretrained weights. This is done by setting up another PyTorch module +# with the target weights and exporting it as an ExportedProgram. Then the ``refit_module_weights`` +# function is used to update the weights of the compiled module with the new weights. # Create and compile the updated model -model2 = models.resnet18(pretrained=False).eval().to("cuda") +model2 = models.resnet18(pretrained=True).eval().to("cuda") exp_program2 = torch.export.export(model2, tuple(inputs)) @@ -91,8 +109,32 @@ print("Refit successfully!") # %% -# Alternative Workflow using Python Runtime +# +# Advanced Usage # ----------------------------- - -# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime. -# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion. +# +# There are a number of settings you can use to control the refit process +# +# Weight Map Cache +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# Weight refitting works by matching the weights of the compiled module with the new weights from +# the user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish, +# the only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the +# early phases of the compilation process to generate near identical weight names. This can be expensive +# and is not always necessary. +# +# To avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch +# weights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used +# to speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at +# refit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter. +# +# Since the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting +# ``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit +# system will run the refitted module and the user supplied module on the same inputs and compare the outputs. +# +# In-Place Refit +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# +# ``in_place`` allows the user to refit the module in place. This is useful when the user wants to update the weights +# of the compiled module without creating a new module. diff --git a/py/torch_tensorrt/dynamo/_compiler.py b/py/torch_tensorrt/dynamo/_compiler.py index 6cd3cf5f5f..9c1bb96dbe 100644 --- a/py/torch_tensorrt/dynamo/_compiler.py +++ b/py/torch_tensorrt/dynamo/_compiler.py @@ -18,7 +18,7 @@ dryrun_stats_display, parse_non_trt_nodes, ) -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache, DiskEngineCache +from torch_tensorrt.dynamo._engine_cache import BaseEngineCache, DiskEngineCache from torch_tensorrt.dynamo.conversion import ( CompilationSettings, UnsupportedOperatorException, diff --git a/py/torch_tensorrt/dynamo/_engine_caching.py b/py/torch_tensorrt/dynamo/_engine_cache.py similarity index 96% rename from py/torch_tensorrt/dynamo/_engine_caching.py rename to py/torch_tensorrt/dynamo/_engine_cache.py index c8ff7aba50..2dd25ecf21 100644 --- a/py/torch_tensorrt/dynamo/_engine_caching.py +++ b/py/torch_tensorrt/dynamo/_engine_cache.py @@ -144,6 +144,8 @@ def get_dir_size(path: str) -> int: if engine_cache_dir not in DiskEngineCache.dir2hash2size_map: DiskEngineCache.dir2hash2size_map[engine_cache_dir] = {} + _LOGGER.info(f"Disk engine cache initialized (cache directory:{self.engine_cache_dir}, max size: {self.total_engine_cache_size})") + def has_available_cache_size(self, needed_size: int) -> bool: """Check if the cache has available space for saving object @@ -184,7 +186,7 @@ def LRU() -> None: engine_hash, 0 ) ) - _LOGGER.info( + _LOGGER.debug( f"Removed the engine cache at {engine_path}, available cache size: {self.available_engine_cache_size} bytes." ) except Exception as e: @@ -228,7 +230,7 @@ def save( try: with open(blob_path, "wb") as f: f.write(blob) - _LOGGER.info(f"The blob was saved to {blob_path}") + _LOGGER.debug(f"The engine added to cache, saved to {blob_path}") except Exception as e: del DiskEngineCache.dir2hash2size_map[self.engine_cache_dir][hash] self.available_engine_cache_size += blob_size @@ -247,5 +249,6 @@ def load(self, hash: str) -> Optional[bytes]: if os.path.exists(blob_path): with open(blob_path, "rb") as f: blob = f.read() + _LOGGER.debug(f"Engine found in cache, loaded from {blob_path}") return blob return None diff --git a/py/torch_tensorrt/dynamo/_refit.py b/py/torch_tensorrt/dynamo/_refit.py index 4ce7d0b150..fa3cfbfb82 100644 --- a/py/torch_tensorrt/dynamo/_refit.py +++ b/py/torch_tensorrt/dynamo/_refit.py @@ -39,6 +39,7 @@ set_log_level, to_torch_device, to_torch_tensorrt_device, + get_model_device, ) from torch_tensorrt.logging import TRT_LOGGER @@ -146,7 +147,7 @@ def _refit_single_trt_engine_with_gm( """ refitted = set() - torch_device = list(new_gm.state_dict().values())[0].device.type + torch_device = get_model_device(new_gm) refitter = trt.Refitter(old_engine, TRT_LOGGER) weight_list = refitter.get_all_weights() diff --git a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py index 3c97c8347a..bad0ebbcaf 100644 --- a/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py +++ b/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py @@ -27,7 +27,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache +from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/conversion/_conversion.py b/py/torch_tensorrt/dynamo/conversion/_conversion.py index cd38ce56e6..06fade9674 100644 --- a/py/torch_tensorrt/dynamo/conversion/_conversion.py +++ b/py/torch_tensorrt/dynamo/conversion/_conversion.py @@ -10,7 +10,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache +from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._TRTInterpreter import ( TRTInterpreter, diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index 66192d59a0..78c6c11175 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -13,7 +13,7 @@ from torch_tensorrt._enums import dtype from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache +from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo._settings import CompilationSettings from packaging import version @@ -509,7 +509,7 @@ def parse_dynamo_kwargs( if kwargs.get("custom_engine_cache") is not None: engine_cache = kwargs.get("custom_engine_cache") else: - from torch_tensorrt.dynamo._engine_caching import DiskEngineCache + from torch_tensorrt.dynamo._engine_cache import DiskEngineCache engine_cache_dir = kwargs.get( "engine_cache_dir", _defaults.ENGINE_CACHE_DIR diff --git a/setup.py b/setup.py index 06b163c51c..de532d9071 100644 --- a/setup.py +++ b/setup.py @@ -484,7 +484,7 @@ def run(self): if not (PY_ONLY or NO_TS): tensorrt_linux_external_dir = ( lambda: subprocess.check_output( - ["bazel", "query", "@tensorrt//:nvinfer", "--output", "location"] + [BAZEL_EXE, "query", "@tensorrt//:nvinfer", "--output", "location"] ) .decode("ascii") .strip() @@ -492,7 +492,7 @@ def run(self): ) tensorrt_windows_external_dir = ( lambda: subprocess.check_output( - ["bazel", "query", "@tensorrt_win//:nvinfer", "--output", "location"] + [BAZEL_EXE, "query", "@tensorrt_win//:nvinfer", "--output", "location"] ) .decode("ascii") .strip() diff --git a/tests/py/dynamo/models/test_engine_cache.py b/tests/py/dynamo/models/test_engine_cache.py index 189a492d4e..770e057a36 100644 --- a/tests/py/dynamo/models/test_engine_cache.py +++ b/tests/py/dynamo/models/test_engine_cache.py @@ -10,7 +10,7 @@ import torchvision.models as models from torch.testing._internal.common_utils import TestCase from torch_tensorrt.dynamo._defaults import ENGINE_CACHE_DIR -from torch_tensorrt.dynamo._engine_caching import BaseEngineCache +from torch_tensorrt.dynamo._engine_cache import BaseEngineCache from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity assertions = unittest.TestCase()