Torch-TensorRT v2.3.0
Windows Support, Dynamic Shape and Quantization in Dynamo , PyTorch 2.3, CUDA 12.1, TensorRT 10.0
Torch-TensorRT 2.3.0 targets PyTorch 2.3, CUDA 12.1 (builds for CUDA 11.8 are available via the PyTorch package index - https://download.pytorch.org/whl/cu118) and TensorRT 10.0. 2.3.0 adds official support for Windows as a platform. Windows will only support using the Dynamo frontend and currently users are required to use the Python-only runtime (support for the C++ runtime will be added in a future version). This release also adds support for Dynamic shape without recompilation. Users can also now use quantized models with Torch-TensorRT using the Model Optimizer toolkit (https://github.com/NVIDIA/TensorRT-Model-Optimizer).
Note: Python 3.12 is not supported as the Dynamo stack in PyTorch 2.3.0 does not support Python 3.12
Windows
In this release we introduce Windows support for the Python runtime using the Dynamo paths. Users can now directly optimize PyTorch models with TensorRT on Windows, with minimal code changes. This integration enables Python-only optimization in the Torch-TensorRT Dynamo compilation paths (ir="dynamo"
and ir="torch_compile"
).
import torch
import torch_tensorrt
import torchvision.models as models
model = models.resnet18(pretrained=True).eval().to("cuda")
input = torch.randn((1, 3, 224, 224)).to("cuda")
trt_mod = torch_tensorrt.compile(model, ir="dynamo", inputs=[input])
trt_mod(input)
Dynamic Shaped Model Compilation in Dynamo
Dynamic shape support has become more robust in v2.3.0. Torch-TensorRT now leverages symbolic information in the graph to calculate intermediate shape ranges which allows more dynamic shape cases to be supported. For AOT workflows using torch.export, using these new features requires no changes. For JIT workflows which previously used torch.compile
guards to automatically recompile the engines where the input size changes, users can now mark dynamic dimensions using torch APIs (https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html). Using these APIs will mean that as long as inputs do not violate the specified constraints, engines would not recompile.
AOT workflow
import torch
import torch_tensorrt
compile_spec = {"inputs": [torch_tensorrt.Input(min_shape=(1, 3, 224, 224),
opt_shape=(4, 3, 224, 224),
max_shape=(8, 3, 224, 224),
dtype=torch.float32)],
"enabled_precisions": {torch.float}}
trt_model = torch_tensorrt.compile(model, **compile_spec)
JIT workflow
import torch
import torch_tensorrt
compile_spec = {"enabled_precisions": {torch.float}}
inputs = torch.randn((4, 3, 224, 224)).to("cuda")
# This indicates the dimension 0 is dynamic and the range is [1, 8]
torch._dynamo.mark_dynamic(inputs, 0, min=1, max=8)
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
More information can be found here: https://pytorch.org/TensorRT/user_guide/dynamic_shapes.html
Explicit Dynamic Shape support in Converters
Converters now explicitly declare their support for dynamic shapes and we are progressively adding and verifying. Converter writers can specify the support for dynamic shapes using the supports_dynamic_shape
argument of the dynamo_tensorrt_converter
decorator.
@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default,
capability_validator=lambda conv_node: conv_node.args[7] in ([0], [0, 0], [0, 0, 0])
supports_dynamic_shapes=True,
) # type: ignore[misc]
def aten_ops_convolution(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
By default, if a converter has not been marked as supporting dynamic shape, it's operator will be run in PyTorch if the user has specified the inputs as dynamic. This is done for the sake of ensuring that compilation will succeed with some valid compiled module. However, many operators already support dynamic shape in an untested fashion. Therefore, users can decide to enable to full converter library for dynamic shape using the assume_dynamic_shape_support
flag. This flag assumes all converters support dynamic shape, leading to more operations being run in TensorRT with the potential drawback that some ops may cause compilation or runtime failures. Future releases will add progressively add coverage for dynamic shape for all Core ATen Operators.
Quantization in Dynamo
We introduce support for model quantization in FP8. We support models quantized using NVIDIA TensorRT-Model-Optimizer toolkit. This toolkit introduces quantization nodes in the graph which are converted and used by TensorRT to quantize the model into lower precision. Although the toolkit supports quantization in various datatypes, we only support FP8 in this release.
Please refer to our end-end example Torch Compile VGG16 with FP8 and PTQ on how to use this.
Engine Version and Hardware Compatibility
We introduce new compilation arguments, hardware_compatible: bool
and version_compatible: bool
, which enable two key features in TensorRT.
hardware_compatible
Enabling hardware compatibility mode will generate TRT Engines which are compatible with Ampere and newer GPUs. As a result, engines built on one GPU can later be run on others, without requiring recompilation.
version_compatible
Enabling version compatibility mode will generate TRT Engines which are compatible with newer versions of TensorRT. As a result, engines built with one version of TensorRT will be forward compatible with other TRT versions, without needing recompilation.
...
trt_mod = torch_tensorrt.compile(model, ir="dynamo", inputs=[input], hardware_compatible=True, version_compatible=True)
...
New Data Type Support
Torch-TensorRT includes a number of new data types that leverage dedicated hardware on Ampere, Hopper and future architectures.
bfloat16
has been added as a supported type alongside FP16 and FP32 that can be enabled for additional kernel tactic options. Models that contain BF16 weights can now be provided to Torch-TensorRT without modification. FP8 has been added with support for Hopper and newer architectures as a new quantization format (see below), similar to INT8. Finally, native support for INT64 inputs and computation has been added. In the past, the truncate_long_and_double
feature flag must be enabled in order to handle INT64 and FLOAT64 computation, inputs and weights. This flag would cause the compiler to truncate any INT64 or FLOAT64 objects to INT32 and FLOAT32 respectively. Now INT64 objects will not be truncated and remain in INT64. As such, the truncate_long_and_double
flag has been renamed truncate_double
as FLOAT64 truncation is still required, truncate_long_and_double
is now deprecated.
What's Changed
- feat: support group_norm, batch_norm, and layer_norm by @zewenli98 in #2330
- support argmax converter by @bowang007 in #2291
- feat: Decomposition for
_unsafe_index
by @gs-olive in #2386 - docs: Add documentation of
torch.compile
backend usage by @gs-olive in #2363 - fix: Remove supported ops from decompositions by @gs-olive in #2390
- fix: Converter, inputs, and utils bugfixes for Transformer XL by @gs-olive in #2404
- feat: support embedding_bag converter (1D input) by @zewenli98 in #2395
- feat: support chunk dynamo converter by @zewenli98 in #2401
- chore: Add documentation for dynamo.compile backend by @peri044 in #2389
- Support new FX Legacy Registry in opset coverage tool by @laikhtewari in #2366
- fix: type error in embedding_bag by @zewenli98 in #2418
- feat: support cumsum dynamo converter by @zewenli98 in #2403
- 2.0 docs overhaul by @narendasan in #2420
- feat: support tile dynamo converter by @zewenli98 in #2402
- chore: update perf tooling to add dynamo options by @peri044 in #2423
- feat: Add
aten.unbind
decomposition for VIT by @gs-olive in #2430 - fix: Segfault fix for Benchmarks by @gs-olive in #2432
- examples: Stable Diffusion
torch.compile
sample with output image by @gs-olive in #2417 - minor fix: Parse out slashes in Docker container name by @gs-olive in #2437
- fix: Docs rendering on PyTorch site by @gs-olive in #2440
- Numpy changes for aten::index converter by @apbose in #2396
- feat: a lowering pass to re-compose ops into aten.linear by @zewenli98 in #2411
- chore: fix docs for export by @peri044 in #2447
- chore: add additional BN native converter by @peri044 in #2446
- minor fix: Update Benchmark values by @gs-olive in #2453
- Delete .circleci directory by @bigfootjon in #2456
- fix: Wrap perf benchmarks with no_grad by @gs-olive in #2466
- fix: Error with
aten.view
across Tensor memory by @gs-olive in #2464 - fix: Bug in slice operator with default inputs by @gs-olive in #2463
- Expose IGridSampleLayer by @apbose in #2290
- feat: support more elementwise and unary dynamo converters by @zewenli98 in #2429
- feat: support for many padding dynamo converters by @zewenli98 in #2482
- chore: Add TRT runner via onnx by @peri044 in #2503
- feat: support aten.amin dynamo converter by @zewenli98 in #2504
- chore: Switch to new export apis by @peri044 in #2376
- feat: support aten.arange.start_step dynamo converter by @zewenli98 in #2505
- fix/feat: Add support for multiple TRT Build Args by @gs-olive in #2510
- feat: Safety Mode for Runtime by @gs-olive in #2512
- feat: support argmin aten converter by @bowang007 in #2501
- feat: expose IResizeLayer in dynamo.conversion.impl by @bowang007 in #2488
- feat: support aten.sort dynamo converter by @zewenli98 in #2514
- fix: aten.index converter by @zewenli98 in #2487
- DLFW changes by @apbose in #2397
- fix: output shape bug in deconv by @zewenli98 in #2537
- fix: Torch nightly version constraint by @gs-olive in #2546
- Fix memory leaks by @gcuendet in #2526
- fix: Docker builder with new Torch versions by @gs-olive in #2547
- chore: fix deconv padding by @peri044 in #2527
- Fix: aten::matmul converter behavior with 1d tensors by @mfeliz-cruise in #2450
- fix: Specify expecttest version to fix CI by @gs-olive in #2554
- feat: Add dryrun feature to Dynamo paths by @gs-olive in #2451
- feat: Add hardware compatibility option in Dynamo by @gs-olive in #2445
- feat: support aten.clamp.Tensor and update aten.clamp.default dynamo converters by @zewenli98 in #2522
- feat: support aten.trunc dynamo converter by @zewenli98 in #2543
- feat: support aten.copy dynamo converter by @zewenli98 in #2550
- fix: Repair usage of
torch_executed_ops
by @gs-olive in #2562 - fix: Switch all copies to force cast by @gs-olive in #2563
- Exposing select layer by @apbose in #2490
- feat: support aten.remainder.Scalar and aten.remainder.Tensor by @zewenli98 in #2566
- chore: Update torch to 2.3dev by @peri044 in #2585
- feat: Add support for flash attention converter by @gs-olive in #2560
- [feat] Support conversion of scaled_dot_product_attention by @mfeliz-cruise in #2549
- Grant write permission to token by @huydhn in #2591
- Clean up AWS credentials by @huydhn in #2589
- feat: support _pdist_forward dynamo converter by @zewenli98 in #2570
- feat: support aten.flip dynamo converter by @zewenli98 in #2540
- feat: add
convert_method_to_trt_engine()
for dynamo by @zewenli98 in #2467 - Clean up AWS credentials by @huydhn in #2592
- feat: support aten.any related converters in dynamo by @bowang007 in #2578
- feat: support aten.scalar_tensor dynamo converter by @zewenli98 in #2595
- fix: Update release branch for Docker build GHA by @gs-olive in #2600
- fix: Add write permissions to Docker GHA by @gs-olive in #2619
- feat: support aten.pixel_shuffle dynamo converter by @zewenli98 in #2596
- DLFW changes by @apbose in #2552
- feat: support aten.roll dynamo converter by @zewenli98 in #2569
- fix: Linter + config fix by @gs-olive in #2636
- small fix: Remove extraneous argument in
compile
by @gs-olive in #2635 - fix: Remove keyserver fetch from Dockerfile by @gs-olive in #2639
- small fix: Index validator enable int64 by @gs-olive in #2642
- fix: update tests of pooling converters by @zewenli98 in #2613
- chore: Set default return type to ExportedProgram by @peri044 in #2575
- feat: Fixed conv1d converter when weights are Tensor by @andi4191 in #2542
- feat: support converter for torch.log10 by @bowang007 in #2621
- [feat] support converter for torch.log2 by @bowang007 in #2620
- fix: Update index in installer for CI failures by @gs-olive in #2679
- chore: update pytorch to 2.3 by @peri044 in #2697
- feat: Add save API for torch-trt compiled models by @peri044 in #2691
- chore: Update versions by @peri044 in #2732
- feat: Implement symbolic shape propagation, sym_size converter by @peri044 in #2473
- feat: Add dynamic shapes support for torch.compile workflow by @peri044 in #2627
- feat: cherry-pick of Selectively enable different frontends (#2693) by @peri044 in #2761
- chore: Upgrade TensorRT version to TRT 10 EA by @peri044 in #2699
- fix: Remove references to implicit batch for TRT 10 by @gs-olive in #2773
- feat: Python Runtime Windows Builds on TRT 10 by @gs-olive in #2764
- TRT-10 GA Support for release/2.3 branch by @zewenli98 in #2778
- chore: cherry pick commits from main into release/2.3 by @peri044 in #2769
- chore: Add cherry picks by @peri044 in #2797
- 2.3 cherry pick feat: Adding support for native int64 (#2789) by @narendasan in #2802
- chore: Updates for 2.3 by @peri044 in #2788
- chore: Cherry pick embedding_bag converter for release 2.3 by @zewenli98 in #2807
- feat: Add validators for dynamic shapes in converter registration by @peri044 in #2796
- chore: enable DS support for converters by @peri044 in #2775
- feat(//py/torch_tensorrt/dynamo): Support for BF16 (#2833) by @narendasan in #2845
- chore: cherry-pick of #2858 by @peri044 in #2861
- chore: Minor fix 2.3 by @peri044 in #2866
- chore: cherry pick of #2832 by @zewenli98 in #2852
- feat: Implement FP8 functionality by @peri044 in #2763
- chore: cherry pick cudnn dependency removal commit by @peri044 in #2870
- cherry pick pull/2879 to release/2.3 branch by @lanluo-nvidia in #2882
- cherry-pick: Add release flag for nightly build tag (#2821) by @gs-olive in #2835
New Contributors
- @bigfootjon made their first contribution in #2456
- @huydhn made their first contribution in #2591
Full Changelog: v2.2.0...v2.3.0