Torch-TensorRT v1.2.0
PyTorch 1.12, Collections based I/O, FX Frontend, torchtrtc custom op support, CMake build system and Community Window Support
Torch-TensorRT 1.2.0 targets PyTorch 1.12, CUDA 11.6, cuDNN 8.4 and TensorRT 8.4. This release focuses on a couple key new APIs to handle function I/O that uses collection types which should enable whole new model classes to be compiled by Torch-TensorRT without source code modification. It also introduces the "FX Frontend", a new frontend for Torch-TensorRT which leverages FX, a high level IR built into PyTorch with extensive Python APIs. For uses cases which do not need to be run outside of Python this may be a strong option to try as it is easily extensible in a familar development enviornment. In Torch-TensorRT 1.2.0, the FX frontend should be considered beta level in stability. torchtrtc
has received improvements which target the ability to handle operators outside of the core PyTorch op set. This includes custom operators from libraries such as torchvision
and torchtext
. Similarlly users can provide custom converters to torchtrtc to extend the compilers support from the command line instead of having to write an application to do so. Finally, Torch-TensorRT introduces community supported Windows and CMake support.
New Dependencies
nvidia-tensorrt
For previous versions of Torch-TensorRT, users had to install TensorRT via system package manager and modify their LD_LIBRARY_PATH
in order to set up Torch-TensorRT. Now users should install the TensorRT Python API as part of the installation proceedure. This can be done via the following steps:
pip install nvidia-pyindex
pip install nvidia-tensorrt==8.4.3.1
pip install torch-tensorrt==1.2.0 -f https://github.com/pytorch/tensorrt/releases
Installing the TensorRT pip package will allow Torch-TensorRT to automatically load the TensorRT libraries without any modification to enviornment variables. It is also a necessary dependency for the FX Frontend.
torchvision
Some FX frontend converters are designed to target operators from 3rd party libraries like torchvision. As such, you must have torchvision installed in order to use them. However, this dependency is optional for cases where you do not need this support.
Jetson
Starting from this release we will be distributing precompiled binaries of our NGC release branches for aarch64 (as well as x86_64), starting with ngc/22.11. These releases are designed to be paired with NVIDIA distributed builds of PyTorch including the NGC containers and Jetson builds and are equivalent to the prepackaged distribution of Torch-TensorRT that comes in the containers. They represent the state of the master branch at the time of branch cutting so may lag in features by a month or so. These releases will come separately to minor version releases like this one. Therefore going forward, these NGC releases should be the primary release channel used on Jetson (including for building from source).
NOTE: NGC PyTorch builds are not identical to builds you might install through normal channels like pytorch.org. In the past this has caused issues in portability between pytorch.org builds and NGC builds. Therefore we strongly recommend in workflows such as exporting a TorchScript module on an x86 machine and then compiling on Jetson to ensure you are using the NGC container release on x86 for your host machine operations. More information about Jetson support can be found along side the 22.07 release (https://github.com/pytorch/TensorRT/releases/tag/v1.2.0a0.nv22.07)
Collections based I/O [Experimental]
Torch-TensorRT previously has operated under the assumption that nn.Module
forward functions can trivially be reduced to the form forward([Tensor]) -> [Tensor]
. Typically this implies functions fo the form forward(Tensor, Tensor, ... Tensor) -> (Tensor, Tensor, ..., Tensor)
. However as model complexity increases, grouping inputs may make it easier to manage many inputs. Therefore, function signatures similar to forward([Tensor], (Tensor, Tensor)) -> [Tensor]
or forward((Tensor, Tensor)) -> (Tensor, (Tensor, Tensor))
might be more common. In Torch-TensorRT 1.2.0, more of these kinds of uses cases are supported using the new experimental input_signature
compile spec API. This API allows users to group Input specs similar to how they might group the input Tensors they would use to call the original module's forward function. This informs Torch-TensorRT on how to map a Tensor input from its location in a group to the engine and from the engine into its grouping returned back to the user.
To make this concrete consider the following standard case:
class StandardTensorInput(nn.Module):
def __init__(self):
super(StandardTensorInput, self).__init__()
def forward(self, x, y):
r = x + y
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = StandardTensorInput().eval().to("cuda")
trt_module = torch_tensorrt.compile(
module,
inputs=[
torch_tensorrt.Input(x.shape),
torch_tensorrt.Input(y.shape)
],
min_block_size=1
)
out = trt_module(x,y)
print(out)
Here a user has defined two explicit tensor inputs and used the existing list based API to define the input specs.
With Torch-TensorRT the following use cases are now possible using the new input_signature
API:
- Tuple based input collection
class TupleInput(nn.Module):
def __init__(self):
super(TupleInput, self).__init__()
def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
r = z[0] + z[1]
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = TupleInput().eval().to("cuda")
trt_module = torch_tensorrt.compile(
module,
input_signature=((x, y),), # Note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module((x,y))
print(out)
- List based input collection
class ListInput(nn.Module):
def __init__(self):
super(ListInput, self).__init__()
def forward(self, z: List[torch.Tensor]):
r = z[0] + z[1]
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = ListInput().eval().to("cuda")
trt_module = torch_tensorrt.compile(
module,
input_signature=([x,y],), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module([x,y])
print(out)
Note how the input specs (in this case just example tensors) are provided to the compiler. The input_signature
argument expects a Tuple[Union[torch.Tensor, torch_tensorrt.Input, List, Tuple]]
grouped in a format representative of how the function would be called. In these cases its just a list or tuple of specs.
More advanced cases are supported as we:
- Tuple I/O
class TupleInputOutput(nn.Module):
def __init__(self):
super(TupleInputOutput, self).__init__()
def forward(self, z: Tuple[torch.Tensor, torch.Tensor]):
r1 = z[0] + z[1]
r2 = z[0] - z[1]
r1 = r1 * 10
r = (r1, r2)
return r
x = torch.Tensor([1,2,3For previous versions of Torch-TensorRT, users had to install TensorRT via ]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = TupleInputOutput()
trt_module = torch_tensorrt.compile(
module,
input_signature=((x,y),), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module((x,y))
print(out)
- List I/O
class ListInputOutput(nn.Module):
def __init__(self):
super(ListInputOutput, self).__init__()
def forward(self, z: List[torch.Tensor]):
r1 = z[0] + z[1]
r2 = z[0] - z[1]
r = [r1, r2]
return r
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = ListInputOutput()
trt_module = torch_tensorrt.compile(
module,
input_signature=([x,y],), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module((x,y))
print(out)
- Multple Groups of Mixed Types
class MultiGroupIO(nn.Module):
def __init__(self):
super(MultiGroupIO, self).__init__()
def forward(self, z: List[torch.Tensor], a: Tuple[torch.Tensor, torch.Tensor]):
r1 = z[0] + z[1]
r2 = a[0] + a[1]
r3 = r1 - r2
r4 = [r1, r2]
return (r3, r4)
x = torch.Tensor([1,2,3]).to("cuda")
y = torch.Tensor([4,5,6]).to("cuda")
module = MultiGroupIO().eval.to("cuda")
trt_module = torch_tensorrt.compile(
module,
input_signature=([x,y],(x,y)), # Again, note how inputs are grouped with the new API
min_block_size=1
)
out = trt_module([x,y],(x,y))
print(out)
These features are also supported in C++ as well:
torch::jit::Module mod;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
mod = torch::jit::load(path);
} catch (const c10::Error& e) {
std::cerr << "error loading the model\n";
}
mod.eval();
mod.to(torch::kCUDA);
std::vector<torch::jit::IValue> inputs_;
for (auto in : inputs) {
inputs_.push_back(torch::jit::IValue(in.clone()));
}
std::vector<torch::jit::IValue> complex_inputs;
auto input_list = c10::impl::GenericList(c10::TensorType::get());
input_list.push_back(inputs_[0]);
input_list.push_back(inputs_[0]);
torch::jit::IValue input_list_ivalue = torch::jit::IValue(input_list);
complex_inputs.push_back(input_list_ivalue);
auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf);
auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive<torch_tensorrt::Input>(input_shape)));
c10::TypePtr elementType = input_shape_ivalue.type();
auto list = c10::impl::GenericList(elementType);
list.push_back(input_shape_ivalue);
list.push_back(input_shape_ivalue);
torch::jit::IValue complex_input_shape(list);
std::tuple<torch::jit::IValue> input_tuple2(complex_input_shape);
torch::jit::IValue complex_input_shape2(input_tuple2);
auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2);
compile_settings.min_block_size = 1;
compile_settings.enabled_precisions = {torch::kHalf};
// // Compile module
auto trt_mod = torch_tensorrt::ts::compile(mod, compile_settings);
auto trt_out = trt_mod.forward(complex_inputs);
Currently this feature should be considered experimental, APIs may be subject to change or folded into existing APIs. There are also limitations introduced by using this feature including the following:
- Not all collection types are supported (e.g.
Dict
,namedtuple
) - Not being able to
require_full_compilation
while using this feature - Certain operators are required to run in PyTorch throughout the graph which may impact performance
- The maximum depth of collections nesting is limited.
These limitations will be addressed in subsequent versions.
Adding FX frontend to Torch-TensorRT [Beta]
This release includes the FX as one of its supported IRs to convert torch models to TensorRT through the new FX frontend. At a high level, this path transforms the model into or consumes an FX graph and similar to the TorchScript frontend converts the graph to TensorRT through the use of a library of converters. The key difference is that it is implemented purely in Python. The role of this FX frontend is to supplement the TS lowering path and to provide users better ease of use and easier extensibility in use cases where removing Python as a dependency is not strictly necessary. Detailed user instructions can be find in the document.
The FX path examples are located under //examples/fx
The FX path unit tests are located under //py/torch_tensorrt/fx/tests
Custom operators and converters in Torch-TensorRT
While both the C++ API and Python API provide systems to include and convert custom operators in your model (for instance those implemented in torchvision
) torchtrtc
has been limited to the core opset. In Torch-TensorRT 1.2.0 two new flags have been added to torchtrtc
.
--custom-torch-ops (repeatable) Shared object/DLL containing custom torch operators
--custom-converters (repeatable) Shared object/DLL containing custom converters
These arguments accept paths to .so or DLL files which define custom operators for PyTorch or custom converters for Torch-TensorRT. These files will get DL_OPEN
'd at runtime to extend the op and converter libraries.
For example:
torchtrtc tests/modules/ssd_traced.jit.pt ssd_trt.ts --custom-torch-ops=<path to custom library .so file> --custom-converters=<path to custom library .so file> "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]@fp16%contiguous" -p f16
Community CMake and Windows support
Thanks to the great work of @gcuendet and others, CMake and consequentially Windows support has been added to the project! Users on Linux and Windows can now build the C++ API using this system and using torch_tensorrt_runtime.dll
add support for executing Torch-TensorRT programs on Windows in both Python and C++. Detailed information on how to use this build system can be found here: https://pytorch.org/TensorRT/getting_started/installation.html
Bazel will continue to be the primary build system for the project and all testing and distributed builds will be built and run with Bazel (including future official Windows support) so users should consider this still the canonical version of Torch-TensorRT. However we aim to ensure as best as we can that the CMake system will be able to build the project properly including on Windows. Contributions to continue to grow the support for this build system and Windows as a platform are definitely welcomed.
Known Limitations
- Collections I/O
- Not all collection types are supported (e.g.
Dict
,namedtuple
) - Not being able to
require_full_compilation
while using this feature - Certain operators are required to run in PyTorch throughout the graph which may impact performance
- The maximum depth of collections nesting is limited.
- Not all collection types are supported (e.g.
- FX
- Some of FX operators have limited dynamic shape capability. Please check here.
- Control flow in model could not be handled
- Python API via the CMake build system.
Dependencies
- Bazel 5.2.0
- LibTorch 1.12.1
- CUDA 11.6 (on x86_64, by default, newer CUDA 11 supported with compatible PyTorch Build)
- cuDNN 8.4.1.50
- TensorRT 8.4.3.1
Operators Supported (TorchScript)
Operators Currently Supported Through Converters
- aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor)
- aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
- aten::abs(Tensor self) -> (Tensor)
- aten::acos(Tensor self) -> (Tensor)
- aten::acosh(Tensor self) -> (Tensor)
- aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> (Tensor)
- aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)
- aten::adaptive_avg_pool3d(Tensor self, int[3] output_size) -> (Tensor)
- aten::adaptive_max_pool1d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
- aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)
- aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)
- aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
- aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
- aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
- aten::argmax(Tensor self, int dim, bool keepdim=False) -> (Tensor)
- aten::argmin(Tensor self, int dim, bool keepdim=False) -> (Tensor)
- aten::asin(Tensor self) -> (Tensor)
- aten::asinh(Tensor self) -> (Tensor)
- aten::atan(Tensor self) -> (Tensor)
- aten::atanh(Tensor self) -> (Tensor)
- aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[0], bool ceil_mode=False, bool count_include_pad=True) -> (Tensor)
- aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
- aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
- aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
- aten::bitwise_not(Tensor self) -> (Tensor)
- aten::bmm(Tensor self, Tensor mat2) -> (Tensor)
- aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)
- aten::ceil(Tensor self) -> (Tensor)
- aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)
- aten::clamp_max(Tensor self, Scalar max) -> (Tensor)
- aten::clamp_min(Tensor self, Scalar min) -> (Tensor)
- aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)
- aten::cos(Tensor self) -> (Tensor)
- aten::cosh(Tensor self) -> (Tensor)
- aten::cumsum(Tensor self, int dim, *, int? dtype=None) -> (Tensor)
- aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> (Tensor)
- aten::div_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))
- aten::div_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
- aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
- aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)
- aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::erf(Tensor self) -> (Tensor)
- aten::exp(Tensor self) -> (Tensor)
- aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))
- aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))
- aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor)
- aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor)
- aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
- aten::floor(Tensor self) -> (Tensor)
- aten::floor_divide(Tensor self, Tensor other) -> (Tensor)
- aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor)
- aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)
- aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))
- aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)
- aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
- aten::layer_norm(Tensor input, int[] normalized_shape, Tensor? gamma, Tensor? beta, float eps, bool cudnn_enabled) -> (Tensor)
- aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)
- aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> (Tensor(a!))
- aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor)
- aten::log(Tensor self) -> (Tensor)
- aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
- aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> (Tensor)
- aten::matmul(Tensor self, Tensor other) -> (Tensor)
- aten::max(Tensor self) -> (Tensor)
- aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
- aten::max.other(Tensor self, Tensor other) -> (Tensor)
- aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)
- aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)
- aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], int[3] dilation=[], bool ceil_mode=False) -> (Tensor)
- aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)
- aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
- aten::min(Tensor self) -> (Tensor)
- aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
- aten::min.other(Tensor self, Tensor other) -> (Tensor)
- aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
- aten::narrow(Tensor(a) self, int dim, int start, int length) -> (Tensor(a))
- aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> (Tensor(a))
- aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)
- aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)
- aten::neg(Tensor self) -> (Tensor)
- aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> (Tensor)
- aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))
- aten::pixel_shuffle(Tensor self, int upscale_factor) -> (Tensor)
- aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)
- aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)
- aten::prelu(Tensor self, Tensor weight) -> (Tensor)
- aten::prod(Tensor self, *, int? dtype=None) -> (Tensor)
- aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
- aten::reciprocal(Tensor self) -> (Tensor)
- aten::reflection_pad1d(Tensor self, int[2] padding) -> (Tensor)
- aten::reflection_pad2d(Tensor self, int[4] padding) -> (Tensor)
- aten::relu(Tensor input) -> (Tensor)
- aten::relu_(Tensor(a!) self) -> (Tensor(a!))
- aten::repeat(Tensor self, int[] repeats) -> (Tensor)
- aten::repeat_interleave.self_int(Tensor self, int repeats, int? dim=None, *, int? output_size=None) -> (Tensor)
- aten::replication_pad1d(Tensor self, int[2] padding) -> (Tensor)
- aten::replication_pad2d(Tensor self, int[4] padding) -> (Tensor)
- aten::replication_pad3d(Tensor self, int[6] padding) -> (Tensor)
- aten::reshape(Tensor self, int[] shape) -> (Tensor)
- aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)
- aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
- aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
- aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> (Tensor)
- aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> (Tensor)
- aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))
- aten::sigmoid(Tensor input) -> (Tensor)
- aten::sigmoid_(Tensor(a!) self) -> (Tensor(a!))
- aten::sin(Tensor self) -> (Tensor)
- aten::sinh(Tensor self) -> (Tensor)
- aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> (Tensor(a))
- aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)
- aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])
- aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])
- aten::split.sizes(Tensor(a -> *) self, int[] split_size, int dim=0) -> (Tensor[])
- aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])
- aten::sqrt(Tensor self) -> (Tensor)
- aten::square(Tensor self) -> (Tensor)
- aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))
- aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)
- aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
- aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
- aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
- aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)
- aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
- aten::t(Tensor self) -> (Tensor)
- aten::tan(Tensor self) -> (Tensor)
- aten::tanh(Tensor input) -> (Tensor)
- aten::tanh_(Tensor(a!) self) -> (Tensor(a!))
- aten::to.device(Tensor(a) self, Device device, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor(a))
- aten::to.dtype(Tensor self, int dtype, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
- aten::to.other(Tensor self, Tensor other, bool non_blocking=False, bool copy=False, int? memory_format=None) -> (Tensor)
- aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> (Tensor(a|b))
- aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
- aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))
- aten::unbind.int(Tensor(a -> *) self, int dim=0) -> (Tensor[])
- aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))
- aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)
- aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
- aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)
- aten::upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
- aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)
- aten::upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
- aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)
- aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
- aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
- aten::upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
- aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
- aten::upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
- aten::view(Tensor(a) self, int[] size) -> (Tensor(a))
- trt::const(Tensor self) -> (Tensor)
Operators Currently Supported Through Evaluators
- aten::Bool.float(float b) -> (bool)
- aten::Bool.int(int a) -> (bool)
- aten::Float.Scalar(Scalar a) -> float
- aten::Float.bool(bool a) -> float
- aten::Float.int(int a) -> float
- aten::Int.Scalar(Scalar a) -> int
- aten::Int.bool(bool a) -> int
- aten::Int.float(float a) -> int
- aten::Int.int(int a) -> int
- aten::and(int a, int b) -> (bool)
- aten::and.bool(bool a, bool b) -> (bool)
- aten::__derive_index(int idx, int start, int step) -> int
- aten::getitem.t(t list, int idx) -> (t(*))
- aten::is(t1 self, t2 obj) -> bool
- aten::isnot(t1 self, t2 obj) -> bool
- aten::not(bool self) -> bool
- aten::or(int a, int b) -> (bool)
- aten::__range_length(int lo, int hi, int step) -> int
- aten::__round_to_zero_floordiv(int a, int b) -> (int)
- aten::xor(int a, int b) -> (bool)
- aten::add.float(float a, float b) -> (float)
- aten::add.int(int a, int b) -> (int)
- aten::add.str(str a, str b) -> (str)
- aten::add_.t(t self, t[] b) -> (t[])
- aten::append.t(t self, t(c -> *) el) -> (t)
- aten::arange(Scalar end, *, int? dtype=None, int? layout=None,
Device? device=None, bool? pin_memory=None) -> (Tensor) - aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None,
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor) - aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor) - aten::clone(Tensor self, *, int? memory_format=None) -> (Tensor)
- aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!))
- aten::dim(Tensor self) -> int
- aten::div.float(float a, float b) -> (float)
- aten::div.int(int a, int b) -> (float)
- aten::eq.bool(bool a, bool b) -> (bool)
- aten::eq.float(float a, float b) -> (bool)
- aten::eq.float_int(float a, int b) -> (bool)
- aten::eq.int(int a, int b) -> (bool)
- aten::eq.int_float(int a, float b) -> (bool)
- aten::eq.str(str a, str b) -> (bool)
- aten::extend.t(t self, t[] other) -> ()
- aten::floor.float(float a) -> (int)
- aten::floor.int(int a) -> (int)
- aten::floordiv.float(float a, float b) -> (int)
- aten::floordiv.int(int a, int b) -> (int)
- aten::format(str self, ...) -> (str)
- aten::ge.bool(bool a, bool b) -> (bool)
- aten::ge.float(float a, float b) -> (bool)
- aten::ge.float_int(float a, int b) -> (bool)
- aten::ge.int(int a, int b) -> (bool)
- aten::ge.int_float(int a, float b) -> (bool)
- aten::gt.bool(bool a, bool b) -> (bool)
- aten::gt.float(float a, float b) -> (bool)
- aten::gt.float_int(float a, int b) -> (bool)
- aten::gt.int(int a, int b) -> (bool)
- aten::gt.int_float(int a, float b) -> (bool)
- aten::is_floating_point(Tensor self) -> (bool)
- aten::le.bool(bool a, bool b) -> (bool)
- aten::le.float(float a, float b) -> (bool)
- aten::le.float_int(float a, int b) -> (bool)
- aten::le.int(int a, int b) -> (bool)
- aten::le.int_float(int a, float b) -> (bool)
- aten::len.t(t[] a) -> (int)
- aten::lt.bool(bool a, bool b) -> (bool)
- aten::lt.float(float a, float b) -> (bool)
- aten::lt.float_int(float a, int b) -> (bool)
- aten::lt.int(int a, int b) -> (bool)
- aten::lt.int_float(int a, float b) -> (bool)
- aten::mul.float(float a, float b) -> (float)
- aten::mul.int(int a, int b) -> (int)
- aten::ne.bool(bool a, bool b) -> (bool)
- aten::ne.float(float a, float b) -> (bool)
- aten::ne.float_int(float a, int b) -> (bool)
- aten::ne.int(int a, int b) -> (bool)
- aten::ne.int_float(int a, float b) -> (bool)
- aten::neg.int(int a) -> (int)
- aten::numel(Tensor self) -> int
- aten::pow.float(float a, float b) -> (float)
- aten::pow.float_int(float a, int b) -> (float)
- aten::pow.int(int a, int b) -> (float)
- aten::pow.int_float(int a, float b) -> (float)
- aten::size(Tensor self) -> (int[])
- aten::size.int(Tensor self, int dim) -> (int)
- aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])
- aten::sqrt.float(float a) -> (float)
- aten::sqrt.int(int a) -> (float)
- aten::sub.float(float a, float b) -> (float)
- aten::sub.int(int a, int b) -> (int)
- aten::tensor(t[] data, *, int? dtype=None, Device? device=None, bool requires_grad=False) -> (Tensor)
- prim::TupleIndex(Any tup, int i) -> (Any)
- prim::dtype(Tensor a) -> (int)
- prim::max.bool(bool a, bool b) -> (bool)
- prim::max.float(float a, float b) -> (bool)
- prim::max.float_int(float a, int b) -> (bool)
- prim::max.int(int a, int b) -> (bool)
- prim::max.int_float(int a, float b) -> (bool)
- prim::max.self_int(int[] self) -> (int)
- prim::min.bool(bool a, bool b) -> (bool)
- prim::min.float(float a, float b) -> (bool)
- prim::min.float_int(float a, int b) -> (bool)
- prim::min.int(int a, int b) -> (bool)
- prim::min.int_float(int a, float b) -> (bool)
- prim::min.self_int(int[] self) -> (int)
- prim::shape(Tensor a) -> (int[])
What's Changed
- chore: Bump version to 1.2.0a0 by @narendasan in #1044
- feat: Extending nox for cxx11 ABI version by @andi4191 in #1013
- docs: Update the documentation theme to PyTorch by @narendasan in #1063
- Adding Code of Conduct file by @facebook-github-bot in #1061
- Update CONTRIBUTING.md by @frank-wei in #1064
- feat: Optimize hub.py download by @andi4191 in #1022
- Adding an action to automatically assign reviewers and assignees by @narendasan in #1078
- Add PR assigner support by @narendasan in #1080
- (//core): Align with prim::Enter in module fallback by @andi4191 in #991
- (//core): Added a variant for aten::split by @andi4191 in #992
- feat(nox): Replacing session with environment variable by @andi4191 in #1057
- Refactor the internal codebase from fx2trt_oss to torch_tensorrt by @frank-wei in #1104
- format by buildifier by @frank-wei in #1106
- [fx2trt] Modify lower setting class by @frank-wei in #1107
- Modified the notebooks directory's README file by @svenchilton in #1102
- [FX] Sync to OSS by @frank-wei in #1118
- [fx_acc] Add acc_tracer support for torch.mm by @khabinov in #1120
- Added Triton deployment instructions to documentation by @tanayvarshney in #1116
- amending triton deployment docs by @tanayvarshney in #1126
- fix: Update broken repo hyperlink by @lamhoangtung in #1131
- fix: Fix keep_dims functionality for aten::max by @peri044 in #1099
- fix(tests/core/partitioning): Fix tests of refactoring segmentation in partitioning by @peri044 in #1140
- feat(//tests): Update rtol and atol based tolerance for test cases by @andi4191 in #1055
- doc: add the explanation for partition phases on docs by @bowang007 in #1090
- feat (//cpp): Using atol and rtol based tolerance threshold for torchtrtc by @andi4191 in #1052
- CI/CD setup by @frank-wei in #1137
- Update README.md by @frank-wei in #1142
- [fx2trt] Engineholder feature improvement, test fixes by @frank-wei in #1143
- feat (//core/conversion) : Add converter for torch.bitwise_not by @blchu in #1029
- fixed typos by @tanayvarshney in #1098
- [FX] --fx-only does not need to check bazel by @frank-wei in #1147
- [FX] refactor the fx path in compile function by @frank-wei in #1141
- [FX] Create getting_started_with_fx_path.rst by @frank-wei in #1145
- [FX] move example folder by @frank-wei in #1149
- [FX] Sync enhancement done internally at Meta by @yinghai in #1161
- Update config.yml by @frank-wei in #1163
- Use py3 next() syntax by @ptrblck in #1159
- Add missing comma for proper torch versioning in setup.py by @dabauxi in #1164
- [docs] Update link to relative path by @zhiqwang in #1171
- [FX] Changes done internally at Facebook by @frank-wei in #1172
- fix: fix the model name typo error by @bowang007 in #1176
- [FX] Changes done internally at Facebook by @frank-wei in #1178
- [feat]: support slice with dynamic shape by @inocsin in #1110
- [FX] Update getting_started_with_fx_path.rst by @frank-wei in #1184
- [FX] Update README.md by @frank-wei in #1183
- fix: Fix PTQ calibration when there are multiple inputs by @peri044 in #1191
- [FX] Changes done internally at Facebook by @frank-wei in #1194
- [fix]: fix bug in aten::to, when network only have aten::to layer wil… by @inocsin in #1108
- Add .circleci/config.yml by @narendasan in #1153
- feat: Upgrade TRT to 8.4 by @peri044 in #1152
- feat: Update Pytorch version to 1.12 by @peri044 in #1177
- fix: converter renaming already named tensors by @bowang007 in #1167
- feat(//py): Use TensorRT to fill in .so libraries automatically if possible by @narendasan in #1085
- [FX] Changes done internally at Facebook by @frank-wei in #1204
- fix: fix the parsing related model loading bug by @bowang007 in #1148
- feat: support min_block_size != 1 caused fallback nodes re-segmentation by @bowang007 in #1195
- [FX] Changes done internally at Facebook by @frank-wei in #1208
- fix: fix the fallback related issue after merging collection by @bowang007 in #1206
- Add CMake support to build the libraries by @gcuendet in #1058
- Fix typo in EfficientNet-example by @davinnovation in #1217
- fix: fix bug that ListConstruct in TRT subgraph when it's entire graph's output by @bowang007 in #1220
- fix: fix the error that collection input segmented into trt subgraph by @bowang007 in #1225
- feat(//circleci): Adding release automation by @narendasan in #1215
- fix: support int tensor * int scaler in aten::mul by @mfeliz-cruise in #1095
- [FX] Changes done internally at Facebook by @frank-wei in #1221
- Fix errors in unbind and list slice by @mfeliz-cruise in #1088
- Adding a Resnet C++ example by @vinhngx in #1175
- [FX] disable 2 of conv3d and type_as tests by @frank-wei in #1224
- [feat] Add support for integers in aten::abs converter (#35) by @mfeliz-cruise in #1232
- Update PTQ example to fix new compile_spec requirements by @ncomly-nvidia in #1242
- feat: support for grouped inputs by @narendasan in #1201
- feat: Added support for custom torch operators and converters in torchtrtc by @andi4191 in #1219
- Add outputPadding in deconv by @ruoqianguo in #1234
- chore: Apply linting and ignore new bazel dirs by @narendasan in #1223
- added qat-ptq workflow notebook by @tanayvarshney in #1239
- fix: Update cmake for the new collection files by @narendasan in #1246
- chore: ignore dist dir for pre-commit by @narendasan in #1249
- chore: Aligning bazel version for consistency across different docker… by @andi4191 in #1250
- refactor: Changed the hardcoded values to macros for DLA memory sizes by @andi4191 in #1247
- chore: update jetson pytorch baase by @narendasan in #1251
- [feat] Add automatic type promotion to element-wise ops by @mfeliz-cruise in #1240
- Assorted small fixes by @narendasan in #1259
- [FX] remove op_lowering_disallow_list and format revert by @frank-wei in #1261
- fix: fix the "schema not found for node" error by @bowang007 in #1236
- chore: Fix contributing doc by @peri044 in #1268
- feat: support scatter.value and scatter.src by @inocsin in #1252
- Internal workspace workflow by @narendasan in #1269
- Fix typo in README by @davinnovation in #1273
- Support swin/bert with dynamic batch by @Njuapp in #1270
- Update release 1.2 by @narendasan in #1275
- correct sha256sum of cudnn by @Njuapp in #1278
- Update release branch by @narendasan in #1279
- Jetson workspace by @narendasan in #1280
- chore(deps): bump @actions/core from 1.8.2 to 1.9.1 in /.github/actions/assigner by @dependabot in #1287
- [FX] Changes done internally at Facebook by @frank-wei in #1288
- chore: Fix dataloader in finetune_qat script by @andi4191 in #1292
- chore: Truncate long and double for ptq CPP path by @andi4191 in #1291
- feat: Add support for aten::square by @mfeliz-cruise in #1286
- fix: fix misleading skipping partitioning msg by @bowang007 in #1289
- fix: Add int support to constant_pad_nd by @mfeliz-cruise in #1283
- fix: Resolve non-determinism in registerSegmentsOutputs by @mfeliz-cruise in #1284
- docs: Update docgen task by @narendasan in #1294
- update fx notebook by @frank-wei in #1297
- [FX] Changes done internally at Facebook by @frank-wei in #1299
- fix(tools): Fix linter to not depend on docker by @narendasan in #1301
- Update release branch by @narendasan in #1300
- Update release branch by @narendasan in #1307
- Support multiple indices for aten::index.Tensor by @ruoqianguo in #1309
- chore: Adding CMake to the CI by @narendasan in #1310
- feat: Upgrade Pytorch to 1.12.1 and TensorRT to 8.4.3.1 by @peri044 in #1315
- Fix bug: correct the output shape of
aten::index.Tensor
by @ruoqianguo in #1314 - feat (//core/conversion) : Add converter for torch.repeat_interleave ( by @blchu in #1313
- chore: Adding NGC build path by @narendasan in #1311
- Update release by @narendasan in #1320
- Update lower.py by @frank-wei in #1324
- fix!: Fixed Windows compilation failures by @andi4191 in #1330
- [feat] Add support for argmax and argmin by @mfeliz-cruise in #1312
- chore: Adding a guideline to build on Windows platform by @andi4191 in #1337
- chore: Fix data loader issues and nox file paths by @peri044 in #1281
- feat(//tools/perf): Refactor perf_run.py, add fx2trt backend support, usage via CLI arguments by @peri044 in #1254
- refactor(//tests) : Refactor the test suite by @peri044 in #1329
- [feat] add support for aten::reciprocal(int) by @mfeliz-cruise in #1308
- Update release branch with latest test fixes by @narendasan in #1339
- [FX] Update getting_started_with_fx_path.rst by @frank-wei in #1342
- Update getting_started_with_fx_path.rst by @frank-wei in #1343
- enable direct call to fx.compile() by @frank-wei in #1344
- fix: add remove_exception pass from torch to fix uninitialized tensor… by @bowang007 in #1345
- chore: apply linting to docs by @narendasan in #1347
- Update release branch by @narendasan in #1348
New Contributors
- @facebook-github-bot made their first contribution in #1061
- @frank-wei made their first contribution in #1064
- @khabinov made their first contribution in #1120
- @blchu made their first contribution in #1029
- @yinghai made their first contribution in #1161
- @ptrblck made their first contribution in #1159
- @dabauxi made their first contribution in #1164
- @zhiqwang made their first contribution in #1171
- @gcuendet made their first contribution in #1058
- @davinnovation made their first contribution in #1217
- @dependabot made their first contribution in #1287
Full Changelog: v1.1.0...v1.2.0