Skip to content

Commit

Permalink
docs: Adding words to the refit and engine caching tutorials
Browse files Browse the repository at this point in the history
  • Loading branch information
narendasan committed Sep 3, 2024
1 parent 8e75039 commit a4684e3
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 61 deletions.
1 change: 0 additions & 1 deletion core/runtime/Platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {}
Platform::Platform(Platform::PlatformEnum val) : _platform{val} {}

Platform::Platform(const std::string& platform_str) {
LOG_ERROR("Platform constructor: " << platform_str);
auto name_map = get_name_to_platform_map();
auto it = name_map.find(platform_str);
if (it != name_map.end()) {
Expand Down
1 change: 1 addition & 0 deletions docsrc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
sphinx_gallery_conf = {
"examples_dirs": "../examples",
"gallery_dirs": "tutorials/_rendered_examples/",
"ignore_pattern": "utils.py"
}

# Setup the breathe extension
Expand Down
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ User Guide
user_guide/using_dla
tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage
tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq
tutorials/_rendered_examples/dynamo/engine_caching_example
tutorials/_rendered_examples/dynamo/refit_engine_example

Dynamo Frontend
----------------
Expand Down
2 changes: 2 additions & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ a number of ways you can leverage this backend to accelerate inference.
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
* :ref:`mutable_torchtrt_module_example`: Compile, use, and modify TensorRT Graph Module with MutableTorchTensorRTModule
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
* :ref:`engine_caching_example`: Utilizing engine caching to speed up compilation times
* :ref:`engine_caching_bert_example`: Demonstrating engine caching on BERT
9 changes: 9 additions & 0 deletions examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
"""
.. _engine_caching_bert_example:
Engine Caching (BERT)
=======================
Small caching example on BERT.
"""
import numpy as np
import torch
import torch_tensorrt
Expand Down
177 changes: 149 additions & 28 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
"""
.. _engine_caching_example:
Engine Caching
=======================
As model sizes increase, the cost of compilation will as well. With AOT methods
like ``torch.dynamo.compile``, this cost is paid upfront. However if the weights
change, the session ends or you are using JIT methods like ``torch.compile``, as
graphs get invalidated they get re-compiled, this cost will get paid repeatedly.
Engine caching is a way to mitigate this cost by saving constructed engines to disk
and re-using them when possible. This tutorial demonstrates how to use engine caching
with TensorRT in PyTorch. Engine caching can significantly speed up subsequent model
compilations reusing previously built TensorRT engines.
We'll explore two approaches:
1. Using torch_tensorrt.dynamo.compile
2. Using torch.compile with the TensorRT backend
The example uses a pre-trained ResNet18 model and shows the
differences between compilation without caching, with caching enabled,
and when reusing cached engines.
"""

import os
from typing import Optional
from typing import Optional, Dict

import numpy as np
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from torch_tensorrt.dynamo._defaults import TIMING_CACHE_PATH
from torch_tensorrt.dynamo._engine_caching import BaseEngineCache
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache

np.random.seed(0)
torch.manual_seed(0)
Expand All @@ -22,6 +48,76 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
if os.path.exists(path):
os.remove(path)

# %%
# Engine Caching for JIT Compilation
# ----------------------------------
#
# The primary goal of engine caching is to help speed up JIT workflows. ``torch.compile``
# provides a great deal of flexibility in model construction which makes it a good
# first tool to try when looking to speed up your workflow. However, historically
# the cost of compilation and in particular recompilation has been a barrier to entry
# for many users. If for some reason a subgraph gets invalidated, that graph is reconstructed
# scratch prior to the addition of engine caching. Now as engines are constructed, with ``cache_built_engines=True``,
# engines are saved to disk tied to a hash of their corresponding PyTorch subgraph. If
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.

def torch_compile(iterations=3):
times = []
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

# The 1st iteration is to measure the compilation time without engine caching
# The 2nd and 3rd iterations are to measure the compilation time with engine caching.
# Since the 2nd iteration needs to compile and save the engine, it will be slower than the 1st iteration.
# The 3rd iteration should be faster than the 1st iteration because it loads the cached engine.
for i in range(iterations):
inputs = [torch.rand((100, 3, 224, 224)).to("cuda")]
# remove timing cache and reset dynamo just for engine caching messurement
remove_timing_cache()
torch._dynamo.reset()

if i == 0:
cache_built_engines = False
reuse_cached_engines = False
else:
cache_built_engines = True
reuse_cached_engines = True

start.record()
compiled_model = torch.compile(
model,
backend="tensorrt",
options={
"use_python_runtime": True,
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
)
compiled_model(*inputs) # trigger the compilation
end.record()
torch.cuda.synchronize()
times.append(start.elapsed_time(end))

print("----------------torch_compile----------------")
print("disable engine caching, used:", times[0], "ms")
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")

torch_compile()

# %%
# Engine Caching for AOT Compilation
# ----------------------------------
# Similarly to the JIT workflow, AOT workflows can benefit from engine caching.
# As the same architecture or common subgraphs get recompiled, the cache will pull
# previously built engines and refit the weights.

def dynamo_compile(iterations=3):
times = []
Expand Down Expand Up @@ -72,43 +168,71 @@ def dynamo_compile(iterations=3):
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")

dynamo_compile()

# %%
# Custom Engine Cache
class MyEngineCache(BaseEngineCache):
# ----------------------
#
# By default, the engine cache is stored in the system's temporary directory. Both the cache directory and
# size limit can be customized by passing ``engine_cache_dir`` and ``engine_cache_size``.
# Users can also define their own engine cache implementation by extending the ``BaseEngineCache`` class.
# This allows for remote or shared caching if so desired.
#
# The custom engine cache should implement the following methods:
# - ``save``: Save the engine blob to the cache.
# - ``load``: Load the engine blob from the cache.
#
# The hash provided by the cache systen is a weight agnostic hash of the originating PyTorch subgraph (post lowering).
# The blob contains a serialized engine, calling spec data, and weight map information in the pickle format
#
# Below is an example of a custom engine cache implementation that implents a ``RAMEngineCache``.

class RAMEngineCache(BaseEngineCache):
def __init__(
self,
engine_cache_dir: str,
) -> None:
self.engine_cache_dir = engine_cache_dir
"""
Constructs a user held engine cache in memory.
"""
self.engine_cache: Dict[str, bytes] = {}

def save(
self,
hash: str,
blob: bytes,
prefix: str = "blob",
):
if not os.path.exists(self.engine_cache_dir):
os.makedirs(self.engine_cache_dir, exist_ok=True)
"""
Insert the engine blob to the cache.
path = os.path.join(
self.engine_cache_dir,
f"{prefix}_{hash}.bin",
)
with open(path, "wb") as f:
f.write(blob)
Args:
hash (str): The hash key to associate with the engine blob.
blob (bytes): The engine blob to be saved.
def load(self, hash: str, prefix: str = "blob") -> Optional[bytes]:
path = os.path.join(self.engine_cache_dir, f"{prefix}_{hash}.bin")
if os.path.exists(path):
with open(path, "rb") as f:
blob = f.read()
return blob
return None
Returns:
None
"""
self.engine_cache[hash] = blob

def load(self, hash: str) -> Optional[bytes]:
"""
Load the engine blob from the cache.
def torch_compile(iterations=3):
Args:
hash (str): The hash key of the engine to load.
Returns:
Optional[bytes]: The engine blob if found, None otherwise.
"""
if hash in self.engine_cache:
return self.engine_cache[hash]
else:
return None


def torch_compile_my_cache(iterations=3):
times = []
engine_cache = MyEngineCache("/tmp/your_dir")
engine_cache = RAMEngineCache()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

Expand Down Expand Up @@ -141,7 +265,7 @@ def torch_compile(iterations=3):
"make_refitable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache, # use custom engine cache
"custom_engine_cache": engine_cache,
},
)
compiled_model(*inputs) # trigger the compilation
Expand All @@ -154,7 +278,4 @@ def torch_compile(iterations=3):
print("enable engine caching to cache engines, used:", times[1], "ms")
print("enable engine caching to reuse engines, used:", times[2], "ms")


if __name__ == "__main__":
dynamo_compile()
torch_compile()
torch_compile_my_cache()
84 changes: 63 additions & 21 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
"""
.. _refit_engine_example:
Refit TenorRT Graph Module with Torch-TensorRT
Refitting Torch-TensorRT Programs with New Weights
===================================================================
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
Compilation is an expensive operation as it involves many graph transformations, translations
and optimizations applied on the model. In cases were the weights of a model might be updated
occasionally (e.g. inserting LoRA adapters), the large cost of recompilation can make it infeasible
to use TensorRT if the compiled program needed to be built from scratch each time. Torch-TensorRT
provides a PyTorch native mechanism to update the weights of a compiled TensorRT program without
recompiling from scratch through weight refitting.
In this tutorial, we are going to walk through
1. Compiling a PyTorch model to a TensorRT Graph Module
2. Save and load a graph module
3. Refit the graph module
1. Compiling a PyTorch model to a TensorRT Graph Module
2. Save and load a graph module
3. Refit the graph module
This tutorial focuses mostly on the AOT workflow where it is most likely that a user might need to
manually refit a module. In the JIT workflow, weight changes trigger recompilation. As the engine
has previously been built, with an engine cache enabled, Torch-TensorRT can automatically recognize
a previously built engine, trigger refit and short cut recompilation on behalf of the user (see: :ref:`engine_caching_example`).
"""

# %%
Expand All @@ -36,10 +43,17 @@


# %%
# Compile the module for the first time and save it.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

model = models.resnet18(pretrained=True).eval().to("cuda")
# Make a Refitable Compilation Program
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
#
# In this case we are going to compile a ResNet18 model with randomly initialized weights and save it.

model = models.resnet18(pretrained=False).eval().to("cuda")
exp_program = torch.export.export(model, tuple(inputs))
enabled_precisions = {torch.float}
debug = False
Expand All @@ -59,16 +73,20 @@
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
# This is only supported when use_python_runtime = False
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)


# %%
# Refit the module with update model weights
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Refit the Program with Pretrained Weights
# ------------------------------------------
#
# Random weights are not useful for inference. But now instead of recompiling the model, we can
# refit the model with the pretrained weights. This is done by setting up another PyTorch module
# with the target weights and exporting it as an ExportedProgram. Then the ``refit_module_weights``
# function is used to update the weights of the compiled module with the new weights.

# Create and compile the updated model
model2 = models.resnet18(pretrained=False).eval().to("cuda")
model2 = models.resnet18(pretrained=True).eval().to("cuda")
exp_program2 = torch.export.export(model2, tuple(inputs))


Expand All @@ -91,8 +109,32 @@
print("Refit successfully!")

# %%
# Alternative Workflow using Python Runtime
#
# Advanced Usage
# -----------------------------

# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.
#
# There are a number of settings you can use to control the refit process
#
# Weight Map Cache
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# Weight refitting works by matching the weights of the compiled module with the new weights from
# the user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish,
# the only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the
# early phases of the compilation process to generate near identical weight names. This can be expensive
# and is not always necessary.
#
# To avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch
# weights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used
# to speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at
# refit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter.
#
# Since the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting
# ``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit
# system will run the refitted module and the user supplied module on the same inputs and compare the outputs.
#
# In-Place Refit
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
#
# ``in_place`` allows the user to refit the module in place. This is useful when the user wants to update the weights
# of the compiled module without creating a new module.
Loading

0 comments on commit a4684e3

Please sign in to comment.