Skip to content

Closed the perf gap of resnet and enabled refit #3629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
DYNAMO_CONVERTERS as CONVERTERS,
)
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo.conversion.impl.normalization.ops import (
batch_norm_constant_folding,
)
from torch_tensorrt.dynamo.conversion.truncate_double import repair_double_inputs
from torch_tensorrt.dynamo.lowering import (
get_decompositions,
Expand Down Expand Up @@ -78,8 +81,9 @@ def construct_refit_mapping(
compilation_settings=settings,
)
interpreter._construct_trt_network_def()
weight_refit_map: dict[str, torch.Tensor] = interpreter.ctx.weight_refit_map

return interpreter.ctx.weight_refit_map
return weight_refit_map


@needs_refit
Expand All @@ -90,7 +94,20 @@ def construct_refit_mapping_from_weight_name_map(
) -> dict[Any, Any]:
engine_weight_map = {}
for engine_weight_name, (sd_weight_name, np_weight_type) in weight_name_map.items():
if sd_weight_name not in state_dict:
# Add more constant folding converters here
if engine_weight_name.split(" ")[-1] in ["SCALE", "SHIFT"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should abstract this imo. Like if there are any weight types that require constant folding in converter this should be associated with the converter. Then the refit system will just iterate through all these constant fold operations. Ideally the converter can use the same implementation

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently BN is the only one. Do you think we should have a constant_fold function and have refit and conversion call that function?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah we should think of some sort of abstraction where the constant folding process is defined once for both refit and conversion

# Batch Norm Layer
params = {}
for w in sd_weight_name:
params[w.split(".")[-1]] = state_dict[w].cuda()
# Batch norm constant folding

scale, shift = batch_norm_constant_folding(**params, eps=1e-7)
# Set scale to scale or shift to shift
engine_weight_map[engine_weight_name] = eval(
engine_weight_name.split(" ")[-1].lower()
)
elif sd_weight_name not in state_dict:
# If weights is not in sd, we can leave it unchanged
continue
else:
Expand Down Expand Up @@ -178,10 +195,12 @@ def _refit_single_trt_engine_with_gm(
for layer_name in weight_list:
if layer_name not in mapping:
raise AssertionError(f"{layer_name} is not found in weight mapping")
# Use Numpy to create weights
# Use Tensor to create weights
weight = mapping[layer_name]
trt_dtype = dtype._from(weight.dtype).to(trt.DataType)
trt_wt_tensor = trt.Weights(trt_dtype, weight.ctypes.data, weight.size)
trt_wt_tensor = trt.Weights(
trt_dtype, weight.data_ptr(), torch.numel(weight)
)
refitter.set_named_weights(layer_name, trt_wt_tensor, trt_wt_location)
refitted.add(layer_name)

Expand Down Expand Up @@ -300,7 +319,7 @@ def refit_module_weights(

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
new_gm, settings.debug, settings.torch_executed_ops
new_gm, settings.torch_executed_ops
)

if num_supported_ops == 0 or (
Expand Down Expand Up @@ -363,7 +382,6 @@ def refit_module_weights(

# Iterate over all components that can be accelerated
# Generate the corresponding TRT Module for those
new_weight_module.module().to(CPU_DEVICE)
for name, new_submodule in new_partitioned_module.named_children():
# Refit each submodule
# Extract engine from the submodule
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,8 @@ def to_trt_weights(
ctx: ConversionContext,
value: torch.Tensor,
name: str,
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT"],
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT"],
layer_type_name: Literal["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"],
weight_type_name: Literal["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"],
target: Optional[Union[Target, str]] = None,
source_ir: Optional[SourceIR] = None,
target_quantized_type: Optional[trt.DataType] = None,
Expand All @@ -362,8 +362,8 @@ def to_trt_weights(
)

# Weight Recording
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT"]
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT"]
supported_layer_types = ["CONVOLUTION", "DECONVOLUTION", "CONSTANT", "SCALE"]
supported_weight_types = ["KERNEL", "BIAS", "CONSTANT", "SCALE", "SHIFT", "POWER"]
assert (
layer_type_name in supported_layer_types
), f"Encountered unsupported layer type: {layer_type_name}. Supported types are: {supported_layer_types}. Manually calling to_trt_weights with a custom layer type is not intended for general use."
Expand Down
Loading
Loading