Skip to content

Commit

Permalink
feat: Add Selective ATen decompositions (#2173)
Browse files Browse the repository at this point in the history
  • Loading branch information
gs-olive authored Aug 17, 2023
1 parent f70574e commit 91fcea4
Show file tree
Hide file tree
Showing 6 changed files with 314 additions and 24 deletions.
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@
TRUNCATE_LONG_AND_DOUBLE = False
USE_PYTHON_RUNTIME = False
USE_FAST_PARTITIONER = True
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from torch_tensorrt.dynamo._defaults import (
DEBUG,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
OPTIMIZATION_LEVEL,
Expand All @@ -19,6 +20,27 @@

@dataclass
class CompilationSettings:
"""Compilation settings for Torch-TensorRT Dynamo Paths
Args:
precision (torch.dtype): Model Layer precision
debug (bool): Whether to print out verbose debugging information
workspace_size (int): Workspace TRT is allowed to use for the module (0 is default)
min_block_size (int): Minimum number of operators per TRT-Engine Block
torch_executed_ops (Sequence[str]): Sequence of operations to run in Torch, regardless of converter coverage
pass_through_build_failures (bool): Whether to fail on TRT engine build errors (True) or not (False)
max_aux_streams (Optional[int]): Maximum number of allowed auxiliary TRT streams for each engine
version_compatible (bool): Provide version forward-compatibility for engine plan files
optimization_level (Optional[int]): Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
use_python_runtime (Optional[bool]): Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
argument as None
truncate_long_and_double (bool): Truncate int64/float64 TRT engine inputs or weights to int32/float32
enable_experimental_decompositions (bool): Whether to enable all core aten decompositions
or only a selected subset of them
"""

precision: torch.dtype = PRECISION
debug: bool = DEBUG
workspace_size: int = WORKSPACE_SIZE
Expand All @@ -31,3 +53,4 @@ class CompilationSettings:
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE
use_fast_partitioner: bool = USE_FAST_PARTITIONER
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS
2 changes: 1 addition & 1 deletion py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def aot_torch_tensorrt_aten_backend(
gm,
sample_inputs,
fw_compiler=make_boxed_compiler(custom_backend),
decompositions=get_decompositions(),
decompositions=get_decompositions(settings.enable_experimental_decompositions),
)


Expand Down
10 changes: 7 additions & 3 deletions py/torch_tensorrt/dynamo/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch_tensorrt.dynamo import CompilationSettings, partitioning
from torch_tensorrt.dynamo._defaults import (
DEBUG,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
MAX_AUX_STREAMS,
MIN_BLOCK_SIZE,
OPTIMIZATION_LEVEL,
Expand Down Expand Up @@ -61,6 +62,7 @@ def compile(
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
use_python_runtime: bool = USE_PYTHON_RUNTIME,
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
**kwargs: Any,
) -> torch.fx.GraphModule:
if debug:
Expand All @@ -71,9 +73,10 @@ def compile(

logger.warning(
"The Dynamo backend is an experimental feature, for which only the "
+ "following arguments are supported: "
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
+ "torch_executed_ops, pass_through_build_failures, use_fast_partitioner}"
"following arguments are supported: "
"{enabled_precisions, debug, workspace_size, min_block_size, "
"torch_executed_ops, pass_through_build_failures, use_fast_partitioner, "
"enable_experimental_decompositions}"
)

if not isinstance(inputs, collections.abc.Sequence):
Expand Down Expand Up @@ -114,6 +117,7 @@ def compile(
"use_python_runtime": use_python_runtime,
"truncate_long_and_double": truncate_long_and_double,
"use_fast_partitioner": use_fast_partitioner,
"enable_experimental_decompositions": enable_experimental_decompositions,
}

settings = CompilationSettings(**compilation_options)
Expand Down
200 changes: 200 additions & 0 deletions py/torch_tensorrt/dynamo/lowering/_decomposition_groups.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from typing import Any, Callable, Dict, Set

import torch
from torch._decomp import core_aten_decompositions
from torch._decomp import get_decompositions as get_torch_decompositions
from torch._ops import OpOverload

aten = torch.ops.aten

_core_aten_decompositions: Dict[
OpOverload, Callable[[Any], Any]
] = core_aten_decompositions()
torch_enabled_decompositions: Set[OpOverload] = {
aten._adaptive_avg_pool2d_backward,
aten.addcdiv,
aten.addcdiv_,
aten.addcmul,
aten.addcmul_,
aten.addr,
aten.aminmax,
aten.arange.default,
aten.arange.start,
aten.avg_pool2d_backward,
aten.binary_cross_entropy,
aten.binary_cross_entropy_backward,
aten.binary_cross_entropy_with_logits,
aten.celu,
aten.col2im,
aten.count_nonzero,
aten.cudnn_batch_norm,
aten.cudnn_batch_norm_backward,
aten.deg2rad,
aten.detach,
aten.diag_embed,
aten.diagonal_backward,
aten.dot,
aten.elu,
aten.elu_backward,
aten._embedding_bag,
aten.embedding_dense_backward,
aten._euclidean_dist.default,
aten.expand_as,
aten.eye,
aten.fill,
aten.frac,
aten._fused_moving_avg_obs_fq_helper,
aten.gelu,
aten.gelu_backward,
aten.glu_backward,
aten.grid_sampler_2d,
aten.hardshrink,
aten.hardshrink_backward,
aten.hardsigmoid,
aten.hardsigmoid_backward,
aten.hardswish,
aten.hardswish_,
aten.hardswish_backward,
aten.hardtanh,
aten.hardtanh_,
aten.hardtanh_backward,
aten.heaviside,
aten.huber_loss,
aten.huber_loss_backward,
aten.im2col,
aten.index_add,
aten.index_add_,
aten.index_copy,
aten.index_copy_,
aten.index_fill,
aten.index_fill_,
aten.index_select,
aten.isneginf,
aten.isposinf,
aten.l1_loss,
aten.leaky_relu,
aten.leaky_relu_,
aten.leaky_relu_backward,
aten.lerp,
aten.linspace,
aten.logaddexp,
aten.logaddexp2,
aten.logit,
aten.logit_backward,
aten.log_sigmoid_backward,
aten.log_sigmoid_forward,
aten._log_softmax,
aten._log_softmax_backward_data,
aten.logspace,
aten.logsumexp.default,
aten.masked_fill,
aten.masked_fill_,
aten.max_pool2d_with_indices_backward,
aten.mish,
aten.mse_loss,
aten.mse_loss_backward,
aten.mv,
aten.mvlgamma,
aten.nansum,
aten.nan_to_num,
aten.narrow,
# TODO: Disable the below operators once freezing is done
aten.native_batch_norm,
aten.native_batch_norm_backward,
aten._native_batch_norm_legit,
aten._native_batch_norm_legit_functional,
aten._native_batch_norm_legit_no_training,
aten.native_dropout_backward,
aten.native_group_norm,
aten.native_group_norm_backward,
aten.native_layer_norm,
aten.native_layer_norm_backward,
aten.new_empty,
aten.new_full,
aten.new_ones,
aten.new_zeros,
aten.nll_loss_backward,
aten.nll_loss_forward,
aten.norm,
aten.ones,
aten.ones_like,
aten._prelu_kernel,
aten._prelu_kernel_backward,
aten._reshape_alias,
aten.rad2deg,
aten.renorm,
aten.renorm_,
aten.rot90,
aten.rsub.Scalar,
aten.rsub.Tensor,
aten.select_backward,
aten.select_scatter,
aten.sgn,
aten.sigmoid_backward,
aten.silu,
aten.silu_,
aten.silu_backward,
aten.sinc,
aten.slice_backward,
aten.smooth_l1_loss,
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
aten._softmax,
aten._softmax_backward_data,
aten.softplus,
aten.softplus_backward,
aten.softshrink,
aten.softshrink_backward,
aten.special_entr,
aten.special_log_ndtr,
aten.special_xlog1py,
aten.stack,
aten.t,
aten.tanh_backward,
aten.threshold,
aten.threshold_backward,
aten.trace,
aten.transpose.int,
aten.tril.default,
aten.triu.default,
aten.unfold,
aten.unfold_backward,
aten.unfold_copy,
aten.upsample_bilinear2d,
aten.upsample_bilinear2d.vec,
aten.upsample_nearest2d_backward,
aten.xlogy,
aten.zero,
aten.zero_,
aten.zeros,
aten.zeros_like,
# Non-default convenience decompositions
aten.clamp_min,
aten.clamp_max,
aten.linalg_vector_norm,
aten.full,
aten.repeat,
}
torch_disabled_decompositions: Set[OpOverload] = set()


ENABLED_TORCH_DECOMPOSITIONS: Dict[
OpOverload, Callable[[Any], Any]
] = get_torch_decompositions(torch_enabled_decompositions)
TORCH_TRT_DECOMPOSITIONS: Dict[OpOverload, Callable[[Any], Any]] = {}


def check_decomp_set_invariants() -> None:
"""Validates no overlap between enabled and disabled decomposition sets"""
overlap = torch_enabled_decompositions.intersection(torch_disabled_decompositions)

if overlap:
raise AssertionError(
f"Detected {overlap} registered in both torch_enabled_decompositions "
"and torch_disabled_decompositions. Ensure all operator(s) are in "
"at most one of the two sets."
)


check_decomp_set_invariants()
Loading

0 comments on commit 91fcea4

Please sign in to comment.