Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Adding shape distingushing to the engine cache #3154

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/dynamo/engine_caching_bert_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def compile_bert(iterations=3):
"truncate_double": True,
"debug": False,
"min_block_size": 1,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"engine_cache_dir": "/tmp/torch_trt_bert_engine_cache",
Expand Down
8 changes: 4 additions & 4 deletions examples/dynamo/engine_caching_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def remove_timing_cache(path=TIMING_CACHE_PATH):
# in a subsequent compilation, either as part of this session or a new session, the cache will
# pull the built engine and **refit** the weights which can reduce compilation times by orders of magnitude.
# As such, in order to insert a new engine into the cache (i.e. ``cache_built_engines=True``),
# the engine must be refitable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.
# the engine must be refittable (``make_refittable=True``). See :ref:`refit_engine_example` for more details.


def torch_compile(iterations=3):
Expand Down Expand Up @@ -97,7 +97,7 @@ def torch_compile(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
},
Expand Down Expand Up @@ -157,7 +157,7 @@ def dynamo_compile(iterations=3):
enabled_precisions=enabled_precisions,
debug=debug,
min_block_size=min_block_size,
make_refitable=True,
make_refittable=True,
cache_built_engines=cache_built_engines,
reuse_cached_engines=reuse_cached_engines,
engine_cache_size=1 << 30, # 1GB
Expand Down Expand Up @@ -268,7 +268,7 @@ def torch_compile_my_cache(iterations=3):
"enabled_precisions": enabled_precisions,
"debug": debug,
"min_block_size": min_block_size,
"make_refitable": True,
"make_refittable": True,
"cache_built_engines": cache_built_engines,
"reuse_cached_engines": reuse_cached_engines,
"custom_engine_cache": engine_cache,
Expand Down
4 changes: 2 additions & 2 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
settings = {
"use_python": False,
"enabled_precisions": {torch.float32},
"make_refitable": True,
"make_refittable": True,
}

model = models.resnet18(pretrained=True).eval().to("cuda")
Expand Down Expand Up @@ -80,7 +80,7 @@
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"make_refitable": True,
"make_refittable": True,
}

model_id = "runwayml/stable-diffusion-v1-5"
Expand Down
6 changes: 3 additions & 3 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@


# %%
# Make a Refitable Compilation Program
# Make a refittable Compilation Program
# ---------------------------------------
#
# The inital step is to compile a module and save it as with a normal. Note that there is an
# additional parameter `make_refitable` that is set to `True`. This parameter is used to
# additional parameter `make_refittable` that is set to `True`. This parameter is used to
# indicate that the engine being built should support weight refitting later. Engines built without
# these setttings will not be able to be refit.
#
Expand All @@ -69,7 +69,7 @@
debug=debug,
min_block_size=min_block_size,
torch_executed_ops=torch_executed_ops,
make_refitable=True,
make_refittable=True,
) # Output is a torch.fx.GraphModule

# Save the graph module as an exported program
Expand Down
22 changes: 11 additions & 11 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compile(
Set[Union[torch.dtype, dtype]], Tuple[Union[torch.dtype, dtype]]
] = _defaults.ENABLED_PRECISIONS,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
make_refitable: bool = _defaults.MAKE_REFITABLE,
make_refittable: bool = _defaults.MAKE_refittable,
debug: bool = _defaults.DEBUG,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
workspace_size: int = _defaults.WORKSPACE_SIZE,
Expand Down Expand Up @@ -180,14 +180,14 @@ def compile(

if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
if make_refitable:
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
if make_refittable:
raise ValueError("Use flag make_refittable only. Flag refit is deprecated.")
else:
make_refitable = kwargs["refit"]
make_refittable = kwargs["refit"]

engine_capability = EngineCapability._from(engine_capability)

Expand Down Expand Up @@ -238,8 +238,8 @@ def compile(
engine_cache = None
if cache_built_engines or reuse_cached_engines:
assert (
make_refitable
), "Engine caching requires make_refitable to be set to True"
make_refittable
), "Engine caching requires make_refittable to be set to True"
engine_cache = (
custom_engine_cache
if custom_engine_cache is not None
Expand Down Expand Up @@ -270,7 +270,7 @@ def compile(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refitable": make_refitable,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"dla_sram_size": dla_sram_size,
"dla_local_dram_size": dla_local_dram_size,
Expand Down Expand Up @@ -513,7 +513,7 @@ def convert_exported_program_to_serialized_trt_engine(
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
disable_tf32: bool = _defaults.DISABLE_TF32,
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
make_refitable: bool = _defaults.MAKE_REFITABLE,
make_refittable: bool = _defaults.MAKE_refittable,
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
Expand Down Expand Up @@ -600,7 +600,7 @@ def convert_exported_program_to_serialized_trt_engine(
)
if "refit" in kwargs.keys():
warnings.warn(
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
"Refit is deprecated. Please use make_refittable=True if you want to enable refitting of the engine.",
DeprecationWarning,
stacklevel=2,
)
Expand Down Expand Up @@ -646,7 +646,7 @@ def convert_exported_program_to_serialized_trt_engine(
"require_full_compilation": require_full_compilation,
"disable_tf32": disable_tf32,
"sparse_weights": sparse_weights,
"make_refitable": make_refitable,
"make_refittable": make_refittable,
"engine_capability": engine_capability,
"num_avg_timing_iters": num_avg_timing_iters,
"dla_sram_size": dla_sram_size,
Expand Down
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
MAKE_REFITABLE = False
MAKE_refittable = False
REQUIRE_FULL_COMPILATION = False
DRYRUN = False
HARDWARE_COMPATIBLE = False
Expand Down
84 changes: 77 additions & 7 deletions py/torch_tensorrt/dynamo/_engine_cache.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import copy
from dataclasses import asdict
import logging
import os
import io
import pickle
import pickletools
import shutil
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, cast
from ast import literal_eval

from sympy.polys.matrices.dense import Sequence
import torch
from torch._inductor.codecache import FxGraphCachePickler
from torch._inductor.codecache import FxGraphCachePickler, sha256_hash
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
from torch_tensorrt.dynamo._settings import CompilationSettings, _SETTINGS_TO_BE_ENGINE_INVARIANT
from torch_tensorrt._Input import Input

_LOGGER: logging.Logger = logging.getLogger(__name__)

UnpackedCacheHit = Tuple[bytes, List[str], List[str], Tuple[Input], CompilationSettings, Optional[Dict[Any, Any]]]

class BaseEngineCache(ABC):

Expand All @@ -24,7 +32,7 @@ def __init__(
pass

@staticmethod
def get_hash(gm: torch.fx.GraphModule) -> str:
def get_hash(gm: torch.fx.GraphModule, input_specs: Sequence[Input], settings: CompilationSettings) -> str:
"""Get the hash value of the GraphModule

Args:
Expand All @@ -39,7 +47,22 @@ def get_hash(gm: torch.fx.GraphModule) -> str:
for name, param in new_gm.named_parameters():
param.data.zero_()

hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))
graph_hash_val = cast(str, FxGraphCachePickler.get_hash(new_gm))

input_spec_strs = [str(i) for i in input_specs]
with io.BytesIO() as stream:
input_specs_data = pickle.dumps(input_spec_strs)
input_specs_data = pickletools.optimize(input_specs_data)
input_specs_hash = sha256_hash(input_specs_data)

invariant_engine_specs = [str(getattr(settings, field)) for field in _SETTINGS_TO_BE_ENGINE_INVARIANT]
with io.BytesIO() as stream:
engine_specs_data = pickle.dumps(invariant_engine_specs)
engine_specs_data = pickletools.optimize(engine_specs_data)
engine_specs_hash = sha256_hash(engine_specs_data)

# TODO: Super first idea I had hash combination solution @Evan please iterate on this
hash_val = graph_hash_val + input_specs_hash + engine_specs_hash

return hash_val

Expand All @@ -48,6 +71,8 @@ def pack(
serialized_engine: bytes,
input_names: List[str],
output_names: List[str],
input_specs: Tuple[Input],
compilation_settings: CompilationSettings,
weight_name_map: Optional[Dict[Any, Any]],
) -> bytes:
"""Pack serialized engine, input names, output names, and weight map into a single blob
Expand All @@ -61,35 +86,80 @@ def pack(
Returns:
bytes: packed blob
"""

settings = copy.deepcopy(compilation_settings)
settings.torch_executed_ops = {
f"torch.ops.{op.__str__()}"
for op in settings.torch_executed_ops
}

return pickle.dumps(
{
"serialized_engine": bytes(serialized_engine),
"input_names": input_names,
"output_names": output_names,
"input_specs": input_specs,
"compilation_settings": settings,
"weight_name_map": weight_name_map,
}
)

@staticmethod
def unpack(
packed_obj: bytes,
) -> Tuple[bytes, List[str], List[str], Optional[Dict[Any, Any]]]:
def unpack(packed_obj: bytes) -> UnpackedCacheHit:
"""Unpack packed blob into serialized engine, input names, output names, and weight map

Args:
packed_obj (bytes): packed blob

Returns:
Tuple[bytes, List[str], List[str], Optional[Dict[str, Any]]]: serialized engine, input names, output names, weight name map
Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[str, Any]]]: serialized engine, input names, output names, CompilationSettings, weight name map
"""
unpacked = pickle.loads(packed_obj)
return (
unpacked["serialized_engine"],
unpacked["input_names"],
unpacked["output_names"],
unpacked["input_specs"],
unpacked["compilation_settings"],
unpacked["weight_name_map"],
)

def insert(self, hash: str, entry: UnpackedCacheHit, *args: Any, **kwargs: Any) -> None:
"""
Insert a cache entry into the engine cache.

Args:
hash (str): The hash value of the GraphModule.
entry (Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]): The cache entry to be inserted.
*args: Variable length argument list passed to ``save``.
**kwargs: Arbitrary keyword arguments passed to ``save``.

Returns:
None
"""
packed_cache_info = BaseEngineCache.pack(*entry)
return self.save(hash, packed_cache_info, *args, **kwargs)


def check(self, hash: str, *args: Any, **kwargs: Any) -> Optional[UnpackedCacheHit]:
"""
Check if a cache entry exists for the given hash.

Args:
hash (str): The hash value of the GraphModule.
*args: Variable length argument list passed to ``load``.
**kwargs: Arbitrary keyword arguments passed to ``load``.

Returns:
Optional[Tuple[bytes, List[str], List[str], CompilationSettings, Optional[Dict[Any, Any]]]]: The unpacked cache entry if found, None otherwise.
"""
packed_cache_info = self.load(hash, *args, **kwargs)

if packed_cache_info:
return BaseEngineCache.unpack(packed_cache_info)
else:
return None

@abstractmethod
def save(self, hash: str, blob: bytes, *args: Any, **kwargs: Any) -> None:
"""Store blob in cache
Expand Down
4 changes: 2 additions & 2 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def refit_module_weights(
]
assert (
encoded_metadata != ""
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refittable=True"
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
# Handle torch modules
compiled_submodules_map = dict(compiled_submodules)
Expand All @@ -270,7 +270,7 @@ def refit_module_weights(
settings = submodule.settings

assert (
settings.make_refitable
settings.make_refittable
), "Refitting is not enabled. Please recompile the engine with refit=True."

if settings.debug:
Expand Down
31 changes: 28 additions & 3 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass, field
from typing import Collection, Optional, Set, Union
from typing import Collection, Optional, Set, Union, Tuple

from torch.fx.node import Target
from torch_tensorrt._Device import Device
Expand All @@ -18,7 +18,7 @@
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
LAZY_ENGINE_INIT,
MAKE_REFITABLE,
MAKE_refittable,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
NUM_AVG_TIMING_ITERS,
Expand Down Expand Up @@ -98,7 +98,7 @@ class CompilationSettings:
disable_tf32: bool = DISABLE_TF32
assume_dynamic_shape_support: bool = ASSUME_DYNAMIC_SHAPE_SUPPORT
sparse_weights: bool = SPARSE_WEIGHTS
make_refitable: bool = MAKE_REFITABLE
make_refittable: bool = MAKE_refittable
engine_capability: EngineCapability = field(
default_factory=lambda: ENGINE_CAPABILITY
)
Expand All @@ -112,3 +112,28 @@ class CompilationSettings:
lazy_engine_init: bool = LAZY_ENGINE_INIT
cache_built_engines: bool = CACHE_BUILT_ENGINES
reuse_cached_engines: bool = REUSE_CACHED_ENGINES


_SETTINGS_TO_BE_ENGINE_INVARIANT =(
"enabled_precisions",
"max_aux_streams",
"version_compatible",
"optimization_level",
"disable_tf32",
"sparse_weights",
"make_refittable",
"engine_capability",
"hardware_compatible",)


def settings_are_compatible(set_a: CompilationSettings, set_b: CompilationSettings) -> Tuple[bool, Set[str]]:
incompatible_settings: Set[str] = set()

for field in _SETTINGS_TO_BE_ENGINE_INVARIANT:
if getattr(set_a, field) != getattr(set_b, field):
incompatible_settings.add(field)

if len(incompatible_settings) == 0:
return True, set()
else:
return False, incompatible_settings
Loading
Loading