Skip to content

Commit

Permalink
Added refitting acceleration (#2983)
Browse files Browse the repository at this point in the history
  • Loading branch information
cehongwang authored Aug 13, 2024
1 parent e90576a commit 3ee2d81
Show file tree
Hide file tree
Showing 7 changed files with 732 additions and 58 deletions.
163 changes: 134 additions & 29 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import collections.abc
import copy
import logging
from typing import Any, Optional, Sequence, Tuple
from typing import Any, List, Optional, Sequence, Tuple

import numpy as np
import tensorrt as trt
Expand All @@ -13,7 +13,7 @@
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import partitioning
from torch_tensorrt.dynamo._exporter import inline_torch_modules
from torch_tensorrt.dynamo.conversion import CompilationSettings
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -108,38 +108,97 @@ def construct_refit_mapping(
return weight_map


def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any], state_dict: dict[Any, Any]
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
trt_dtype = dtype.try_from(np_weight_type).to(trt.DataType)
torch_dtype = dtype.try_from(np_weight_type).to(torch.dtype)
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
# Batch Norm Layer
params = {}
for w in sd_weight_name:
params[w.split(".")[-1]] = state_dict[w]
scale = params["weight"] / torch.sqrt(params["running_var"] + 1e-7)
shift = params["bias"] - params["running_mean"] * scale
# Set scale to scale or shift to shift
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
)

elif sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
engine_weight_map[engine_weight_name] = state_dict[sd_weight_name]

engine_weight_map[engine_weight_name] = (
engine_weight_map[engine_weight_name]
.clone()
.reshape(-1)
.contiguous()
.to(torch_dtype),
trt_dtype,
)

return engine_weight_map


def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
input_list: Tuple[Any, ...],
input_list: Sequence[Any],
settings: CompilationSettings = CompilationSettings(),
weight_name_map: Optional[dict[str, List[str]]] = None,
) -> None:
"""
Refit a TensorRT Engine in place
"""
# Get the refitting mapping
mapping = construct_refit_mapping(new_gm, input_list, settings)

refitted = set()

trt_wt_location = trt.TensorLocation.HOST
refitter = trt.Refitter(old_engine, TRT_LOGGER)
weight_list = refitter.get_all_weights()

for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)
if weight_name_map:
# Get the refitting mapping
trt_wt_location = trt.TensorLocation.DEVICE
mapping = construct_refit_mapping_from_weight_name_map(
weight_name_map, new_gm.state_dict()
)
for layer_name in weight_list:
if layer_name not in mapping:
logger.warning(f"{layer_name} is not found in weight mapping.")
continue
# Use Numpy to create weights
weight, weight_dtype = mapping[layer_name]
trt_wt_tensor = trt.Weights(
weight_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
assert (
len(refitter.get_missing_weights()) == 0
), "Fast refitting failed due to incomplete mapping"

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")
else:
mapping = construct_refit_mapping(new_gm, input_list, settings)
trt_wt_location = trt.TensorLocation.HOST
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
weight, datatype = mapping[layer_name]
trt_wt_tensor = trt.Weights(datatype, weight.ctypes.data, weight.size)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

if len(refitted) != len(weight_list):
logger.warning("Not all weights have been refitted!!!")

if not refitter.refit_cuda_engine():
logger.error("Error: failed to refit new weights.")
exit(0)
raise AssertionError("Refitting failed.")


def refit_module_weights(
Expand All @@ -148,6 +207,8 @@ def refit_module_weights(
arg_inputs: Optional[Tuple[Any, ...]] = None,
kwarg_inputs: Optional[dict[str, Any]] = None,
verify_output: bool = False,
use_weight_map_cache: bool = True,
in_place: bool = False,
) -> torch.fx.GraphModule:
"""
Refit a compiled graph module with ExportedProgram. This performs weight updates in compiled_module without recompiling the engine.
Expand All @@ -170,7 +231,12 @@ def refit_module_weights(
if len(list(compiled_module.named_children())) == 0:
inline_module = True

compiled_module = copy.deepcopy(compiled_module)
if not in_place:
compiled_module = copy.deepcopy(compiled_module)
elif inline_module:
raise AssertionError(
"Exported program does not support modifying in place. Please set inplace to false and use the returned graph module."
)

# Get the settings and check the setting to be uniform
settings: CompilationSettings = None
Expand All @@ -182,13 +248,14 @@ def refit_module_weights(
for name, engine in compiled_module.__dict__.items()
if "engine" in name
]
encoded_settings = compiled_submodules[0][1].__getstate__()[0][
# [('_run_on_acc_0', inline_module)]
encoded_metadata = compiled_submodules[0][1].__getstate__()[0][
SERIALIZED_METADATA_IDX
]
assert (
encoded_settings != ""
), "Settings are not saved in the engine. Please recompile the engine with make_refitable=True."
settings = TorchTensorRTModule.decode_metadata(encoded_settings)
encoded_metadata != ""
), "The engine provided is either not refittable or was built with a version of Torch-TensorRT that is too old, please recompile using the latest version with make_refitable=True"
settings = TorchTensorRTModule.decode_metadata(encoded_metadata)["settings"]
# Handle torch modules
compiled_submodules_map = dict(compiled_submodules)
for name, submodule in compiled_module.named_children():
Expand Down Expand Up @@ -287,6 +354,7 @@ def refit_module_weights(
# Extract engine from the submodule
try:
if inline_module:
weight_name_map = None
compiled_submodule = compiled_submodules_map[name]
# If this is a torch module, load the old state_dict
if "_run_on_acc" not in name:
Expand All @@ -297,8 +365,33 @@ def refit_module_weights(
engine = get_engine_from_encoded_engine(
engine_info[ENGINE_IDX], runtime
)
if use_weight_map_cache:
encoded_metadata = compiled_submodule.__getstate__()[0][
SERIALIZED_METADATA_IDX
]
weight_name_map = TorchTensorRTModule.decode_metadata(
encoded_metadata
)["weight_name_map"]
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
"This engine does not have a weight map cache. Rebuilding the weight map"
)
else:
compiled_submodule = getattr(compiled_module, name)
weight_name_map = None
if use_weight_map_cache:
try:
weight_name_map = compiled_submodule.weight_name_map
except AttributeError:
logger.warning(
"The module was compiled with an old version of Torch-TensorRT. Rebuilding the weight map."
)
if not weight_name_map:
use_weight_map_cache = False
logger.warning(
"This engine does not have a weight map cache. Rebuilding the weight map"
)
if isinstance(compiled_submodule, PythonTorchTensorRTModule):
engine = compiled_submodule.engine
elif isinstance(compiled_submodule, TorchTensorRTModule):
Expand Down Expand Up @@ -335,13 +428,25 @@ def refit_module_weights(
to_torch_device(settings.device),
name,
)

_refit_single_trt_engine_with_gm(
new_gm=new_submodule,
old_engine=engine,
input_list=submodule_inputs,
settings=settings,
)
try:
_refit_single_trt_engine_with_gm(
new_gm=new_submodule,
old_engine=engine,
input_list=submodule_inputs,
settings=settings,
weight_name_map=weight_name_map,
)
except AssertionError as e:
# If fast_refit is used and failed, we fall back to regular refit
logger.warning(e)
if use_weight_map_cache and weight_name_map:
_refit_single_trt_engine_with_gm(
new_gm=new_submodule,
old_engine=engine,
input_list=submodule_inputs,
settings=settings,
weight_name_map=None,
)

if isinstance(compiled_submodule, TorchTensorRTModule):
serialized_engine = bytes(engine.serialize())
Expand Down
Loading

0 comments on commit 3ee2d81

Please sign in to comment.