Skip to content

Commit

Permalink
fix: distingush engines based on compilation settings in addition to … (
Browse files Browse the repository at this point in the history
#3155)

Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan authored Sep 12, 2024
1 parent 4d2a04a commit 2be2e64
Show file tree
Hide file tree
Showing 20 changed files with 519 additions and 131 deletions.
2 changes: 1 addition & 1 deletion examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def compile_bert(iterations=3):
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
Expand Down
8 changes: 4 additions & 4 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.


def torch_compile(iterations=3):
Expand Down Expand Up @@ -97,7 +97,7 @@ def torch_compile(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
Expand Down Expand Up @@ -157,7 +157,7 @@ def dynamo_compile(iterations=3):
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
make_refittable=True,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
Expand Down Expand Up @@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
Expand Down
4 changes: 2 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"make_refitable": True,
"make_refittable": True,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -80,7 +80,7 @@
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refitable": True,
"make_refittable": True,
}

model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
6 changes: 3 additions & 3 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@


# %%
# Make a Refitable Compilation Program
# Make a refittable Compilation Program
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
#
Expand All @@ -69,7 +69,7 @@
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
make_refittable=True,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
28 changes: 28 additions & 0 deletions py/torch_tensorrt/_Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,34 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return self.__str__()

@staticmethod
def equivalent_spec(a: Input, b: Input) -> bool:
if a.shape_mode != b.shape_mode:
return False

if a.shape_mode == Input._ShapeMode.DYNAMIC:
assert isinstance(a.shape, dict)
assert isinstance(b.shape, dict)
checks = [
a.shape["min_shape"] == b.shape["min_shape"],
a.shape["opt_shape"] == b.shape["opt_shape"],
a.shape["max_shape"] == b.shape["max_shape"],
a.dtype == b.dtype,
a.format == b.format,
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
]
return all(checks)
else:
checks = [
a.shape == b.shape,
a.dtype == b.dtype,
a.format == b.format,
a.low_tensor_domain_incl == b.low_tensor_domain_incl,
a.high_tensor_domain_excl == b.high_tensor_domain_excl,
]
return all(checks)

@staticmethod
def _supported_input_size_type(input_size: Any) -> bool:
if isinstance(input_size, torch.Size):
Expand Down
22 changes: 11 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compile(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refitable: bool = _defaults.MAKE_REFITABLE,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -180,14 +180,14 @@ def compile(

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
if make_refitable:
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
if make_refittable:
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
else:
make_refitable = kwargs["refit"]
make_refittable = kwargs["refit"]

engine_capability = EngineCapability._from(engine_capability)

Expand Down Expand Up @@ -238,8 +238,8 @@ def compile(
engine_cache = None
if cache_built_engines or reuse_cached_engines:
assert (
make_refitable
), "Engine caching requires make_refitable to be set to True"
make_refittable
), "Engine caching requires make_refittable to be set to True"
engine_cache = (
custom_engine_cache
if custom_engine_cache is not None
Expand Down Expand Up @@ -270,7 +270,7 @@ def compile(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refitable": make_refitable,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand Down Expand Up @@ -513,7 +513,7 @@ def convert_exported_program_to_serialized_trt_engine(
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
disable_tf32: bool = _defaults.DISABLE_TF32,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
make_refitable: bool = _defaults.MAKE_REFITABLE,
make_refittable: bool = _defaults.MAKE_REFITTABLE,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
Expand Down Expand Up @@ -600,7 +600,7 @@ def convert_exported_program_to_serialized_trt_engine(
)
if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -646,7 +646,7 @@ def convert_exported_program_to_serialized_trt_engine(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refitable": make_refitable,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"num_avg_timing_iters": num_avg_timing_iters,
"dla_sram_size": dla_sram_size,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
MAKE_REFITABLE = False
MAKE_REFITTABLE = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
Expand Down
105 changes: 91 additions & 14 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,33 @@
import copy
import io
import logging
import os
import pickle
import pickletools
import shutil
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast

import torch
from torch._inductor.codecache import FxGraphCachePickler
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._settings import (
_SETTINGS_TO_BE_ENGINE_INVARIANT,
CompilationSettings,
)

_LOGGER: logging.Logger = logging.getLogger(__name__)

UnpackedCacheHit = Tuple[
bytes,
List[str],
List[str],
Sequence[Input],
CompilationSettings,
Optional[Dict[str, Any]],
]


class BaseEngineCache(ABC):

Expand All @@ -24,7 +40,11 @@ def __init__(
pass

@staticmethod
def get_hash(gm: torch.fx.GraphModule) -> str:
def get_hash(
gm: torch.fx.GraphModule,
input_specs: Sequence[Input],
settings: CompilationSettings,
) -> str:
"""Get the hash value of the GraphModule
Args:
Expand All @@ -39,7 +59,23 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
for name, param in new_gm.named_parameters():
param.data.zero_()

hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))

input_spec_strs = [str(i) for i in input_specs]
with io.BytesIO() as stream:
input_specs_data = pickle.dumps(input_spec_strs)
input_specs_data = pickletools.optimize(input_specs_data)
input_specs_hash = sha256_hash(input_specs_data)

invariant_engine_specs = [
str(getattr(settings, field)) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT
]
with io.BytesIO() as stream:
engine_specs_data = pickle.dumps(invariant_engine_specs)
engine_specs_data = pickletools.optimize(engine_specs_data)
engine_specs_hash = sha256_hash(engine_specs_data)

hash_val: str = graph_hash_val + input_specs_hash + engine_specs_hash

return hash_val

Expand All @@ -48,6 +84,8 @@ def pack(
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
input_specs: Sequence[Input],
compilation_settings: CompilationSettings,
weight_name_map: Optional[Dict[Any, Any]],
) -> bytes:
"""Pack serialized engine, input names, output names, and weight map into a single blob
Expand All @@ -56,40 +94,83 @@ def pack(
serialized_engine (bytes): serialized TRT engine
input_names (List[str]): input names of TRT engine
output_names (List[str]): output names of TRT engine
input_specs (Sequence[Input]): input specs of TRT engine
compilation_settings (CompilationSettings): compilation settings of TRT engine
weight_name_map (Optional[Dict[Any, Any]]): weight name map for refitting
Returns:
bytes: packed blob
"""

settings = copy.deepcopy(compilation_settings)
return pickle.dumps(
{
"serialized_engine": bytes(serialized_engine),
"input_names": input_names,
"output_names": output_names,
"input_specs": input_specs,
"compilation_settings": settings,
"weight_name_map": weight_name_map,
}
)

@staticmethod
def unpack(
packed_obj: bytes,
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
def unpack(packed_obj: bytes) -> UnpackedCacheHit:
"""Unpack packed blob into serialized engine, input names, output names, and weight map
Args:
packed_obj (bytes): packed blob
Returns:
Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
Tuple[bytes, List[str], List[str], Sequence[Input], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, input specs, CompilationSettings, weight name map
"""
unpacked = pickle.loads(packed_obj)
return (
unpacked["serialized_engine"],
unpacked["input_names"],
unpacked["output_names"],
unpacked["input_specs"],
unpacked["compilation_settings"],
unpacked["weight_name_map"],
)

def insert(
self, hash: str, entry: UnpackedCacheHit, *args: Any, **kwargs: Any
) -> None:
"""
Insert a cache entry into the engine cache.
Args:
hash (str): The hash value of the GraphModule.
entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
*args: Variable length argument list passed to ``save``.
**kwargs: Arbitrary keyword arguments passed to ``save``.
Returns:
None
"""
packed_cache_info = BaseEngineCache.pack(*entry)
return self.save(hash, packed_cache_info, *args, **kwargs)

def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheHit]:
"""
Check if a cache entry exists for the given hash.
Args:
hash (str): The hash value of the GraphModule.
*args: Variable length argument list passed to ``load``.
**kwargs: Arbitrary keyword arguments passed to ``load``.
Returns:
Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
"""
packed_cache_info = self.load(hash, *args, **kwargs)

if packed_cache_info:
return BaseEngineCache.unpack(packed_cache_info)
else:
return None

@abstractmethod
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
"""Store blob in cache
Expand Down Expand Up @@ -203,11 +284,7 @@ def LRU() -> None:
else:
LRU()

def save(
self,
hash: str,
blob: bytes,
) -> None:
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
blob_size = len(blob)
if blob_size > self.total_engine_cache_size:
_LOGGER.warning(
Expand Down Expand Up @@ -244,7 +321,7 @@ def save(
f"The size {blob_size} is still larger than the available cache size {self.available_engine_cache_size}."
)

def load(self, hash: str) -> Optional[bytes]:
def load(self, hash: str, *args: Any, **kwargs: Any) -> Optional[bytes]:
directory = os.path.join(self.engine_cache_dir, hash)
if os.path.exists(directory):
blob_path = os.path.join(directory, "blob.bin")
Expand Down
Loading

0 comments on commit 2be2e64

Please sign in to comment.