Skip to content

Commit

Permalink
feat: exclude refit sensitive ops from TRT compilation (#3159)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Sep 17, 2024
1 parent bc93437 commit 1e9aefe
Show file tree
Hide file tree
Showing 13 changed files with 174 additions and 56 deletions.
10 changes: 4 additions & 6 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,9 @@ def compile_module(
dryrun_tracker = DryRunTracker()
if sample_kwarg_inputs is None:
sample_kwarg_inputs = {}
# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)

# Set torch-executed ops
CONVERTERS.set_disallowed_targets(settings.torch_executed_ops)
# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

# Check the number of supported operations in the graph
num_supported_ops, total_ops = partitioning.get_graph_converter_support(
Expand Down Expand Up @@ -670,8 +668,8 @@ def convert_exported_program_to_serialized_trt_engine(
settings = CompilationSettings(**compilation_options)
logger.info("Compilation Settings: %s\n", settings)

# Assume converters support dynamic shapes and disable validation
CONVERTERS.set_dynamic_shape_support(settings.assume_dynamic_shape_support)
# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(settings)

try:
interpreter_result = interpret_module_to_result(
Expand Down
3 changes: 1 addition & 2 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, List, Optional, Sequence, Tuple

import numpy as np
import tensorrt as trt
import torch
from torch.export import ExportedProgram
from torch_tensorrt._enums import dtype
Expand Down Expand Up @@ -42,8 +43,6 @@
)
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down
40 changes: 27 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/_ConverterRegistry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from torch import SymBool, SymFloat, SymInt
from torch._ops import OpOverloadPacket
from torch.fx.node import Argument, Node, Target, _get_qualified_name
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.fx.converter_registry import CONVERTERS as FX_CONVERTERS

Expand Down Expand Up @@ -82,7 +83,9 @@ class ConverterSupport:
"""

converter_implementation: ConverterImplSignature
capability_validator: Callable[[Node], bool] = field(default=lambda node: True)
capability_validator: Callable[[Node, CompilationSettings], bool] = field(
default=lambda node, compilation_settings: True
)
supports_dynamic_shapes: bool = False


Expand Down Expand Up @@ -112,18 +115,20 @@ def has_dynamic_shapes_in_args(

def has_static_shapes_in_args(
arg_positions_to_check: Optional[List[int]] = None,
) -> Callable[[torch.fx.Node], bool]:
) -> Callable[[torch.fx.Node, CompilationSettings], bool]:
"""Returns True if a node has static inputs in node.args at specified positions"""
_has_static_shapes = lambda node, arg_positions_to_check: not _has_dynamic_shapes(
node, arg_positions_to_check
_has_static_shapes = lambda node, compilation_settings, arg_positions_to_check: not _has_dynamic_shapes(
node, compilation_settings, arg_positions_to_check
)
return functools.partial(
_has_static_shapes, arg_positions_to_check=arg_positions_to_check
)


def _has_dynamic_shapes(
node: torch.fx.Node, arg_positions_to_check: Optional[List[int]] = None
node: torch.fx.Node,
compilation_settings: CompilationSettings = None,
arg_positions_to_check: Optional[List[int]] = None,
) -> bool:
# Validate that none of the inputs to the node have Dynamic shapes
assert isinstance(
Expand Down Expand Up @@ -188,7 +193,7 @@ def dynamo_tensorrt_converter(
key: Target,
*,
enabled: bool = True,
capability_validator: Optional[Callable[[Node], bool]] = None,
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
priority: ConverterPriority = ConverterPriority.STANDARD,
supports_dynamic_shapes: bool = False,
) -> Callable[[ConverterImplSignature], ConverterImplSignature]:
Expand Down Expand Up @@ -297,7 +302,6 @@ def __init__(
],
registry_names: Optional[Sequence[str]] = None,
registry_calling_conventions: Optional[Sequence[CallingConvention]] = None,
assume_dynamic_shape_support: bool = False,
):
# Copy reference to each dictionary object into attribute list
self.registries = list(registries)
Expand All @@ -318,12 +322,16 @@ def __init__(
CallingConvention.CTX for _ in range(len(self.registries))
]

self.compilation_settings: CompilationSettings = None
self.disallowed_targets: Collection[Target] = set()
self.assume_dynamic_shape_support = assume_dynamic_shape_support
self.validate_invariants()

def set_dynamic_shape_support(self, assume_dynamic_shape_support: bool) -> None:
self.assume_dynamic_shape_support = assume_dynamic_shape_support
def set_compilation_settings(
self, compilation_settings: CompilationSettings
) -> None:
self.compilation_settings = compilation_settings
# set torch executed ops as disallowed targets
self.set_disallowed_targets(compilation_settings.torch_executed_ops)

def set_disallowed_targets(self, torch_executed_ops: Collection[Target]) -> None:
self.disallowed_targets = torch_executed_ops
Expand Down Expand Up @@ -412,7 +420,11 @@ def __getitem__(

self.validate_invariants()
key = node.target

assume_dynamic_shape_support = False
if self.compilation_settings:
assume_dynamic_shape_support = (
self.compilation_settings.assume_dynamic_shape_support
)
if (
key in self.disallowed_targets
or self.qualified_name_or_str(key) in self.disallowed_targets
Expand All @@ -436,8 +448,10 @@ def __getitem__(
# 2) Assume dynamic_shape support is True
# 3) Node only has static shaped inputs
# 4) Node has dynamic inputs and the converter has supports_dynamic_shapes=True
if candidate.capability_validator(node) and (
self.assume_dynamic_shape_support
if candidate.capability_validator(
node, self.compilation_settings
) and (
assume_dynamic_shape_support
or not node_has_dynamic_shapes(node)
or candidate.supports_dynamic_shapes
):
Expand Down
8 changes: 6 additions & 2 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)

import numpy as np
import tensorrt as trt
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
Expand All @@ -43,7 +44,6 @@
from torch_tensorrt.fx.observer import Observer
from torch_tensorrt.logging import TRT_LOGGER

import tensorrt as trt
from packaging import version

_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,6 +89,11 @@ def __init__(
self.builder.create_network(flag), compilation_settings
)

self.compilation_settings = compilation_settings
if not CONVERTERS.compilation_settings:
# Configure user compilation settings to converters.
CONVERTERS.set_compilation_settings(compilation_settings)

assert TRTInterpreter._all_precisions_supported(
compilation_settings.enabled_precisions
), f"Attempted to enable kernel precisions that are not supported (got: {compilation_settings.enabled_precisions}, support: {_defaults.SUPPORTED_KERNEL_PRECISIONS})"
Expand Down Expand Up @@ -117,7 +122,6 @@ def __init__(
self._itensor_to_tensor_meta: Dict[trt.tensorrt.ITensor, TensorMetadata] = (
dict()
)
self.compilation_settings = compilation_settings

# Data types for TRT Module output Tensors
self.output_dtypes = (
Expand Down
66 changes: 47 additions & 19 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import torch
from torch.fx.node import Argument, Node, Target
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
Expand Down Expand Up @@ -48,7 +49,7 @@ def get_ir(target: Target) -> SourceIR:
return SourceIR.UNKNOWN


def one_user_validator(node: Node) -> bool:
def one_user_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Validate only one user, which is a getitem node that accesses the first element in the list
return (
len(node.users) == 1
Expand Down Expand Up @@ -270,7 +271,11 @@ def aten_ops_embedding(
)


def embedding_bag_validator(node: Node) -> bool:
def embedding_bag_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Embedding bag op is not refitable
if settings.make_refittable:
return False

if not one_user_validator(node):
return False
meta = node.args[1].meta
Expand Down Expand Up @@ -416,7 +421,7 @@ def aten_ops_symsize_int(
return impl.shape.shape(ctx, target, SourceIR.ATEN, name, args[0], args[1])


def index_dtype_validator(node: Node) -> bool:
def index_dtype_validator(node: Node, settings: CompilationSettings = None) -> bool:
index = node.args[1]
for ind in index:
if ind is not None:
Expand Down Expand Up @@ -837,7 +842,7 @@ def aten_ops_select(
)


def index_put_validator(node: Node) -> bool:
def index_put_validator(node: Node, settings: CompilationSettings = None) -> bool:
if args_bounds_check(node.args, 3, False): # Check if accumulate is valid
_LOGGER.debug("We do not support accumulate=True for aten.index_put operation")
accumulate_valid = False
Expand Down Expand Up @@ -924,7 +929,18 @@ def aten_ops_slice(
)


@dynamo_tensorrt_converter(torch.ops.aten.cumsum.default, supports_dynamic_shapes=True)
def refit_validator(node: Node, settings: CompilationSettings = None) -> bool:
# cumsum op is not refitable
if settings and settings.make_refittable:
return False
return True


@dynamo_tensorrt_converter(
torch.ops.aten.cumsum.default,
capability_validator=refit_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
{
0: (TRTTensor,),
Expand Down Expand Up @@ -970,7 +986,7 @@ def aten_ops_tile(
)


def zero_output_validator(node: Node) -> bool:
def zero_output_validator(node: Node, settings: CompilationSettings = None) -> bool:
if 0 in node.args[1]:
_LOGGER.debug(
f"We do not support output tensor {node.args[1]} tensors with zero-sized dimensions for this operation."
Expand Down Expand Up @@ -1027,7 +1043,9 @@ def aten_ops_permute(
)


def to_copy_dtype_validator(placeholder_only: bool) -> Callable[[Node], bool]:
def to_copy_dtype_validator(
placeholder_only: bool, settings: CompilationSettings = None
) -> Callable[[Node, CompilationSettings], bool]:
"""Return validator for to_copy node with placeholder restrictions"""

def validate_dtype(to_copy_node: Node) -> bool:
Expand Down Expand Up @@ -1059,7 +1077,7 @@ def validate_dtype(to_copy_node: Node) -> bool:
)
return False

def validator(to_copy_node: Node) -> bool:
def validator(to_copy_node: Node, settings: CompilationSettings = None) -> bool:
"""Returns true if the to_copy node can be converted to TRT
and the placeholder restriction is satisfied
"""
Expand All @@ -1074,7 +1092,9 @@ def validator(to_copy_node: Node) -> bool:

@dynamo_tensorrt_converter(
torch.ops.aten.clone.default,
capability_validator=lambda node: not is_only_operator_on_placeholder(node),
capability_validator=lambda node, settings: not is_only_operator_on_placeholder(
node, settings
),
supports_dynamic_shapes=True,
)
@dynamo_tensorrt_converter(
Expand Down Expand Up @@ -2128,7 +2148,7 @@ def aten_ops_logical_xor(
)


def bitwise_type_validator(node: Node) -> bool:
def bitwise_type_validator(node: Node, settings: CompilationSettings = None) -> bool:
supported_type = [torch.bool, bool]

tensor_targets = [
Expand Down Expand Up @@ -2271,7 +2291,9 @@ def aten_ops_bitwise_xor(
)


def bitwise_not_type_validator(node: Node) -> bool:
def bitwise_not_type_validator(
node: Node, settings: CompilationSettings = None
) -> bool:
val = node.args[0]
val_meta = val.meta.get("tensor_meta")

Expand Down Expand Up @@ -2453,7 +2475,7 @@ def aten_ops_le(
)


def conv_param_validator(conv_node: Node) -> bool:
def conv_param_validator(conv_node: Node, settings: CompilationSettings = None) -> bool:
return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])


Expand Down Expand Up @@ -2549,7 +2571,9 @@ def aten_ops_cdist_forward(
)


def avg_pool_param_validator(pool_node: Node) -> bool:
def avg_pool_param_validator(
pool_node: Node, settings: CompilationSettings = None
) -> bool:
ceil_mode = args_bounds_check(pool_node.args, 4, False)
divisor_override = args_bounds_check(pool_node.args, 6)

Expand Down Expand Up @@ -2665,12 +2689,12 @@ def aten_ops_adaptive_avg_poolNd(
)


def topk_validator(node: Node) -> bool:
def topk_validator(node: Node, settings: CompilationSettings = None) -> bool:
k = node.args[1]
return topk_sort_validator(k)


def sort_validator(node: Node) -> bool:
def sort_validator(node: Node, settings: CompilationSettings = None) -> bool:
meta_data = node.args[0].meta.get("tensor_meta")
if meta_data is None:
return False
Expand All @@ -2692,7 +2716,9 @@ def topk_sort_validator(k: int) -> bool:
return True


def max_pool_param_validator(pool_node: Node) -> bool:
def max_pool_param_validator(
pool_node: Node, settings: CompilationSettings = None
) -> bool:
dilation = args_bounds_check(pool_node.args, 4, 1)
ceil_mode = args_bounds_check(pool_node.args, 5, False)

Expand Down Expand Up @@ -2746,7 +2772,7 @@ def aten_ops_max_pool(
)


def attention_validator(node: Node) -> bool:
def attention_validator(node: Node, settings: CompilationSettings = None) -> bool:
# Currently, `attn_mask` is not supported
return args_bounds_check(node.args, 3) is None

Expand Down Expand Up @@ -3637,7 +3663,7 @@ def aten_ops_flip(
)


def zero_diag_size_validator(node: Node) -> bool:
def zero_diag_size_validator(node: Node, settings: CompilationSettings = None) -> bool:
meta = node.args[0].meta.get("tensor_meta")
if meta:
input_shape = meta.shape
Expand Down Expand Up @@ -3765,7 +3791,9 @@ def aten_ops_index_select(
)


def dropout_inference_validator(node: Node) -> bool:
def dropout_inference_validator(
node: Node, settings: CompilationSettings = None
) -> bool:
train_mode = args_bounds_check(node.args, 2, None)
if train_mode is False:
return True
Expand Down
Loading

0 comments on commit 1e9aefe

Please sign in to comment.